pub(super) mod gdt;
mod idt;
mod syscall;
use align_ext::AlignExt;
use cfg_if::cfg_if;
use log::debug;
use spin::Once;
use super::ex_table::ExTable;
use crate::{
arch::{
if_tdx_enabled,
irq::{disable_local, enable_local},
},
cpu::context::{CpuException, CpuExceptionInfo, PageFaultErrorCode},
cpu_local_cell,
mm::{
kspace::{KERNEL_PAGE_TABLE, LINEAR_MAPPING_BASE_VADDR, LINEAR_MAPPING_VADDR_RANGE},
page_prop::{CachePolicy, PageProperty},
PageFlags, PrivilegedPageFlags as PrivFlags, MAX_USERSPACE_VADDR, PAGE_SIZE,
},
task::disable_preempt,
trap::call_irq_callback_functions,
};
cfg_if! {
if #[cfg(feature = "cvm_guest")] {
use tdx_guest::{tdcall, handle_virtual_exception};
use crate::arch::tdx_guest::TrapFrameWrapper;
}
}
cpu_local_cell! {
static KERNEL_INTERRUPT_NESTED_LEVEL: u8 = 0;
}
#[derive(Debug, Default, Clone, Copy)]
#[repr(C)]
#[expect(missing_docs)]
pub struct TrapFrame {
pub rax: usize,
pub rbx: usize,
pub rcx: usize,
pub rdx: usize,
pub rsi: usize,
pub rdi: usize,
pub rbp: usize,
pub rsp: usize,
pub r8: usize,
pub r9: usize,
pub r10: usize,
pub r11: usize,
pub r12: usize,
pub r13: usize,
pub r14: usize,
pub r15: usize,
pub _pad: usize,
pub trap_num: usize,
pub error_code: usize,
pub rip: usize,
pub cs: usize,
pub rflags: usize,
}
pub unsafe fn init() {
unsafe { gdt::init() };
idt::init();
unsafe { syscall::init() };
}
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq)]
#[repr(C)]
#[expect(missing_docs)]
pub struct UserContext {
pub general: GeneralRegs,
pub trap_num: usize,
pub error_code: usize,
}
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq)]
#[repr(C)]
#[expect(missing_docs)]
pub struct GeneralRegs {
pub rax: usize,
pub rbx: usize,
pub rcx: usize,
pub rdx: usize,
pub rsi: usize,
pub rdi: usize,
pub rbp: usize,
pub rsp: usize,
pub r8: usize,
pub r9: usize,
pub r10: usize,
pub r11: usize,
pub r12: usize,
pub r13: usize,
pub r14: usize,
pub r15: usize,
pub rip: usize,
pub rflags: usize,
pub fsbase: usize,
pub gsbase: usize,
}
impl UserContext {
pub fn get_syscall_num(&self) -> usize {
self.general.rax
}
pub fn get_syscall_ret(&self) -> usize {
self.general.rax
}
pub fn set_syscall_ret(&mut self, ret: usize) {
self.general.rax = ret;
}
pub fn get_syscall_args(&self) -> [usize; 6] {
[
self.general.rdi,
self.general.rsi,
self.general.rdx,
self.general.r10,
self.general.r8,
self.general.r9,
]
}
pub fn set_ip(&mut self, ip: usize) {
self.general.rip = ip;
}
pub fn set_sp(&mut self, sp: usize) {
self.general.rsp = sp;
}
pub fn get_sp(&self) -> usize {
self.general.rsp
}
pub fn set_tls(&mut self, tls: usize) {
self.general.fsbase = tls;
}
}
pub fn is_kernel_interrupted() -> bool {
KERNEL_INTERRUPT_NESTED_LEVEL.load() != 0
}
#[no_mangle]
extern "sysv64" fn trap_handler(f: &mut TrapFrame) {
fn enable_local_if(cond: bool) {
if cond {
enable_local();
}
}
fn disable_local_if(cond: bool) {
if cond {
disable_local();
}
}
let was_irq_enabled =
f.rflags as u64 & x86_64::registers::rflags::RFlags::INTERRUPT_FLAG.bits() > 0;
match CpuException::to_cpu_exception(f.trap_num as u16) {
#[cfg(feature = "cvm_guest")]
Some(CpuException::VIRTUALIZATION_EXCEPTION) => {
let ve_info = tdcall::get_veinfo().expect("#VE handler: fail to get VE info\n");
enable_local_if(was_irq_enabled);
let mut trapframe_wrapper = TrapFrameWrapper(&mut *f);
handle_virtual_exception(&mut trapframe_wrapper, &ve_info);
*f = *trapframe_wrapper.0;
disable_local_if(was_irq_enabled);
}
Some(CpuException::PAGE_FAULT) => {
let page_fault_addr = x86_64::registers::control::Cr2::read_raw();
enable_local_if(was_irq_enabled);
if (0..MAX_USERSPACE_VADDR).contains(&(page_fault_addr as usize)) {
handle_user_page_fault(f, page_fault_addr);
} else {
handle_kernel_page_fault(f, page_fault_addr);
}
disable_local_if(was_irq_enabled);
}
Some(exception) => {
enable_local_if(was_irq_enabled);
panic!(
"cannot handle kernel CPU exception: {:?}, trapframe: {:?}",
exception, f
);
}
None => {
KERNEL_INTERRUPT_NESTED_LEVEL.add_assign(1);
call_irq_callback_functions(f, f.trap_num);
KERNEL_INTERRUPT_NESTED_LEVEL.sub_assign(1);
}
}
}
#[expect(clippy::type_complexity)]
static USER_PAGE_FAULT_HANDLER: Once<fn(&CpuExceptionInfo) -> core::result::Result<(), ()>> =
Once::new();
pub fn inject_user_page_fault_handler(
handler: fn(info: &CpuExceptionInfo) -> core::result::Result<(), ()>,
) {
USER_PAGE_FAULT_HANDLER.call_once(|| handler);
}
fn handle_user_page_fault(f: &mut TrapFrame, page_fault_addr: u64) {
let info = CpuExceptionInfo {
page_fault_addr: page_fault_addr as usize,
id: f.trap_num,
error_code: f.error_code,
};
let handler = USER_PAGE_FAULT_HANDLER
.get()
.expect("a page fault handler is missing");
let res = handler(&info);
if res.is_ok() {
return;
}
if let Some(addr) = ExTable::find_recovery_inst_addr(f.rip) {
f.rip = addr;
} else {
panic!("Cannot handle user page fault; Trapframe:{:#x?}.", f);
}
}
fn handle_kernel_page_fault(f: &TrapFrame, page_fault_vaddr: u64) {
let preempt_guard = disable_preempt();
let error_code = PageFaultErrorCode::from_bits_truncate(f.error_code);
debug!(
"kernel page fault: address {:?}, error code {:?}",
page_fault_vaddr as *const (), error_code
);
assert!(
LINEAR_MAPPING_VADDR_RANGE.contains(&(page_fault_vaddr as usize)),
"kernel page fault: the address is outside the range of the linear mapping",
);
const SUPPORTED_ERROR_CODES: PageFaultErrorCode = PageFaultErrorCode::PRESENT
.union(PageFaultErrorCode::WRITE)
.union(PageFaultErrorCode::INSTRUCTION);
assert!(
SUPPORTED_ERROR_CODES.contains(error_code),
"kernel page fault: the error code is not supported",
);
assert!(
!error_code.contains(PageFaultErrorCode::INSTRUCTION),
"kernel page fault: the direct mapping cannot be executed",
);
assert!(
!error_code.contains(PageFaultErrorCode::PRESENT),
"kernel page fault: the direct mapping already exists",
);
let page_table = KERNEL_PAGE_TABLE
.get()
.expect("kernel page fault: the kernel page table is not initialized");
let vaddr = (page_fault_vaddr as usize).align_down(PAGE_SIZE);
let paddr = vaddr - LINEAR_MAPPING_BASE_VADDR;
let priv_flags = if_tdx_enabled!({
PrivFlags::SHARED | PrivFlags::GLOBAL
} else {
PrivFlags::GLOBAL
});
let prop = PageProperty {
flags: PageFlags::RW,
cache: CachePolicy::Uncacheable,
priv_flags,
};
let mut cursor = page_table
.cursor_mut(&preempt_guard, &(vaddr..vaddr + PAGE_SIZE))
.unwrap();
unsafe { cursor.map(crate::mm::kspace::MappedItem::Untracked(paddr, 1, prop)) }.unwrap();
}