#include "emu.h"
#include "msr-index.h"
#include "cpufeature.h"

void real_cpuid(struct kvm_cpuid_entry *entry)
{
    asm volatile("cpuid"
		 : "=a" (entry->eax),
		   "=b" (entry->ebx),
		   "=c" (entry->ecx),
		   "=d" (entry->edx)
		 : "a" (entry->function));
}

static void emulate_cpuid(struct regs *regs)
{
    struct kvm_cpuid_entry entry;

    entry.function = regs->rax;
    real_cpuid(&entry);
    regs->rax = entry.eax;
    regs->rbx = entry.ebx;
    regs->rcx = entry.ecx;
    regs->rdx = entry.edx;
    printk(2, "cpuid 0x%08x: eax 0x%08x ebx 0x%08x ecx 0x%08x edx 0x%08x\n",
	   entry.function, entry.eax, entry.ebx, entry.ecx, entry.edx);
}

static void emulate_rdmsr(struct regs *regs)
{
    uint32_t ax,dx;
    switch (regs->rcx) {
    case MSR_EFER:
    case MSR_FS_BASE:
    case MSR_GS_BASE:
    case MSR_KERNEL_GS_BASE:
	/* white listed */
	rdmsr(regs->rcx, &ax, &dx);
	regs->rax = ax;
	regs->rdx = dx;
	break;
    default:
	printk(1, "%s: ignore: rcx 0x%" PRIxREG "\n", __FUNCTION__, regs->rcx);
	regs->rax = 0;
	regs->rdx = 0;
	break;
    }
}

static void emulate_wrmsr(struct regs *regs)
{
    static const uint64_t known = (EFER_NX|EFER_LMA|EFER_LME|EFER_SCE);
    static const uint64_t fixed = (EFER_LMA|EFER_LME|EFER_SCE);
    uint32_t ax,dx;
    
    switch (regs->rcx) {
    case MSR_EFER:
	if (regs->rax & ~known) {
	    printk(1, "%s: efer: unknown bit set\n", __FUNCTION__);
	    goto out;
	}

	rdmsr(regs->rcx, &ax, &dx);
	if ((regs->rax & fixed) != (ax & fixed)) {
	    printk(1, "%s: efer: modify fixed bit\n", __FUNCTION__);
	    goto out;
	}

	printk(1, "%s: efer:%s%s%s%s\n", __FUNCTION__,
	       regs->rax & EFER_SCE ? " sce" : "",
	       regs->rax & EFER_LME ? " lme" : "",
	       regs->rax & EFER_LMA ? " lma" : "",
	       regs->rax & EFER_NX  ? " nx"  : "");
	/* fall through */
    case MSR_FS_BASE:
    case MSR_GS_BASE:
    case MSR_KERNEL_GS_BASE:
	wrmsr(regs->rcx, regs->rax, regs->rdx);
	return;
    }

out:
    printk(1, "%s: ignore: 0x%" PRIxREG " 0x%" PRIxREG ":0x%" PRIxREG "\n",
	   __FUNCTION__, regs->rcx, regs->rdx, regs->rax);
}

void print_emu_instr(int level, const char *prefix, uint8_t *instr)
{
    printk(level, "%s: rip %p bytes %02x %02x %02x %02x  %02x %02x %02x %02x\n",
	   prefix, instr,
	   instr[0], instr[1], instr[2], instr[3],
	   instr[4], instr[5], instr[6], instr[7]);
}

static ureg_t *decode_reg(struct regs *regs, uint8_t modrm, int rm)
{
    int shift = rm ? 0 : 3;
    ureg_t *reg = NULL;

    switch ((modrm >> shift) & 0x07) {
    case 0: reg = (ureg_t*)&regs->rax; break;
    case 1: reg = (ureg_t*)&regs->rcx; break;
    case 2: reg = (ureg_t*)&regs->rdx; break;
    case 3: reg = (ureg_t*)&regs->rbx; break;
    case 4: reg = (ureg_t*)&regs->rsp; break;
    case 5: reg = (ureg_t*)&regs->rbp; break;
    case 6: reg = (ureg_t*)&regs->rsi; break;
    case 7: reg = (ureg_t*)&regs->rdi; break;
    }
    return reg;
}

void print_bits(int level, char *msg, uint32_t old, uint32_t new, const char *names[])
{
    char buf[128];
    int pos = 0;
    uint32_t mask;
    char *mod;
    int i;

    pos += snprintf(buf+pos, sizeof(buf)-pos, "%s:", msg);
    for (i = 0; i < 32; i++) {
	mask = 1 << i;
	if (new&mask)
	    if (old&mask)
		mod = ""; /* bit present */
	    else
		mod = "+"; /* bit added */
	else
	    if (old&mask)
		mod = "-"; /* bit removed */
	    else
		continue; /* bit not present */
	pos += snprintf(buf+pos, sizeof(buf)-pos, " %s%s",
			mod, names[i] ? names[i] : "???");
    }
    pos += snprintf(buf+pos, sizeof(buf)-pos, "\n");
    printk(level, "%s", buf);
}

int emulate(struct xen_cpu *cpu, struct regs *regs)
{
    static const uint8_t xen_emu_prefix[5] = {0x0f, 0x0b, 'x','e','n'};
    uint8_t *instr;
    int skip = 0;
    int in = 0;
    int shift = 0;
    int port = 0;

restart:
    instr = (void*)regs->rip;

    /* prefixes */
    if (instr[skip] == 0x66) {
	shift = 16;
	skip++;
    }

    /* instructions */
    switch (instr[skip]) {
    case 0x0f:
	switch (instr[skip+1]) {
	case 0x06:
	    /* clts */
	    clts();
	    skip += 2;
	    break;
	case 0x09:
	    /* wbinvd */
	    __asm__("wbinvd" ::: "memory");
	    skip += 2;
	    break;
	case 0x0b:
	    /* ud2a */
	    if (xen_emu_prefix[2] == instr[skip+2] &&
		xen_emu_prefix[3] == instr[skip+3] &&
		xen_emu_prefix[4] == instr[skip+4]) {
		printk(2, "%s: xen emu prefix\n", __FUNCTION__);
		regs->rip += 5;
		goto restart;
	    }
	    printk(1, "%s: ud2a -- linux kernel BUG()?\n", __FUNCTION__);
	    /* bounce to guest, hoping it prints more info */
	    return 0;
	case 0x20:
	{
	    /* read control registers */
	    ureg_t *reg = decode_reg(regs, instr[skip+2], 1);
	    switch (((instr[skip+2]) >> 3) & 0x07) {
	    case 0:
		*reg = read_cr0();
		skip = 3;
		break;
	    case 3:
		*reg = frame_to_addr(read_cr3_mfn(cpu));
		skip = 3;
		break;
	    case 4:
		*reg = read_cr4();
		skip = 3;
		break;
	    }
	    break;
	}
	case 0x22:
	{
	    /* write control registers */
	    static const ureg_t cr0_fixed = ~(X86_CR0_TS);
	    static const ureg_t cr4_fixed = X86_CR4_TSD;
	    ureg_t *reg = decode_reg(regs, instr[skip+2], 1);
	    ureg_t cr;
	    switch (((instr[skip+2]) >> 3) & 0x07) {
	    case 0:
		cr = read_cr0();
		if (cr != *reg) {
		    if ((cr & cr0_fixed) == (*reg & cr0_fixed)) {
			print_bits(2, "apply cr0 update", cr, *reg, cr0_bits);
			write_cr0(*reg);
		    } else {
			print_bits(1, "IGNORE cr0 update", cr, *reg, cr0_bits);
		    }
		}
		skip = 3;
		break;
	    case 4:
		cr = read_cr4();
		if (cr != *reg) {
		    if ((cr & cr4_fixed) == (*reg & cr4_fixed)) {
			print_bits(1, "apply cr4 update", cr, *reg, cr4_bits);
			write_cr4(*reg);
		    } else {
			print_bits(1, "IGNORE cr4 update", cr, *reg, cr4_bits);
		    }
		}
		skip = 3;
		break;
	    }
	    break;
	}
	case 0x30:
	    /* wrmsr */
	    emulate_wrmsr(regs);
	    skip += 2;
	    break;
	case 0x32:
	    /* rdmsr */
	    emulate_rdmsr(regs);
	    skip += 2;
	    break;
	case 0xa2:
	    /* cpuid */
	    emulate_cpuid(regs);
	    skip += 2;
	    break;
	}
	break;

    case 0xe4: /* in     <next byte>,%al */
    case 0xe5:
	in = (instr[skip] & 1) ? 2 : 1;
	port = instr[skip+1];
	skip += 2;
	break;
    case 0xec: /* in     (%dx),%al */
    case 0xed:
	in = (instr[skip] & 1) ? 2 : 1;
	port = regs->rdx & 0xffff;
	skip += 1;
	break;
    case 0xe6: /* out    %al,<next byte> */
    case 0xe7:
	port = instr[skip+1];
	skip += 2;
	break;
    case 0xee: /* out    %al,(%dx) */
    case 0xef:
	port = regs->rdx & 0xffff;
	skip += 1;
	break;

    case 0xfa:
	/* cli */
	guest_cli(cpu);
	skip += 1;
	break;
    case 0xfb:
	/* sti */
	guest_sti(cpu);
	skip += 1;
	break;
    }

    /* unknown instruction */
    if (!skip) {
	print_emu_instr(0, "instr emu failed", instr);
	return -1;
    }

    /* I/O instruction */
    if (2 == in)
	regs->rax |= 0xffffffff;
    if (1 == in)
	regs->rax |= (0xffff << shift);

    return skip;
}
