/*
 * Copyright (c) 2007, 2008 University of Tsukuba
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice,
 *    this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 *    this list of conditions and the following disclaimer in the documentation
 *    and/or other materials provided with the distribution.
 * 3. Neither the name of the University of Tsukuba nor the names of its
 *    contributors may be used to endorse or promote products derived from
 *    this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */
/*
 * Copyright (c) 2010-2012 Yuichi Watanabe
 */

#include <common/common.h>
#include <common/list.h>
#include <core/assert.h>
#include <core/initfunc.h>
#include <core/mm.h>
#include <core/gmm.h>
#include <core/printf.h>
#include <core/thread.h>
#include "apic_pass.h"
#include "asm.h"
#include "cpu_mmu.h"
#include "mm.h"
#include "mtrr.h"
#include "vt_internal.h"
#include "constants.h"
#include "current.h"

#define EPT_PAGE_NUM		1024
#define EPT_MAX_LEVEL		4
#define EPT_LEVEL_STRIDE	(9)
#define EPT_LEVEL_MASK		((1 << EPT_LEVEL_STRIDE) - 1)

#define EPT_READ		0x1
#define EPT_WRITE		0x2
#define EPT_EXECUTE		0x4
#define EPT_VAILED_MASK		0x7
#define EPT_PS			0x80
#define EPT_MEMTYPE_TO_PTE(memtype)	((memtype) << 3)

#define VMCS_EPT_POINTER_ADDR_WB		0x6
#define VMCS_EPT_POINTER_ADDR_EPT_LENGTH	((EPT_MAX_LEVEL - 1) << 3)

#define VMCS_EXIT_QUALIFICATION_READ_BIT		0x1
#define VMCS_EXIT_QUALIFICATION_WRITE_BIT		0x2
#define VMCS_EXIT_QUALIFICATION_INST_BIT		0x4
#define VMCS_EXIT_QUALIFICATION_LINEAR_ADDR_VALID_BIT	0x80

#define ept_offset(addr, level) \
        ((addr >> (PAGE_SHIFT + (level - 1) * EPT_LEVEL_STRIDE)) \
	 & EPT_LEVEL_MASK)

struct ept_tbl {
	LIST2_DEFINE(struct ept_tbl, list);
	u64	*referenced_pte;
	void	*virt;
	phys_t	phys;
};

static void
vt_ept_invalidate_tlb(void)
{
	struct {
		u64 ept_l4tbl;
		u64 reserved;
	} desc;

	desc.ept_l4tbl = current->vm->vt.ept_l4tbl;
	desc.reserved = 0;
	asm_invept(&desc);
}

void
vt_ept_init(void)
{
	struct vt_vm_data	*vt_data = &current->vm->vt;
	u32			ctls;
	u32			ctls_or, ctls_and;
	u32			ctls2_or, ctls2_and;
	u32			exit_ctls_or, exit_ctls_and;
	u32			entry_ctls_or, entry_ctls_and;
	void			*vaddr;
	vmmerr_t		err;
	u64			pat;
	struct ept_tbl		*tbls, *tbl;
	int			i;

	asm_rdmsr32(MSR_IA32_VMX_PROCBASED_CTLS, &ctls_or, &ctls_and);
	if ((ctls_and & VMCS_PROC_BASED_VMEXEC_CTL_SEC_CTL_BIT) == 0 ||
	    ctls_or & VMCS_PROC_BASED_VMEXEC_CTL_INVLPGEXIT_BIT) {
		return;
	}
	asm_rdmsr32(MSR_IA32_VMX_PROCBASED_CTLS2, &ctls2_or, &ctls2_and);
	if ((ctls2_and & VMCS_PROC_BASED_VMEXEC_CTL2_EPT_BIT) == 0) {
		return;
	}

	asm_rdmsr32(MSR_IA32_VMX_EXIT_CTLS, &exit_ctls_or, &exit_ctls_and);
	if ((exit_ctls_or & VMCS_VMEXIT_CTL_LOAD_IA32_PAT_BIT) ||
	    (exit_ctls_and & VMCS_VMEXIT_CTL_LOAD_IA32_PAT_BIT) == 0) {
		return;
	}
	asm_rdmsr32 (MSR_IA32_VMX_ENTRY_CTLS, &entry_ctls_or, &entry_ctls_and);
	if ((entry_ctls_or & VMCS_VMENTRY_CTL_LOAD_IA32_PAT_BIT) ||
	    (entry_ctls_and & VMCS_VMENTRY_CTL_LOAD_IA32_PAT_BIT) == 0) {
		return;
	}

	if (cpu_is_bsp()) {
		printf("Enabling EPT\n");
	}

	current->u.vt.ept_enabled = true;

	spinlock_lock(&vt_data->ept_lock);
	if (vt_data->ept_l4tbl == 0) {
		LIST2_HEAD_INIT(vt_data->ept_tbl_free_list);
		LIST2_HEAD_INIT(vt_data->ept_tbl_l1_list);

		err = alloc_page(&vaddr, &vt_data->ept_l4tbl);
		if (err) {
			panic("Failed to allocate a ept l4tbl.");
		}
		memset(vaddr, 0, PAGE_SIZE);

		tbls = alloc(sizeof(struct ept_tbl) * EPT_PAGE_NUM);
		if (tbls == NULL) {
			panic("Failed to allocate struct ept_tbl");
		}
		for (i = 0; i < EPT_PAGE_NUM; i++) {
			tbl = tbls + i;
			err = alloc_page(&tbl->virt, &tbl->phys);
			if (err) {
				panic("Failed to allocate ept tbl.");
			}
			memset(tbl->virt, 0, PAGE_SIZE);
			LIST2_ADD(vt_data->ept_tbl_free_list, list, tbl);
		}
	}
	spinlock_unlock(&vt_data->ept_lock);

	ctls = ctls2_or | VMCS_PROC_BASED_VMEXEC_CTL2_EPT_BIT;
	asm_vmwrite32(VMCS_PROC_BASED_VMEXEC_CTL2, ctls);

	asm_vmread32(VMCS_PROC_BASED_VMEXEC_CTL, &ctls);
	ctls &= ~VMCS_PROC_BASED_VMEXEC_CTL_INVLPGEXIT_BIT;
	ctls |= VMCS_PROC_BASED_VMEXEC_CTL_SEC_CTL_BIT;
	asm_vmwrite32(VMCS_PROC_BASED_VMEXEC_CTL, ctls);

	asm_vmwrite32(VMCS_EXCEPTION_BMP, 0xffffbfff);
	asm_vmwrite32(VMCS_PAGEFAULT_ERRCODE_MASK, 0);
	asm_vmwrite32(VMCS_PAGEFAULT_ERRCODE_MATCH, 0xffffffff);

	asm_vmwrite64(VMCS_EPT_POINTER_ADDR,
		      vt_data->ept_l4tbl |
		      VMCS_EPT_POINTER_ADDR_WB |
		      VMCS_EPT_POINTER_ADDR_EPT_LENGTH);

	asm_vmread32(VMCS_VMEXIT_CTL, &ctls);
	ctls |= VMCS_VMEXIT_CTL_LOAD_IA32_PAT_BIT;
	asm_vmwrite32(VMCS_VMEXIT_CTL, ctls);

	asm_vmread32(VMCS_VMENTRY_CTL, &ctls);
	ctls |= VMCS_VMENTRY_CTL_LOAD_IA32_PAT_BIT;
	asm_vmwrite32(VMCS_VMENTRY_CTL, ctls);

	asm_rdmsr64(MSR_IA32_PAT, &pat);
	asm_vmwrite64(VMCS_HOST_IA32_PAT, pat);

	/*
	 * Guest IA32_PAT will be set by msr_init
	 * so that we don't need to set guest IA32_PAT here.
	 */
}

static u64 *
vt_ept_walk(phys_t gphys, void *ept_l4tbl, int req_level, int *result_level)
{
	u64		*tbl;
	int		offset;
	u64		*pte;
	int		level;
	struct ept_tbl	*stbl;
	void		*virt;
	phys_t		phys;
	struct vt_vm_data	*vt_data = &current->vm->vt;

	tbl = ept_l4tbl;

	for (level = EPT_MAX_LEVEL; level >= req_level + 1; level--) {
		offset = ept_offset(gphys, level);
		pte = &tbl[offset];

		if ((*pte & EPT_VAILED_MASK) == 0) {
			/*
			 * Allocate new ept table.
			 */
			stbl = LIST2_POP(vt_data->ept_tbl_free_list, list);
			if (stbl == NULL) {
				stbl = LIST2_POP(vt_data->ept_tbl_l1_list, list);
				if (stbl == NULL) {
					panic("No usable ept tbl");
				}
				*stbl->referenced_pte = 0;
			}
			if (level == 2) {
				LIST2_ADD(vt_data->ept_tbl_l1_list, list, stbl);
			}
			stbl->referenced_pte = pte;
			virt = stbl->virt;
			phys = stbl->phys;
			memset(virt, 0, PAGE_SIZE);
			/*
			 * Set pte.
			 */
			*pte = (phys & PAGE_MASK) |
				EPT_READ | EPT_WRITE | EPT_EXECUTE;
		} else {
			if (*pte & EPT_PS) {
				break;
			}
			phys = *pte & PAGE_MASK;
			virt = (void *)phys_to_virt(phys);
		}
		tbl = (u64 *)virt;
	};

	*result_level = level;
	return tbl + ept_offset(gphys, level);
}

static void
vt_ept_generate_pagefault(vmmerr_t vmmerr)
{
	bool wr, us;
	ulong linear_addr;
	ulong qual;

	asm_vmread(VMCS_EXIT_QUALIFICATION, &qual);
	if ((qual & VMCS_EXIT_QUALIFICATION_LINEAR_ADDR_VALID_BIT) == 0) {
		panic("EPT violation without linear address. qual 0x%lx",
		      qual);
	}
	wr = qual & VMCS_EXIT_QUALIFICATION_WRITE_BIT ? true : false;
	us = seg_user_mode();
	asm_vmread(VMCS_GUEST_LINEAR_ADDR, &linear_addr);
	printf("vt_ept_generate_pagefault 0x%lx wr %d us %d\n",
	       linear_addr, wr, us);
	mmu_generate_pagefault(vmmerr, wr, us, linear_addr);
}

static vmmerr_t
vt_ept_map_4kpage(phys_t gphys)
{
	struct vt_vm_data	*vt_data = &current->vm->vt;
	void			*ept_l4tbl;
	u64			*pte;
	u8			memtype;
	phys_t			hphys;
	int			level;

	gphys = ROUND_DOWN(gphys, PAGESIZE);

	hphys = current->vm->gmm.gp2hp(gphys);
	if (hphys == GMM_NO_MAPPING) {
		return VMMERR_RANGE;
	}

	memtype = mtrr_get_mem_type(hphys);

	spinlock_lock(&vt_data->ept_lock);
	ept_l4tbl = (void *)phys_to_virt(vt_data->ept_l4tbl);
	pte = vt_ept_walk(gphys, ept_l4tbl, 1, &level);

	if (pte == NULL) {
		spinlock_unlock(&vt_data->ept_lock);
		return VMMERR_NOMEM;
	}
	if (level != 1) {
		panic("EPT 2M page is already mapped. "
		      "gphys 0x%llx, hphys 0x%llx, pte 0x%llx",
		      gphys, hphys, *pte);
	}

	if (*pte & EPT_VAILED_MASK) {
		if ((*pte & PAGE_MASK) != (hphys & PAGE_MASK)) {
			panic("EPT PTE is already set. "
			      "gphys 0x%llx, hphys 0x%llx, pte 0x%llx",
			      gphys, hphys, *pte);
		}
	} else {
		*pte = (hphys & PAGE_MASK) |
			EPT_READ | EPT_WRITE | EPT_EXECUTE |
			EPT_MEMTYPE_TO_PTE(memtype);
	}

	spinlock_unlock(&vt_data->ept_lock);

	return VMMERR_SUCCESS;
}

static vmmerr_t
vt_ept_map_2mpage(phys_t gphys)
{
	struct vt_vm_data	*vt_data = &current->vm->vt;
	void			*ept_l4tbl;
	u64			*pte;
	u8			memtype;
	phys_t			hphys;
	phys_t			cur_hphys;
	phys_t			cur_gphys;
	int			level;

	gphys = ROUND_DOWN(gphys, PAGESIZE2M);

	hphys = current->vm->gmm.gp2hp(gphys);
	if (hphys == GMM_NO_MAPPING) {
		return VMMERR_RANGE;
	}

	if ((hphys & PAGESIZE2M_MASK) != 0) {
		return VMMERR_NOSUP;
	}

	memtype = mtrr_get_mem_type(hphys);

	cur_hphys = hphys + PAGESIZE;
	for (cur_gphys = gphys + PAGESIZE;
	     cur_gphys < gphys + PAGESIZE2M;
	     cur_gphys += PAGESIZE) {
		if (cur_hphys != current->vm->gmm.gp2hp(cur_gphys)) {
			return VMMERR_NOT_CONTIGUOUS;
		}
		if (memtype != mtrr_get_mem_type(cur_hphys)) {
			return VMMERR_NOT_CONTIGUOUS;
		}

		cur_hphys += PAGESIZE;
	}

	spinlock_lock(&vt_data->ept_lock);
	ept_l4tbl = (void *)phys_to_virt(vt_data->ept_l4tbl);
	pte = vt_ept_walk(gphys, ept_l4tbl, 2, &level);
	if (pte == NULL) {
		spinlock_unlock(&vt_data->ept_lock);
		return VMMERR_NOMEM;
	}
	if (level != 2) {
		panic("EPT 2M page is already mapped. "
		      "gphys 0x%llx, hphys 0x%llx, pte 0x%llx",
		      gphys, hphys, *pte);
	}

	if (*pte & EPT_VAILED_MASK) {
		if ((*pte & EPT_PS) == 0) {
			panic("EPT L2 PTE is already set, but not 2M page. "
			      "gphys 0x%llx, hphys 0x%llx, pte 0x%llx",
			      gphys, hphys, *pte);
		}
		if ((*pte & PAGE_MASK) != (hphys & PAGE_MASK)) {
			panic("EPT L2 PTE (2M) is already set. "
			      "gphys 0x%llx, hphys 0x%llx, pte 0x%llx",
			      gphys, hphys, *pte);
		}
	} else {
		*pte = (hphys & PAGE_MASK) | EPT_PS |
			EPT_READ | EPT_WRITE | EPT_EXECUTE |
			EPT_MEMTYPE_TO_PTE(memtype);
	}

	spinlock_unlock(&vt_data->ept_lock);

	return VMMERR_SUCCESS;
}

void
vt_ept_violation(void)
{
	phys_t			gphys;
	vmmerr_t		ret;

	asm_vmread64(VMCS_GUEST_PHYS_ADDR, &gphys);

	mmio_lock();
	ret = mmio_pagefault(gphys);
	mmio_unlock();
	if (ret != VMMERR_NODEV) {
		if (ret == VMMERR_SUCCESS) {
			return;
		}
		if (ret < VMMERR_PAGE_NOT_PRESENT ||
		    ret > VMMERR_PAGE_BAD_RESERVED_BIT) {
			panic("Failed to emulate accessing to MMIO area. ret 0x%x",
			      ret);
		} 
		vt_ept_generate_pagefault(ret);
		return;
	}

	ret = vt_ept_map_4kpage(gphys);
	if (ret != VMMERR_SUCCESS) {
		if (ret != VMMERR_RANGE) {
			panic("vt_ept_violation: Unknown error 0x%x", ret);
		}
		ret = cpu_interpreter();
		if (ret == VMMERR_SUCCESS) {
			return;
		}
		if (ret < VMMERR_PAGE_NOT_PRESENT ||
		    ret > VMMERR_PAGE_BAD_RESERVED_BIT) {
			panic("Failed to emulate accessing to no mapping area. ret 0x%x",
			      ret);
		}
		vt_ept_generate_pagefault(ret);
	}
}

void
vt_ept_pg_change(bool pg)
{
	ulong			cr4;
	u32			ctls;

	if (!current->u.vt.ept_enabled) {
		return;
	}

	if (pg) {
		/*
		 * - Enable EPT.
		 * - Don't cause VM-exit on invlpg.
		 * - Don't cause VM-exit on page fault.
		 * - Load IA32_PAT on VM-exit and VM-entry.
		 * - Set Guest CR3 to the value set by guest software
		 *   instead of shadow table.
		 */
		asm_vmread32(VMCS_PROC_BASED_VMEXEC_CTL2, &ctls);
		ctls |= VMCS_PROC_BASED_VMEXEC_CTL2_EPT_BIT;
		asm_vmwrite32(VMCS_PROC_BASED_VMEXEC_CTL2, ctls);

		asm_vmread32(VMCS_PROC_BASED_VMEXEC_CTL, &ctls);
		ctls &= ~VMCS_PROC_BASED_VMEXEC_CTL_INVLPGEXIT_BIT;
		asm_vmwrite32(VMCS_PROC_BASED_VMEXEC_CTL, ctls);

		asm_vmwrite32(VMCS_EXCEPTION_BMP, 0xffffbfff);
		asm_vmwrite32(VMCS_PAGEFAULT_ERRCODE_MASK, 0);
		asm_vmwrite32(VMCS_PAGEFAULT_ERRCODE_MATCH, 0xffffffff);

		asm_vmread32(VMCS_VMEXIT_CTL, &ctls);
		ctls |= VMCS_VMEXIT_CTL_LOAD_IA32_PAT_BIT;
		asm_vmwrite32(VMCS_VMEXIT_CTL, ctls);

		asm_vmread32(VMCS_VMENTRY_CTL, &ctls);
		ctls |= VMCS_VMENTRY_CTL_LOAD_IA32_PAT_BIT;
		asm_vmwrite32(VMCS_VMENTRY_CTL, ctls);

		asm_vmwrite32(VMCS_GUEST_CR3, current->u.vt.vr.cr3);
	} else {
		/*
		 * - Disable EPT.
		 * - Cause VM-exit on invlpg.
		 * - Cause VM-exit on page fault.
		 * - Don't load IA32_PAT on VM-exit and VM-entry.
		 *
		 * Guest CR3 will be updated by cpu_mmu_spt_updatecr3,
		 * so that we don't need to set guest CR3 here.
		 */
		asm_vmread32(VMCS_PROC_BASED_VMEXEC_CTL2, &ctls);
		ctls &= ~VMCS_PROC_BASED_VMEXEC_CTL2_EPT_BIT;
		asm_vmwrite32(VMCS_PROC_BASED_VMEXEC_CTL2, ctls);

		asm_vmread32(VMCS_PROC_BASED_VMEXEC_CTL, &ctls);
		ctls |= VMCS_PROC_BASED_VMEXEC_CTL_INVLPGEXIT_BIT;
		asm_vmwrite32(VMCS_PROC_BASED_VMEXEC_CTL, ctls);

		asm_vmwrite32(VMCS_EXCEPTION_BMP, 0xffffffff);
		asm_vmwrite32(VMCS_PAGEFAULT_ERRCODE_MASK, 0);
		asm_vmwrite32(VMCS_PAGEFAULT_ERRCODE_MATCH, 0);

		asm_vmread32(VMCS_VMEXIT_CTL, &ctls);
		ctls &= ~VMCS_VMEXIT_CTL_LOAD_IA32_PAT_BIT;
		asm_vmwrite32(VMCS_VMEXIT_CTL, ctls);

		asm_vmread32(VMCS_VMENTRY_CTL, &ctls);
		ctls &= ~VMCS_VMENTRY_CTL_LOAD_IA32_PAT_BIT;
		asm_vmwrite32(VMCS_VMENTRY_CTL, ctls);
	}
	vt_read_control_reg(CONTROL_REG_CR4, &cr4);
	vt_write_control_reg(CONTROL_REG_CR4, cr4);

	vt_ept_invalidate_tlb();
}

static void
vt_ept_map_memory_region(phys_t gphys, phys_t len)
{
	phys_t cur, cur2;
	phys_t start, end, final_end;
	vmmerr_t err;

	final_end = ROUND_UP(gphys + len, PAGESIZE);

	start = ROUND_DOWN(gphys, PAGESIZE);
	end = MIN(ROUND_UP(gphys, PAGESIZE2M), final_end);

	for (cur = start; cur < end; cur += PAGESIZE) {
		err = vt_ept_map_4kpage(cur);
		if (err != VMMERR_SUCCESS) {
			panic("Failed to map 4K page via EPT. 0x%llx", cur);
		}
	}

	start = end;
	end = MAX(ROUND_DOWN(gphys + len, PAGESIZE2M), end);

	for (cur = start; cur < end; cur += PAGESIZE2M) {
		if (vt_ept_map_2mpage(cur)) {
			for (cur2 = cur; cur2 < cur + PAGESIZE2M;
			     cur2 += PAGESIZE) {
				err = vt_ept_map_4kpage(cur2);
				if (err != VMMERR_SUCCESS) {
					panic("Failed to map 4k page via EPT. 0x%llx", cur2);
				}
			}
		}
	}

	start = end;
	end = final_end;

	for (cur = start; cur < end; cur += PAGESIZE) {
		err = vt_ept_map_4kpage(cur);
		if (err != VMMERR_SUCCESS) {
			panic("Failed to map 4K page via EPT. 0x%llx", cur);
		}
	}
}

static void
vt_ept_map_all_memory(void)
{
	int index = 0;
	phys_t gphys;
	phys_t len;
	u32 type;

	while (gmm_get_mem_map(index++, &gphys, &len, &type, NULL)) {
		if (type != MEM_TYPE_AVAILABLE) {
			continue;
		}
		vt_ept_map_memory_region(gphys, len);
	}

}

static void
vt_ept_setup_vm(void)
{
	if (currentcpu->fullvirtualize != FULLVIRTUALIZE_VT) {
		return;
	}
	if (!current->u.vt.ept_enabled) {
		return;
	}
	vt_ept_map_all_memory();
}

INITFUNC("setupvm3", vt_ept_setup_vm);
