diff --git a/riscv-rt/macros/src/lib.rs b/riscv-rt/macros/src/lib.rs index 2eff2e82..f48f7870 100644 --- a/riscv-rt/macros/src/lib.rs +++ b/riscv-rt/macros/src/lib.rs @@ -313,11 +313,122 @@ pub fn loop_global_asm(input: TokenStream) -> TokenStream { res.parse().unwrap() } +#[derive(Clone, Copy)] enum RiscvArch { Rv32, Rv64, } +const TRAP_SIZE: usize = 16; + +#[rustfmt::skip] +const TRAP_FRAME: [&str; TRAP_SIZE] = [ + "ra", + "t0", + "t1", + "t2", + "t3", + "t4", + "t5", + "t6", + "a0", + "a1", + "a2", + "a3", + "a4", + "a5", + "a6", + "a7", +]; + +fn store_trap bool>(arch: RiscvArch, mut filter: T) -> String { + let (width, store) = match arch { + RiscvArch::Rv32 => (4, "sw"), + RiscvArch::Rv64 => (8, "sd"), + }; + let mut stores = Vec::new(); + for (i, reg) in TRAP_FRAME + .iter() + .enumerate() + .filter(|(_, ®)| filter(reg)) + { + stores.push(format!("{store} {reg}, {i}*{width}(sp)")); + } + stores.join("\n") +} + +fn load_trap(arch: RiscvArch) -> String { + let (width, load) = match arch { + RiscvArch::Rv32 => (4, "lw"), + RiscvArch::Rv64 => (8, "ld"), + }; + let mut loads = Vec::new(); + for (i, reg) in TRAP_FRAME.iter().enumerate() { + loads.push(format!("{load} {reg}, {i}*{width}(sp)")); + } + loads.join("\n") +} + +#[proc_macro] +pub fn weak_start_trap_riscv32(_input: TokenStream) -> TokenStream { + weak_start_trap(RiscvArch::Rv32) +} + +#[proc_macro] +pub fn weak_start_trap_riscv64(_input: TokenStream) -> TokenStream { + weak_start_trap(RiscvArch::Rv64) +} + +fn weak_start_trap(arch: RiscvArch) -> TokenStream { + let width = match arch { + RiscvArch::Rv32 => 4, + RiscvArch::Rv64 => 8, + }; + // ensure we do not break that sp is 16-byte aligned + if (TRAP_SIZE * width) % 16 != 0 { + return parse::Error::new(Span::call_site(), "Trap frame size must be 16-byte aligned") + .to_compile_error() + .into(); + } + let store = store_trap(arch, |_| true); + let load = load_trap(arch); + + #[cfg(feature = "s-mode")] + let ret = "sret"; + #[cfg(not(feature = "s-mode"))] + let ret = "mret"; + + let instructions: proc_macro2::TokenStream = format!( + " +core::arch::global_asm!( +\".section .trap, \\\"ax\\\" +.align {width} +.weak _start_trap +_start_trap: + addi sp, sp, - {TRAP_SIZE} * {width} + {store} + add a0, sp, zero + jal ra, _start_trap_rust + {load} + addi sp, sp, {TRAP_SIZE} * {width} + {ret} +\");" + ) + .parse() + .unwrap(); + + #[cfg(feature = "v-trap")] + let v_trap = v_trap::continue_interrupt_trap(arch); + #[cfg(not(feature = "v-trap"))] + let v_trap = proc_macro2::TokenStream::new(); + + quote!( + #instructions + #v_trap + ) + .into() +} + #[proc_macro_attribute] pub fn interrupt_riscv32(args: TokenStream, input: TokenStream) -> TokenStream { interrupt(args, input, RiscvArch::Rv32) @@ -376,7 +487,7 @@ fn interrupt(args: TokenStream, input: TokenStream, _arch: RiscvArch) -> TokenSt #[cfg(not(feature = "v-trap"))] let start_trap = proc_macro2::TokenStream::new(); #[cfg(feature = "v-trap")] - let start_trap = v_trap::start_interrupt_trap_asm(ident, _arch); + let start_trap = v_trap::start_interrupt_trap(ident, _arch); quote!( #start_trap @@ -390,45 +501,41 @@ fn interrupt(args: TokenStream, input: TokenStream, _arch: RiscvArch) -> TokenSt mod v_trap { use super::*; - const TRAP_SIZE: usize = 16; - - #[rustfmt::skip] - const TRAP_FRAME: [&str; TRAP_SIZE] = [ - "ra", - "t0", - "t1", - "t2", - "t3", - "t4", - "t5", - "t6", - "a0", - "a1", - "a2", - "a3", - "a4", - "a5", - "a6", - "a7", - ]; - - pub(crate) fn start_interrupt_trap_asm( + pub(crate) fn start_interrupt_trap( ident: &syn::Ident, arch: RiscvArch, ) -> proc_macro2::TokenStream { - let function = ident.to_string(); - let (width, store, load) = match arch { - RiscvArch::Rv32 => (4, "sw", "lw"), - RiscvArch::Rv64 => (8, "sd", "ld"), + let interrupt = ident.to_string(); + let width = match arch { + RiscvArch::Rv32 => 4, + RiscvArch::Rv64 => 8, }; + let store = store_trap(arch, |r| r == "a0"); - let (mut stores, mut loads) = (Vec::new(), Vec::new()); - for (i, r) in TRAP_FRAME.iter().enumerate() { - stores.push(format!(" {store} {r}, {i}*{width}(sp)")); - loads.push(format!(" {load} {r}, {i}*{width}(sp)")); - } - let store = stores.join("\n"); - let load = loads.join("\n"); + let instructions = format!( + " +core::arch::global_asm!( + \".section .trap, \\\"ax\\\" + .align {width} + .global _start_{interrupt}_trap + _start_{interrupt}_trap: + addi sp, sp, -{TRAP_SIZE} * {width} // allocate space for trap frame + {store} // store trap partially (only register a0) + la a0, {interrupt} // load interrupt handler address into a0 + j _continue_interrupt_trap // jump to common part of interrupt trap +\");" + ); + + instructions.parse().unwrap() + } + + pub(crate) fn continue_interrupt_trap(arch: RiscvArch) -> proc_macro2::TokenStream { + let width = match arch { + RiscvArch::Rv32 => 4, + RiscvArch::Rv64 => 8, + }; + let store = store_trap(arch, |reg| reg != "a0"); + let load = load_trap(arch); #[cfg(feature = "s-mode")] let ret = "sret"; @@ -439,16 +546,15 @@ mod v_trap { " core::arch::global_asm!( \".section .trap, \\\"ax\\\" - .align {width} - .global _start_{function}_trap - _start_{function}_trap: - addi sp, sp, - {TRAP_SIZE} * {width} -{store} - call {function} -{load} - addi sp, sp, {TRAP_SIZE} * {width} - {ret}\" -);" + .align {width} // TODO is this necessary? + .global _continue_interrupt_trap + _continue_interrupt_trap: + {store} // store trap partially (all registers except a0) + jalr ra, a0, 0 // jump to corresponding interrupt handler (address stored in a0) + {load} // restore trap frame + addi sp, sp, {TRAP_SIZE} * {width} // deallocate space for trap frame + {ret} // return from interrupt +\");" ); instructions.parse().unwrap() diff --git a/riscv-rt/src/asm.rs b/riscv-rt/src/asm.rs index d8f98e51..384b2b03 100644 --- a/riscv-rt/src/asm.rs +++ b/riscv-rt/src/asm.rs @@ -277,65 +277,10 @@ _pre_init_trap: j _pre_init_trap", ); -/// Trap entry point (_start_trap). It saves caller saved registers, calls -/// _start_trap_rust, restores caller saved registers and then returns. -/// -/// # Usage -/// -/// The macro takes 5 arguments: -/// - `$STORE`: the instruction used to store a register in the stack (e.g. `sd` for riscv64) -/// - `$LOAD`: the instruction used to load a register from the stack (e.g. `ld` for riscv64) -/// - `$BYTES`: the number of bytes used to store a register (e.g. 8 for riscv64) -/// - `$TRAP_SIZE`: the number of registers to store in the stack (e.g. 32 for all the user registers) -/// - list of tuples of the form `($REG, $LOCATION)`, where: -/// - `$REG`: the register to store/load -/// - `$LOCATION`: the location in the stack where to store/load the register -#[rustfmt::skip] -macro_rules! trap_handler { - ($STORE:ident, $LOAD:ident, $BYTES:literal, $TRAP_SIZE:literal, [$(($REG:ident, $LOCATION:literal)),*]) => { - // ensure we do not break that sp is 16-byte aligned - const _: () = assert!(($TRAP_SIZE * $BYTES) % 16 == 0); - global_asm!( - " - .section .trap, \"ax\" - .weak _start_trap - _start_trap:", - // save space for trap handler in stack - concat!("addi sp, sp, -", stringify!($TRAP_SIZE * $BYTES)), - // save registers in the desired order - $(concat!(stringify!($STORE), " ", stringify!($REG), ", ", stringify!($LOCATION * $BYTES), "(sp)"),)* - // call rust trap handler - "add a0, sp, zero - jal ra, _start_trap_rust", - // restore registers in the desired order - $(concat!(stringify!($LOAD), " ", stringify!($REG), ", ", stringify!($LOCATION * $BYTES), "(sp)"),)* - // free stack - concat!("addi sp, sp, ", stringify!($TRAP_SIZE * $BYTES)), - ); - cfg_global_asm!( - // return from trap - #[cfg(feature = "s-mode")] - "sret", - #[cfg(not(feature = "s-mode"))] - "mret", - ); - }; -} - -#[rustfmt::skip] #[cfg(riscv32)] -trap_handler!( - sw, lw, 4, 16, - [(ra, 0), (t0, 1), (t1, 2), (t2, 3), (t3, 4), (t4, 5), (t5, 6), (t6, 7), - (a0, 8), (a1, 9), (a2, 10), (a3, 11), (a4, 12), (a5, 13), (a6, 14), (a7, 15)] -); -#[rustfmt::skip] +riscv_rt_macros::weak_start_trap_riscv32!(); #[cfg(riscv64)] -trap_handler!( - sd, ld, 8, 16, - [(ra, 0), (t0, 1), (t1, 2), (t2, 3), (t3, 4), (t4, 5), (t5, 6), (t6, 7), - (a0, 8), (a1, 9), (a2, 10), (a3, 11), (a4, 12), (a5, 13), (a6, 14), (a7, 15)] -); +riscv_rt_macros::weak_start_trap_riscv64!(); #[cfg(feature = "v-trap")] cfg_global_asm!( @@ -345,7 +290,7 @@ cfg_global_asm!( .type _vector_table, @function .option push - .balign 0x100 // TODO check if this is the correct alignment + .balign 0x4 // TODO check if this is the correct alignment .option norelax .option norvc