#include <inttypes.h>
#include <xen/xen.h>

#include "emu.h"
#include "emu-mm.c"

/* --------------------------------------------------------------------- */

#define MAPS_R_BITS        4
#define MAPS_R_COUNT       (1 << MAPS_R_BITS)
#define MAPS_R_MASK        (MAPS_R_COUNT - 1)
#define MAPS_R_SIZE        (MAPS_MAX / MAPS_R_COUNT)
#define MAPS_R_LOW(r)      (MAPS_R_SIZE * (r))
#define MAPS_R_HIGH(r)     (MAPS_R_SIZE * (r) + MAPS_R_SIZE)
static int maps_next[MAPS_R_COUNT];

static spinlock_t maplock = SPIN_LOCK_UNLOCKED;

/* --------------------------------------------------------------------- */

static int find_slot(int range)
{
    int low   = MAPS_R_LOW(range);
    int high  = MAPS_R_HIGH(range);
    int *next = maps_next + range;
    int start = *next;
    int slot;

    while (0 != maps_refcnt[*next]) {
	(*next)++;
	if (*next == high)
	    *next = low;
	if (*next == start)
	    return -1;
    }
    slot = *next;
    (*next)++;
    if (*next == high)
	*next = low;
    return slot;
}

static int mfn_to_slot_32(uint32_t mfn, int range)
{
    int low   = MAPS_R_LOW(range);
    int high  = MAPS_R_HIGH(range);
    int slot;

    for (slot = low; slot < high; slot++) {
	if (!test_pgflag_32(maps_32[slot], _PAGE_PRESENT))
	    continue;
	if (get_pgframe_32(maps_32[slot]) == mfn) {
	    /* cache hit */
	    return slot;
	}
    }
    return -1;
}

static int mfn_to_slot_pae(uint32_t mfn, int range)
{
    int low   = MAPS_R_LOW(range);
    int high  = MAPS_R_HIGH(range);
    int slot;

    for (slot = low; slot < high; slot++) {
	if (!test_pgflag_pae(maps_pae[slot], _PAGE_PRESENT))
	    continue;
	if (get_pgframe_pae(maps_pae[slot]) == mfn) {
	    /* cache hit */
	    return slot;
	}
    }
    return -1;
}

void *map_32(uint32_t maddr)
{
    uint32_t mfn = addr_to_frame(maddr);
    uint32_t off = addr_offset(maddr);
    uint32_t va;
    int range, slot;

    spin_lock(&maplock);
    range = mfn & MAPS_R_MASK;
    slot = mfn_to_slot_32(mfn, range);
    if (-1 == slot) {
	slot = find_slot(range);
	if (-1 == slot)
	    panic("out of map slots", NULL);
	printk(3, "%s: mfn %5x range %d [%3d - %3d], slot %3d\n", __FUNCTION__,
	       mfn, range, MAPS_R_LOW(range), MAPS_R_HIGH(range), slot);
	maps_32[slot] = get_pgentry_32(mfn, EMU_PGFLAGS);
	vminfo.faults[XEN_FAULT_MAPS_MAPIT]++;
	va = XEN_MAP_32 + slot*PAGE_SIZE;
	flush_tlb_addr(va);
    } else {
	vminfo.faults[XEN_FAULT_MAPS_REUSE]++;
	va = XEN_MAP_32 + slot*PAGE_SIZE;
    }
    maps_refcnt[slot]++;
    spin_unlock(&maplock);

    return (void*)va + off;
}

void *map_pae(uint64_t maddr)
{
    uint32_t mfn = addr_to_frame(maddr);
    uint32_t off = addr_offset(maddr);
    uint32_t va;
    int range, slot;

    spin_lock(&maplock);
    range = mfn & MAPS_R_MASK;
    slot = mfn_to_slot_pae(mfn, range);
    if (-1 == slot) {
	slot = find_slot(range);
	if (-1 == slot)
	    panic("out of map slots", NULL);
	printk(3, "%s: mfn %5x range %d [%3d - %3d], slot %3d\n", __FUNCTION__,
	       mfn, range, MAPS_R_LOW(range), MAPS_R_HIGH(range), slot);
	maps_pae[slot] = get_pgentry_pae(mfn, EMU_PGFLAGS);
	vminfo.faults[XEN_FAULT_MAPS_MAPIT]++;
	va = XEN_MAP_PAE + slot*PAGE_SIZE;
	flush_tlb_addr(va);
    } else {
	vminfo.faults[XEN_FAULT_MAPS_REUSE]++;
	va = XEN_MAP_PAE + slot*PAGE_SIZE;
    }
    spin_unlock(&maplock);

    maps_refcnt[slot]++;
    return (void*)va + off;
}

void *map_page(uint64_t maddr)
{
    if (is_pae())
	return map_pae(maddr);
    else
	return map_32(maddr);
}

void free_page(void *ptr)
{
    uintptr_t va   = (uintptr_t)ptr;
    uintptr_t base = is_pae() ? XEN_MAP_PAE : XEN_MAP_32;
    int slot       = (va - base) >> PAGE_SHIFT;

    spin_lock(&maplock);
    maps_refcnt[slot]--;
    spin_unlock(&maplock);
}

static void *fixmap_32(uint32_t maddr)
{
    static int fixmap_slot = MAPS_MAX;
    uint32_t mfn = addr_to_frame(maddr);
    uint32_t off = addr_offset(maddr);
    uint32_t va;
    int slot;

    slot = fixmap_slot++;
    printk(2, "%s: mfn %5x slot %3d\n", __FUNCTION__, mfn, slot);
    maps_32[slot] = get_pgentry_32(mfn, EMU_PGFLAGS);
    va = XEN_MAP_32 + slot*PAGE_SIZE;
    return (void*)va + off;
}

static void *fixmap_pae(uint64_t maddr)
{
    static int fixmap_slot = MAPS_MAX;
    uint32_t mfn = addr_to_frame(maddr);
    uint32_t off = addr_offset(maddr);
    uint32_t va;
    int slot;

    slot = fixmap_slot++;
    printk(2, "%s: mfn %5x slot %3d\n", __FUNCTION__, mfn, slot);
    maps_pae[slot] = get_pgentry_pae(mfn, EMU_PGFLAGS);
    va = XEN_MAP_PAE + slot*PAGE_SIZE;
    return (void*)va + off;
}

void *fixmap_page(struct xen_cpu *cpu, uint64_t maddr)
{
    if (is_pae())
	return fixmap_pae(maddr);
    else
	return fixmap_32(maddr);
}

/* --------------------------------------------------------------------- */

uint32_t *find_pte_32_lpt(uint32_t va)
{
    uint32_t *lpt_base = (void*)XEN_LPT_32;
    uint32_t offset = va >> PAGE_SHIFT;

    return lpt_base + offset;
}

uint32_t *find_pte_32_map(struct xen_cpu *cpu, uint32_t va)
{
    uint32_t *pgd;
    uint32_t *pte;
    int g,t;

    g = PGD_INDEX_32(va);
    t = PTE_INDEX_32(va);
    printk(1, "va %" PRIx32 " | 32 %d -> %d\n", va, g, t);

    pgd  = map_32(frame_to_addr(read_cr3_mfn(cpu)));
    printk(1, "  pgd   %3d = %08" PRIx32 "\n", g, pgd[g]);
    if (!test_pgflag_32(pgd[g], _PAGE_PRESENT))
        return NULL;

    pte  = map_32(frame_to_addr(get_pgframe_32(pgd[g])));
    printk(1, "    pte %3d = %08" PRIx32 "\n", t, pte[t]);
    free_page(pgd);
    return pte+t;
}

uint64_t *find_pte_pae_lpt(uint32_t va)
{
    uint64_t *lpt_base = (void*)XEN_LPT_PAE;
    uint32_t offset = va >> PAGE_SHIFT;

    return lpt_base + offset;
}

uint64_t *find_pte_pae_map(struct xen_cpu *cpu, uint32_t va)
{
    uint64_t *pgd;
    uint64_t *pmd;
    uint64_t *pte;
    int g,m,t;

    g = PGD_INDEX_PAE(va);
    m = PMD_INDEX_PAE(va);
    t = PTE_INDEX_PAE(va);

    pgd  = map_pae(frame_to_addr(read_cr3_mfn(cpu)));
    if (!test_pgflag_pae(pgd[g], _PAGE_PRESENT))
	goto err1;

    pmd  = map_pae(frame_to_addr(get_pgframe_pae(pgd[g])));
    if (!test_pgflag_pae(pmd[m], _PAGE_PRESENT))
	goto err2;

    pte  = map_pae(frame_to_addr(get_pgframe_pae(pmd[m])));
    free_page(pgd);
    free_page(pmd);
    return pte+t;

err2:
    free_page(pmd);
err1:
    free_page(pgd);
    return NULL;
}

static char *print_pgflags(uint32_t flags)
{
    static char buf[80];

    snprintf(buf, sizeof(buf), "%s%s%s%s%s%s%s%s%s\n",
	     flags & _PAGE_GLOBAL   ? " global"   : "",
	     flags & _PAGE_PSE      ? " pse"      : "",
	     flags & _PAGE_DIRTY    ? " dirty"    : "",
	     flags & _PAGE_ACCESSED ? " accessed" : "",
	     flags & _PAGE_PCD      ? " pcd"      : "",
	     flags & _PAGE_PWT      ? " pwt"      : "",
	     flags & _PAGE_USER     ? " user"     : "",
	     flags & _PAGE_RW       ? " write"    : "",
	     flags & _PAGE_PRESENT  ? " present"  : "");
    return buf;
}

void pgtable_walk_pae(struct xen_cpu *cpu, uint32_t va)
{
    uint64_t *pgd = NULL;
    uint64_t *pmd = NULL;
    uint64_t *pte = NULL;
    uint64_t mfn;
    uint32_t g,m,t, flags;

    g = PGD_INDEX_PAE(va);
    m = PMD_INDEX_PAE(va);
    t = PTE_INDEX_PAE(va);
    printk(1, "va %" PRIx32 " | pae %d -> %d -> %d\n", va, g, m, t);

    pgd  = map_pae(frame_to_addr(read_cr3_mfn(cpu)));
    mfn   = get_pgframe_64(pgd[g]);
    flags = get_pgflags_64(pgd[g]);
    printk(1, "  pgd     +%3d : %08" PRIx64 "  |  mfn %4" PRIx64 " | %s",
	   g, pgd[g], mfn, print_pgflags(flags));
    if (!test_pgflag_pae(pgd[g], _PAGE_PRESENT))
	goto cleanup;

    pmd  = map_pae(frame_to_addr(get_pgframe_pae(pgd[g])));
    mfn   = get_pgframe_64(pmd[m]);
    flags = get_pgflags_64(pmd[m]);
    printk(1, "    pmd   +%3d : %08" PRIx64 "  |  mfn %4" PRIx64 " | %s",
	   m, pmd[m], mfn, print_pgflags(flags));
    if (!test_pgflag_pae(pmd[m], _PAGE_PRESENT))
	goto cleanup;
    if (test_pgflag_pae(pmd[m], _PAGE_PSE))
	goto cleanup;

    pte  = map_pae(frame_to_addr(get_pgframe_pae(pmd[m])));
    mfn   = get_pgframe_64(pte[t]);
    flags = get_pgflags_64(pte[t]);
    printk(1, "      pte +%3d : %08" PRIx64 "  |  mfn %4" PRIx64 " | %s",
	   t, pte[t], mfn, print_pgflags(flags));

cleanup:
    if (pgd)
	free_page(pgd);
    if (pmd)
	free_page(pmd);
    if (pte)
	free_page(pte);
}

/* --------------------------------------------------------------------- */

static void update_emu_mappings_32(uint32_t cr3_mfn)
{
    uint32_t *new_pgd;
    uint32_t entry;
    int idx;

    new_pgd  = map_32(frame_to_addr(cr3_mfn));

    idx = PGD_INDEX_32(XEN_M2P_32);
    if (!test_pgflag_32(new_pgd[idx], _PAGE_PRESENT)) {
	/* new one, must init static mappings */
	for (; idx < PGD_COUNT_32; idx++) {
	    if (!test_pgflag_32(emu_pgd_32[idx], _PAGE_PRESENT))
		continue;
	    if (idx == PGD_INDEX_32(XEN_LPT_32))
		continue;
	    new_pgd[idx] = emu_pgd_32[idx];
	    idx++;
	}
    }

    /* linear pgtable mapping */
    idx = PGD_INDEX_32(XEN_LPT_32);
    entry = get_pgentry_32(cr3_mfn, LPT_PGFLAGS);
    if (new_pgd[idx] != entry)
	new_pgd[idx] = entry;
    free_page(new_pgd);
}

static void update_emu_mappings_pae(uint32_t cr3_mfn)
{
    uint64_t *new_pgd, *new_pmd3;
    uint64_t entry;
    uint32_t mfn;
    int idx, i;

    new_pgd  = map_pae(frame_to_addr(cr3_mfn));
    new_pmd3 = map_pae(frame_to_addr(get_pgframe_pae(new_pgd[3])));

    idx = PMD_INDEX_PAE(XEN_M2P_PAE);
    if (!test_pgflag_pae(new_pmd3[idx], _PAGE_PRESENT)) {
	/* new one, must init static mappings */
	for (; idx < PMD_COUNT_PAE; idx++) {
	    if (!test_pgflag_pae(emu_pmd_pae[idx], _PAGE_PRESENT))
		continue;
	    if (idx >= PMD_INDEX_PAE(XEN_LPT_PAE) &&
		idx <  PMD_INDEX_PAE(XEN_LPT_PAE) +4)
		continue;
	    new_pmd3[idx] = emu_pmd_pae[idx];
	}
    }

    /* linear pgtable mappings */
    idx = PMD_INDEX_PAE(XEN_LPT_PAE);
    for (i = 0; i < 4; i++) {
	if (test_pgflag_pae(new_pgd[i], _PAGE_PRESENT)) {
	    mfn = get_pgframe_pae(new_pgd[i]);
	    entry = get_pgentry_pae(mfn, LPT_PGFLAGS);
	} else {
	    entry = 0;
	}
	if (new_pmd3[idx+i] != entry)
	    new_pmd3[idx+i] = entry;
    }
    free_page(new_pgd);
    free_page(new_pmd3);
}

void update_emu_mappings(uint32_t cr3_mfn)
{
    if (is_pae())
	update_emu_mappings_pae(cr3_mfn);
    else
	update_emu_mappings_32(cr3_mfn);
}

/* --------------------------------------------------------------------- */

static void paging_init_32(void)
{
    int idx;

    idx = PGD_INDEX_32(XEN_TXT_32);
    emu_pgd_32[idx] = get_pgentry_32(vmconf.mfn_emu, EMU_PGFLAGS | _PAGE_PSE);

    idx = PGD_INDEX_32(XEN_M2P_32);
    emu_pgd_32[idx] = get_pgentry_32(vmconf.mfn_m2p, M2P_PGFLAGS_32 | _PAGE_PSE);

    idx = PGD_INDEX_32(XEN_MAP_32);
    emu_pgd_32[idx] = get_pgentry_32(EMU_MFN(maps_32), PGT_PGFLAGS_32);

    m2p = (void*)XEN_M2P_32;
}

static void paging_init_pae(void)
{
    uint32_t mfn;
    int idx;

    idx = PMD_INDEX_PAE(XEN_TXT_PAE);
    for (mfn = vmconf.mfn_emu;
	 mfn < vmconf.mfn_emu + vmconf.pg_emu;
	 mfn += PMD_COUNT_PAE, idx++)
	emu_pmd_pae[idx] = get_pgentry_pae(mfn, EMU_PGFLAGS | _PAGE_PSE);

    idx = PMD_INDEX_PAE(XEN_M2P_PAE);
    for (mfn = vmconf.mfn_m2p;
	 mfn < vmconf.mfn_m2p + vmconf.pg_m2p;
	 mfn += PMD_COUNT_PAE, idx++)
	emu_pmd_pae[idx] = get_pgentry_pae(mfn, EMU_PGFLAGS | _PAGE_PSE);

    idx = PMD_INDEX_PAE(XEN_MAP_PAE);
    emu_pmd_pae[idx] = get_pgentry_pae(EMU_MFN(maps_pae), PGT_PGFLAGS_32);

    m2p = (void*)XEN_M2P_PAE;
}

void paging_init(struct xen_cpu *cpu)
{
    if (is_pae())
	paging_init_pae();
    else
	paging_init_32();
}
