
#include "kernel.h"
#include "screen.h"
#include "memmgr.h"

MemMgr MemMgr::Instance;

MemMgr::MemMgr(void) : NotAssigns(), Frees() {
    for (uint32 i = 0; i < REGION_NUM; i++) {
        NotAssigns.add(Regions[i]);
    }
}

void MemMgr::init(multiboot_info_t* mbi) {
    extern int _start;

    Start = uint32(&_start);
    Size  = mbi->mem_upper * 1024;

    size_t kernel_size;
    if ((mbi->flags & MBI_FLAG_MODS_VALID) != 0) { // mods valid ? 
        module_t* mode = (module_t*)mbi->mods_addr;
        kernel_size = mode[mbi->mods_count - 1].mod_end - Start;
    } else {
        extern void* _end;
        kernel_size = uint32(&_end) - Start;
    }


    // create free list
    Region* free_region = NotAssigns.get(0);
    if (free_region != 0) {
        NotAssigns.remove(*free_region);
        free_region->Start = Start + kernel_size;
        free_region->Size  = Size - kernel_size;
        Frees.add(*free_region);
    }
}

void* MemMgr::alloc(size_t size) {
    void* addr = 0;

    if (size != 0) {
        IntMgr::disable();

        size = getAllocSize(size);
        Region* free_region  = Frees.searchBySize(size);
        Region* alloc_region = NotAssigns.get(0);

        if ((free_region != 0) && (alloc_region != 0)) {
            NotAssigns.remove(*alloc_region);
            alloc_region->Start = free_region->Start;
            alloc_region->Size  = size;
            Allocs.add(*alloc_region);
            addr = (void*)(alloc_region->Start);

            free_region->Start += size;
            free_region->Size  -= size;
            if (free_region->Size == 0) {
                Frees.remove(*free_region);
                NotAssigns.add(*free_region);
            }
        }

        IntMgr::enable();
    }

    return addr;
}

void MemMgr::free(void* addr) {
    if (addr != 0) {
        IntMgr::disable();

        Region* alloc_region = Allocs.searchByAddr(addr);
        if (alloc_region != 0) {
            Allocs.remove(*alloc_region);

            // merge free regions, if you can.
            Region* merge_region;
            while ((merge_region = Frees.merge(*alloc_region)) != 0) {
                NotAssigns.add(*alloc_region);
                alloc_region = merge_region;
                Frees.remove(*alloc_region);
            }
            Frees.add(*alloc_region);
        }

        IntMgr::enable();
    }
}

void RegionList::add(Region& region) {
    region.Next = 0;
    region.Prev = Bottom;
    if (Bottom == 0) {
        Top = &region;
    } else {
        Bottom->Next = &region;
    }
    Bottom = &region;
}

void RegionList::remove(Region& region) {
    if (region.Prev == 0) {
        Top = region.Next;
    } else {
        region.Prev->Next = region.Next;
    }
    if (region.Next == 0) {
        Bottom = region.Prev;
    } else {
        region.Next->Prev = region.Prev;
    }

    region.Prev = region.Next = 0;
}

Region* RegionList::merge(Region& region) {
    Region* merge_region = 0;

    for (Region* r = Top; r != 0; r = r->Next) {
        if (region.Start + region.Size == r->Start) {
            r->Start = region.Start;
            r->Size += region.Size;
            merge_region = r;
            break;
        } else if (r->Start + r->Size == region.Start) {
            r->Size += region.Size;
            merge_region = r;
            break;
        }
    }

    return merge_region;
}

Region* RegionList::get(uint32 i) {
    Region* region = Top;
    for (uint32 j = 0; j < i; j++) {
        if (region == 0) break;
        region = region->Next;
    } 
    return region;
}

Region* RegionList::searchBySize(size_t size) {
    Region* target_region = 0;
    for (Region* region = Top; region != 0; region = region->Next) {
        if (region->Size >= size) {
            target_region = region;
            break;
        }
    }
    return target_region;
}

Region* RegionList::searchByAddr(void* addr) {
    Region* target_region = 0;
    for (Region* region = Top; region != 0; region = region->Next) {
        if (region->Start == uint32(addr)) {
            target_region = region;
            break;
        }
    }
    return target_region;
}

void RegionList::show(void) {
    Console.printf("list :: top(%x), bottom(%x)\n", uint32(Top), uint32(Bottom));
    Region* region = Top;
    while (region != 0) {
        Console.printf("  reg[%x] :: start(%x), size(%x)\n", 
                       uint32(region), region->Start, region->Size);
        region = region->Next;
    }
}

// new / delete

void* operator new(size_t size) {
    return MemMgr::getInstance()->alloc(size);
}

void* operator new[](size_t size) {
    return MemMgr::getInstance()->alloc(size);
}

void operator delete(void* addr) {
    MemMgr::getInstance()->free(addr);
}
