diff --git a/Cargo.lock b/Cargo.lock index 53bedfee..6ff8e273 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -458,6 +458,16 @@ version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.20" @@ -584,6 +594,7 @@ dependencies = [ "log", "polkavm-assembler", "proptest", + "spin", ] [[package]] @@ -868,6 +879,12 @@ dependencies = [ "hashbrown 0.13.2", ] +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "sdl2" version = "0.35.2" @@ -956,6 +973,15 @@ dependencies = [ "serde_json", ] +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" diff --git a/Cargo.toml b/Cargo.toml index 01ab5657..03b45ac0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,6 +60,7 @@ ruzstd = { version = "0.4.0", default-features = false } schnellru = { version = "0.2.3" } serde = { version = "1.0.203", features = ["derive"] } serde_json = { version = "1.0.117" } +spin = { version = "0.9.8", default-features = false, features = ["lock_api", "spin_mutex", "rwlock", "lazy"] } syn = "2.0.25" yansi = "0.5.1" diff --git a/crates/polkavm-common/Cargo.toml b/crates/polkavm-common/Cargo.toml index 612afc0c..2efd5f8e 100644 --- a/crates/polkavm-common/Cargo.toml +++ b/crates/polkavm-common/Cargo.toml @@ -12,6 +12,7 @@ description = "The common crate for PolkaVM" log = { workspace = true, optional = true } polkavm-assembler = { workspace = true, optional = true } blake3 = { workspace = true, optional = true } +spin = { workspace = true } [features] default = [] diff --git a/crates/polkavm-common/src/program.rs b/crates/polkavm-common/src/program.rs index d10f2fc2..a179a4a7 100644 --- a/crates/polkavm-common/src/program.rs +++ b/crates/polkavm-common/src/program.rs @@ -4,6 +4,13 @@ use crate::varint::{read_simple_varint, read_varint, write_simple_varint, write_ use core::fmt::Write; use core::ops::Range; +#[cfg(feature = "unique-id")] +use spin::RwLock; +#[cfg(feature = "unique-id")] +struct UniqueId(u64); +#[cfg(feature = "unique-id")] +static ID_COUNTER: RwLock = RwLock::new(UniqueId(0)); + #[derive(Copy, Clone)] #[repr(transparent)] pub struct RawReg(u32); @@ -3980,8 +3987,10 @@ impl ProgramBlob { #[cfg(feature = "unique-id")] { - static ID_COUNTER: core::sync::atomic::AtomicU64 = core::sync::atomic::AtomicU64::new(0); - blob.unique_id = ID_COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed); + let mut counter = ID_COUNTER.write(); + *counter = UniqueId(counter.0 + 1); + blob.unique_id = counter.0; + // The lock is dropper here. } Ok(blob) diff --git a/crates/polkavm-common/src/zygote.rs b/crates/polkavm-common/src/zygote.rs index 28b51c0a..269136d6 100644 --- a/crates/polkavm-common/src/zygote.rs +++ b/crates/polkavm-common/src/zygote.rs @@ -4,7 +4,7 @@ //! is recompiled. use core::cell::UnsafeCell; -use core::sync::atomic::{AtomicBool, AtomicI64, AtomicU32, AtomicU64}; +use core::sync::atomic::{AtomicBool, AtomicU32}; // Due to the limitations of Rust's compile time constant evaluation machinery // we need to define this struct multiple times. @@ -137,24 +137,24 @@ pub const VM_SANDBOX_MAXIMUM_NATIVE_CODE_SIZE: u32 = 2176 * 1024 * 1024 - 1; #[repr(C)] pub struct JmpBuf { - pub rip: AtomicU64, - pub rbx: AtomicU64, - pub rsp: AtomicU64, - pub rbp: AtomicU64, - pub r12: AtomicU64, - pub r13: AtomicU64, - pub r14: AtomicU64, - pub r15: AtomicU64, + pub rip: u64, + pub rbx: u64, + pub rsp: u64, + pub rbp: u64, + pub r12: u64, + pub r13: u64, + pub r14: u64, + pub r15: u64, } #[repr(C)] pub struct VmInit { - pub stack_address: AtomicU64, - pub stack_length: AtomicU64, - pub vdso_address: AtomicU64, - pub vdso_length: AtomicU64, - pub vvar_address: AtomicU64, - pub vvar_length: AtomicU64, + pub stack_address: u64, + pub stack_length: u64, + pub vdso_address: u64, + pub vdso_length: u64, + pub vvar_address: u64, + pub vvar_length: u64, /// Whether userfaultfd-based memory management is available. pub uffd_available: AtomicBool, @@ -230,7 +230,7 @@ pub struct VmCtx { _align_1: CacheAligned<()>, /// The current gas counter. - pub gas: AtomicI64, + pub gas: i64, _align_2: CacheAligned<()>, @@ -238,7 +238,7 @@ pub struct VmCtx { pub futex: AtomicU32, /// Address to which to jump to. - pub jump_into: AtomicU64, + pub jump_into: u64, /// The address of the instruction currently being executed. pub program_counter: AtomicU32, @@ -253,10 +253,10 @@ pub struct VmCtx { pub arg: AtomicU32, /// A dump of all of the registers of the VM. - pub regs: [AtomicU64; REG_COUNT], + pub regs: [u64; REG_COUNT], /// The address of the native code to call inside of the VM process, if non-zero. - pub next_native_program_counter: AtomicU64, + pub next_native_program_counter: u64, /// The state of the program's heap. pub heap_info: VmCtxHeapInfo, @@ -264,20 +264,20 @@ pub struct VmCtx { pub arg2: AtomicU32, /// Offset in shared memory to this sandbox's memory map. - pub shm_memory_map_offset: AtomicU64, + pub shm_memory_map_offset: u64, /// Number of maps to map. - pub shm_memory_map_count: AtomicU64, + pub shm_memory_map_count: u64, /// Offset in shared memory to this sandbox's code. - pub shm_code_offset: AtomicU64, + pub shm_code_offset: u64, /// Length this sandbox's code. - pub shm_code_length: AtomicU64, + pub shm_code_length: u64, /// Offset in shared memory to this sandbox's jump table. - pub shm_jump_table_offset: AtomicU64, + pub shm_jump_table_offset: u64, /// Length of sandbox's jump table, in bytes. - pub shm_jump_table_length: AtomicU64, + pub shm_jump_table_length: u64, /// Address of the sysreturn routine. - pub sysreturn_address: AtomicU64, + pub sysreturn_address: u64, /// Whether userfaultfd-based memory management is enabled. pub uffd_enabled: AtomicBool, @@ -328,7 +328,7 @@ pub const VMCTX_FUTEX_GUEST_SIGNAL: u32 = VMCTX_FUTEX_IDLE | (3 << 1); pub const VMCTX_FUTEX_GUEST_STEP: u32 = VMCTX_FUTEX_IDLE | (4 << 1); #[allow(clippy::declare_interior_mutable_const)] -const ATOMIC_U64_ZERO: AtomicU64 = AtomicU64::new(0); +const ATOMIC_U64_ZERO: u64 = 0; #[allow(clippy::new_without_default)] impl VmCtx { @@ -338,25 +338,25 @@ impl VmCtx { _align_1: CacheAligned(()), _align_2: CacheAligned(()), - gas: AtomicI64::new(0), + gas: 0, program_counter: AtomicU32::new(0), next_program_counter: AtomicU32::new(0), arg: AtomicU32::new(0), arg2: AtomicU32::new(0), regs: [ATOMIC_U64_ZERO; REG_COUNT], - jump_into: AtomicU64::new(0), - next_native_program_counter: AtomicU64::new(0), + jump_into: 0, + next_native_program_counter: 0, futex: AtomicU32::new(VMCTX_FUTEX_BUSY), - shm_memory_map_offset: AtomicU64::new(0), - shm_memory_map_count: AtomicU64::new(0), - shm_code_offset: AtomicU64::new(0), - shm_code_length: AtomicU64::new(0), - shm_jump_table_offset: AtomicU64::new(0), - shm_jump_table_length: AtomicU64::new(0), + shm_memory_map_offset: 0, + shm_memory_map_count: 0, + shm_code_offset: 0, + shm_code_length: 0, + shm_jump_table_offset: 0, + shm_jump_table_length: 0, uffd_enabled: AtomicBool::new(false), - sysreturn_address: AtomicU64::new(0), + sysreturn_address: 0, heap_base: UnsafeCell::new(0), heap_initial_threshold: UnsafeCell::new(0), heap_max_size: UnsafeCell::new(0), @@ -373,24 +373,24 @@ impl VmCtx { }), init: VmInit { - stack_address: AtomicU64::new(0), - stack_length: AtomicU64::new(0), - vdso_address: AtomicU64::new(0), - vdso_length: AtomicU64::new(0), - vvar_address: AtomicU64::new(0), - vvar_length: AtomicU64::new(0), + stack_address: 0, + stack_length: 0, + vdso_address: 0, + vdso_length: 0, + vvar_address: 0, + vvar_length: 0, uffd_available: AtomicBool::new(false), sandbox_disabled: AtomicBool::new(false), logging_enabled: AtomicBool::new(false), idle_regs: JmpBuf { - rip: AtomicU64::new(0), - rbx: AtomicU64::new(0), - rsp: AtomicU64::new(0), - rbp: AtomicU64::new(0), - r12: AtomicU64::new(0), - r13: AtomicU64::new(0), - r14: AtomicU64::new(0), - r15: AtomicU64::new(0), + rip: 0, + rbx: 0, + rsp: 0, + rbp: 0, + r12: 0, + r13: 0, + r14: 0, + r15: 0, }, }, diff --git a/crates/polkavm-zygote/Cargo.lock b/crates/polkavm-zygote/Cargo.lock index e37eb010..3f8d2879 100644 --- a/crates/polkavm-zygote/Cargo.lock +++ b/crates/polkavm-zygote/Cargo.lock @@ -2,6 +2,22 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.22" @@ -20,6 +36,7 @@ name = "polkavm-common" version = "0.15.0" dependencies = [ "polkavm-assembler", + "spin", ] [[package]] @@ -33,3 +50,18 @@ dependencies = [ "polkavm-common", "polkavm-linux-raw", ] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] diff --git a/crates/polkavm-zygote/src/main.rs b/crates/polkavm-zygote/src/main.rs index cfc9e647..bbfabb0f 100644 --- a/crates/polkavm-zygote/src/main.rs +++ b/crates/polkavm-zygote/src/main.rs @@ -4,7 +4,7 @@ use core::ptr::addr_of_mut; use core::sync::atomic::Ordering; -use core::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize}; +use core::sync::atomic::{AtomicBool, AtomicUsize}; #[rustfmt::skip] use polkavm_common::{ @@ -353,14 +353,14 @@ unsafe extern "C" fn signal_handler(signal: u32, _info: &linux_raw::siginfo_t, c } static mut RESUME_MAIN_LOOP_JMPBUF: JmpBuf = JmpBuf { - rip: AtomicU64::new(0), - rbx: AtomicU64::new(0), - rsp: AtomicU64::new(0), - rbp: AtomicU64::new(0), - r12: AtomicU64::new(0), - r13: AtomicU64::new(0), - r14: AtomicU64::new(0), - r15: AtomicU64::new(0), + rip: 0, + rbx: 0, + rsp: 0, + rbp: 0, + r12: 0, + r13: 0, + r14: 0, + r15: 0, }; extern "C" { diff --git a/crates/polkavm/src/sandbox/linux.rs b/crates/polkavm/src/sandbox/linux.rs index d74cc884..ea2a5abe 100644 --- a/crates/polkavm/src/sandbox/linux.rs +++ b/crates/polkavm/src/sandbox/linux.rs @@ -32,6 +32,7 @@ use crate::config::GasMeteringKind; use crate::page_set::PageSet; use crate::shm_allocator::{ShmAllocation, ShmAllocator}; use crate::{Gas, InterruptKind, ProgramCounter, RegValue, Segfault}; +use std::sync::RwLock; pub struct GlobalState { shared_memory: ShmAllocator, @@ -972,6 +973,7 @@ enum SandboxState { pub struct Sandbox { _lifetime_pipe: Fd, + vmctx: RwLock>, vmctx_mmap: Mmap, memory_mmap: Mmap, iouring: Option, @@ -997,22 +999,24 @@ pub struct Sandbox { idle_regs: linux_raw::user_regs_struct, } -impl Drop for Sandbox { +impl<'a> Drop for Sandbox { fn drop(&mut self) { - let vmctx = self.vmctx(); - let child_futex_wait = unsafe { *vmctx.counters.syscall_futex_wait.get() }; - let child_loop_start = unsafe { *vmctx.counters.syscall_wait_loop_start.get() }; log::debug!( "Host futex wait count: {}/{} ({:.02}%)", self.count_futex_wait, self.count_wait_loop_start, self.count_futex_wait as f64 / self.count_wait_loop_start as f64 * 100.0 ); + + let vmctx = &mut self.vmctx.write().unwrap(); + let child_futex_wait = vmctx.counters.syscall_futex_wait.get(); + let child_loop_start = vmctx.counters.syscall_wait_loop_start.get(); + log::debug!( "Child futex wait count: {}/{} ({:.02}%)", - child_futex_wait, - child_loop_start, - child_futex_wait as f64 / child_loop_start as f64 * 100.0 + unsafe { *child_futex_wait }, + unsafe { *child_loop_start }, + unsafe { *child_futex_wait as f64 }/ unsafe { *child_loop_start } as f64 * 100.0 ); } } @@ -1219,10 +1223,10 @@ impl super::Sandbox for Sandbox { let (memory_memfd, memory_mmap) = prepare_memory()?; let (vmctx_memfd, vmctx_mmap) = prepare_vmctx()?; - let vmctx = unsafe { &*vmctx_mmap.as_ptr().cast::() }; - vmctx.init.logging_enabled.store(config.enable_logger, Ordering::Relaxed); - vmctx.init.uffd_available.store(global.uffd_available, Ordering::Relaxed); - vmctx.init.sandbox_disabled.store(cfg!(polkavm_dev_debug_zygote), Ordering::Relaxed); + let vmctx = RwLock::new(unsafe { Box::from_raw(vmctx_mmap.as_ptr() as *mut VmCtx) }); + vmctx.write().unwrap().init.logging_enabled.store(config.enable_logger, Ordering::Relaxed); + vmctx.write().unwrap().init.uffd_available.store(global.uffd_available, Ordering::Relaxed); + vmctx.write().unwrap().init.sandbox_disabled.store(cfg!(polkavm_dev_debug_zygote), Ordering::Relaxed); let sandbox_flags = if !cfg!(polkavm_dev_debug_zygote) { u64::from( @@ -1297,9 +1301,7 @@ impl super::Sandbox for Sandbox { abort(); } Err(error) => { - let vmctx = &*vmctx_mmap.as_ptr().cast::(); - set_message(vmctx, format_args!("fatal error while spawning child: {error}")); - + set_message(&vmctx.write().unwrap(), format_args!("fatal error while spawning child: {error}")); abort(); } } @@ -1465,7 +1467,7 @@ impl super::Sandbox for Sandbox { } // Wait until the child process receives the vmctx memfd. - wait_for_futex(vmctx, &mut child, VMCTX_FUTEX_BUSY, VMCTX_FUTEX_IDLE)?; + wait_for_futex(&vmctx.write().unwrap(), &mut child, VMCTX_FUTEX_BUSY, VMCTX_FUTEX_IDLE)?; // Grab the child process' maps and see what we can unmap. // @@ -1481,16 +1483,16 @@ impl super::Sandbox for Sandbox { let map = Map::parse(line).ok_or_else(|| Error::from_str("failed to parse the maps of the child process"))?; match map.name { b"[stack]" => { - vmctx.init.stack_address.store(map.start, Ordering::Relaxed); - vmctx.init.stack_length.store(map.end - map.start, Ordering::Relaxed); + vmctx.write().unwrap().init.stack_address = map.start; + vmctx.write().unwrap().init.stack_length = map.end - map.start; } b"[vdso]" => { - vmctx.init.vdso_address.store(map.start, Ordering::Relaxed); - vmctx.init.vdso_length.store(map.end - map.start, Ordering::Relaxed); + vmctx.write().unwrap().init.vdso_address = map.start; + vmctx.write().unwrap().init.vdso_length = map.end - map.start; } b"[vvar]" => { - vmctx.init.vvar_address.store(map.start, Ordering::Relaxed); - vmctx.init.vvar_length.store(map.end - map.start, Ordering::Relaxed); + vmctx.write().unwrap().init.vvar_address = map.start; + vmctx.write().unwrap().init.vvar_length = map.end - map.start; } b"[vsyscall]" => { if map.is_readable { @@ -1502,15 +1504,15 @@ impl super::Sandbox for Sandbox { } // Wake the child so that it finishes initialization. - vmctx.futex.store(VMCTX_FUTEX_BUSY, Ordering::Release); - linux_raw::sys_futex_wake_one(&vmctx.futex)?; + vmctx.write().unwrap().futex.store(VMCTX_FUTEX_BUSY, Ordering::Release); + linux_raw::sys_futex_wake_one(&vmctx.write().unwrap().futex)?; let (iouring, userfaultfd) = if global.uffd_available { let iouring = linux_raw::IoUring::new(3)?; let userfaultfd = linux_raw::recvfd(socket.borrow()).map_err(|error| { let mut error = format!("failed to fetch the userfaultfd from the child process: {error}"); - if let Some(message) = get_message(vmctx) { + if let Some(message) = get_message(&vmctx.write().unwrap()) { use core::fmt::Write; write!(&mut error, " (root cause: {message})").unwrap(); } @@ -1537,7 +1539,7 @@ impl super::Sandbox for Sandbox { socket.close()?; // Wait for the child to finish initialization. - wait_for_futex(vmctx, &mut child, VMCTX_FUTEX_BUSY, VMCTX_FUTEX_IDLE)?; + wait_for_futex(&vmctx.write().unwrap(), &mut child, VMCTX_FUTEX_BUSY, VMCTX_FUTEX_IDLE)?; let mut idle_regs = linux_raw::user_regs_struct::default(); if global.uffd_available { @@ -1556,24 +1558,25 @@ impl super::Sandbox for Sandbox { linux_raw::sys_ptrace_continue(child.pid, None)?; // Then grab the worker's idle longjmp registers. - vmctx.jump_into.store(ZYGOTE_TABLES.1.ext_fetch_idle_regs, Ordering::Relaxed); - vmctx.futex.store(VMCTX_FUTEX_BUSY, Ordering::Release); - linux_raw::sys_futex_wake_one(&vmctx.futex)?; - wait_for_futex(vmctx, &mut child, VMCTX_FUTEX_BUSY, VMCTX_FUTEX_IDLE)?; + vmctx.write().unwrap().jump_into = ZYGOTE_TABLES.1.ext_fetch_idle_regs; + vmctx.write().unwrap().futex.store(VMCTX_FUTEX_BUSY, Ordering::Release); + linux_raw::sys_futex_wake_one(&vmctx.write().unwrap().futex)?; + wait_for_futex(&vmctx.write().unwrap(), &mut child, VMCTX_FUTEX_BUSY, VMCTX_FUTEX_IDLE)?; idle_regs.rax = 1; - idle_regs.rip = vmctx.init.idle_regs.rip.load(Ordering::Relaxed); - idle_regs.rbx = vmctx.init.idle_regs.rbx.load(Ordering::Relaxed); - idle_regs.sp = vmctx.init.idle_regs.rsp.load(Ordering::Relaxed); - idle_regs.rbp = vmctx.init.idle_regs.rbp.load(Ordering::Relaxed); - idle_regs.r12 = vmctx.init.idle_regs.r12.load(Ordering::Relaxed); - idle_regs.r13 = vmctx.init.idle_regs.r13.load(Ordering::Relaxed); - idle_regs.r14 = vmctx.init.idle_regs.r14.load(Ordering::Relaxed); - idle_regs.r15 = vmctx.init.idle_regs.r15.load(Ordering::Relaxed); + idle_regs.rip = vmctx.write().unwrap().init.idle_regs.rip; + idle_regs.rbx = vmctx.write().unwrap().init.idle_regs.rbx; + idle_regs.sp = vmctx.write().unwrap().init.idle_regs.rsp; + idle_regs.rbp = vmctx.write().unwrap().init.idle_regs.rbp; + idle_regs.r12 = vmctx.write().unwrap().init.idle_regs.r12; + idle_regs.r13 = vmctx.write().unwrap().init.idle_regs.r13; + idle_regs.r14 = vmctx.write().unwrap().init.idle_regs.r14; + idle_regs.r15 = vmctx.write().unwrap().init.idle_regs.r15; } Ok(Sandbox { _lifetime_pipe: lifetime_pipe_host, + vmctx, vmctx_mmap, memory_mmap, iouring, @@ -1648,9 +1651,7 @@ impl super::Sandbox for Sandbox { }; } - self.vmctx() - .shm_memory_map_count - .store(program.memory_map.len() as u64, Ordering::Relaxed); + self.vmctx.write().unwrap().shm_memory_map_count = program.memory_map.len() as u64; memory_map } else { let Some(memory_map) = global.shared_memory.alloc(core::mem::size_of::()) else { @@ -1667,42 +1668,34 @@ impl super::Sandbox for Sandbox { fd_offset: 0x10000, }; - self.vmctx().shm_memory_map_count.store(1, Ordering::Relaxed); + self.vmctx.write().unwrap().shm_memory_map_count = 1; memory_map }; - self.vmctx() - .shm_memory_map_offset - .store(memory_map.offset() as u64, Ordering::Relaxed); + self.vmctx.write().unwrap().shm_memory_map_offset = memory_map.offset() as u64; unsafe { - *self.vmctx().heap_info.heap_top.get() = u64::from(module.memory_map().heap_base()); - *self.vmctx().heap_info.heap_threshold.get() = u64::from(module.memory_map().rw_data_range().end); - *self.vmctx().heap_base.get() = module.memory_map().heap_base(); - *self.vmctx().heap_initial_threshold.get() = module.memory_map().rw_data_range().end; - *self.vmctx().heap_max_size.get() = module.memory_map().max_heap_size(); - *self.vmctx().page_size.get() = module.memory_map().page_size(); - } - - self.vmctx() - .shm_code_offset - .store(program.shm_code.offset() as u64, Ordering::Relaxed); - self.vmctx().shm_code_length.store(program.shm_code.len() as u64, Ordering::Relaxed); - self.vmctx() - .shm_jump_table_offset - .store(program.shm_jump_table.offset() as u64, Ordering::Relaxed); - self.vmctx() - .shm_jump_table_length - .store(program.shm_jump_table.len() as u64, Ordering::Relaxed); - self.vmctx().sysreturn_address.store(program.sysreturn_address, Ordering::Relaxed); - - self.vmctx().program_counter.store(0, Ordering::Relaxed); - self.vmctx().next_program_counter.store(0, Ordering::Relaxed); - self.vmctx().next_native_program_counter.store(0, Ordering::Relaxed); - self.vmctx().jump_into.store(ZYGOTE_TABLES.1.ext_load_program, Ordering::Relaxed); - self.vmctx().gas.store(0, Ordering::Relaxed); - for reg in &self.vmctx().regs { - reg.store(0, Ordering::Relaxed); + *self.vmctx.write().unwrap().heap_info.heap_top.get() = u64::from(module.memory_map().heap_base()); + *self.vmctx.write().unwrap().heap_info.heap_threshold.get() = u64::from(module.memory_map().rw_data_range().end); + *self.vmctx.write().unwrap().heap_base.get() = module.memory_map().heap_base(); + *self.vmctx.write().unwrap().heap_initial_threshold.get() = module.memory_map().rw_data_range().end; + *self.vmctx.write().unwrap().heap_max_size.get() = module.memory_map().max_heap_size(); + *self.vmctx.write().unwrap().page_size.get() = module.memory_map().page_size(); + } + + self.vmctx.write().unwrap().shm_code_offset = program.shm_code.offset() as u64; + self.vmctx.write().unwrap().shm_code_length = program.shm_code.len() as u64; + self.vmctx.write().unwrap().shm_jump_table_offset = program.shm_jump_table.offset() as u64; + self.vmctx.write().unwrap().shm_jump_table_length = program.shm_jump_table.len() as u64; + self.vmctx.write().unwrap().sysreturn_address = program.sysreturn_address; + + self.vmctx.write().unwrap().program_counter.store(0, Ordering::Relaxed); + self.vmctx.write().unwrap().next_program_counter.store(0, Ordering::Relaxed); + self.vmctx.write().unwrap().next_native_program_counter = 0; + self.vmctx.write().unwrap().jump_into = ZYGOTE_TABLES.1.ext_load_program; + self.vmctx.write().unwrap().gas = 0; + for reg in &mut self.vmctx.write().unwrap().regs { + *reg = 0; } self.dynamic_paging_enabled = module.is_dynamic_paging(); @@ -1743,7 +1736,7 @@ impl super::Sandbox for Sandbox { self.cancel_pagefault()?; Ok(()) } else { - self.vmctx().jump_into.store(ZYGOTE_TABLES.1.ext_recycle, Ordering::Relaxed); + self.vmctx.write().unwrap().jump_into = ZYGOTE_TABLES.1.ext_recycle; self.wake_oneshot_and_expect_idle() } } @@ -1762,27 +1755,29 @@ impl super::Sandbox for Sandbox { let Some(address) = compiled_module.lookup_native_code_address(pc) else { log::debug!("Tried to call into {pc} which doesn't have any native code associated with it"); self.is_program_counter_valid = true; - self.vmctx().program_counter.store(pc.0, Ordering::Relaxed); + + let vmctx = &mut self.vmctx.write().unwrap(); + vmctx.program_counter.store(pc.0, Ordering::Relaxed); if self.module.as_ref().unwrap().is_step_tracing() { - self.vmctx().next_program_counter.store(pc.0, Ordering::Relaxed); - self.vmctx() - .next_native_program_counter - .store(compiled_module.invalid_code_offset_address, Ordering::Relaxed); + vmctx.next_program_counter.store(pc.0, Ordering::Relaxed); + vmctx.next_native_program_counter = compiled_module.invalid_code_offset_address; return Ok(InterruptKind::Step); } else { - self.vmctx().next_native_program_counter.store(0, Ordering::Relaxed); + vmctx.next_native_program_counter = 0; return Ok(InterruptKind::Trap); } }; + let vmctx = &mut self.vmctx.write().unwrap(); log::trace!("Jumping into: {pc} (0x{address:x})"); - self.vmctx().next_program_counter.store(pc.0, Ordering::Relaxed); - self.vmctx().next_native_program_counter.store(address, Ordering::Relaxed); + vmctx.next_program_counter.store(pc.0, Ordering::Relaxed); + vmctx.next_native_program_counter = address; } else { + let vmctx = self.vmctx.write().unwrap(); log::trace!( "Resuming into: {} (0x{:x})", - self.vmctx().next_program_counter.load(Ordering::Relaxed), - self.vmctx().next_native_program_counter.load(Ordering::Relaxed) + vmctx.next_program_counter.load(Ordering::Relaxed), + vmctx.next_native_program_counter ); }; @@ -1799,19 +1794,19 @@ impl super::Sandbox for Sandbox { ); linux_raw::sys_ptrace_continue(self.child.pid, None)?; } else { + let vmctx = &mut self.vmctx.write().unwrap(); let compiled_module = Self::downcast_module(self.module.as_ref().unwrap()); - debug_assert_eq!(self.vmctx().futex.load(Ordering::Relaxed) & 1, VMCTX_FUTEX_IDLE); - self.vmctx() - .jump_into - .store(compiled_module.sandbox_program.0.sysenter_address, Ordering::Relaxed); + debug_assert_eq!(vmctx.futex.load(Ordering::Relaxed) & 1, VMCTX_FUTEX_IDLE); + vmctx.jump_into = compiled_module.sandbox_program.0.sysenter_address; self.wake_worker()?; self.is_program_counter_valid = true; } let result = self.wait()?; if self.module.as_ref().unwrap().gas_metering() == Some(GasMeteringKind::Async) && self.gas() < 0 { + let vmctx = &mut self.vmctx.write().unwrap(); self.is_program_counter_valid = false; - self.vmctx().next_native_program_counter.store(0, Ordering::Relaxed); + vmctx.next_native_program_counter = 0; return Ok(InterruptKind::NotEnoughGas); } @@ -1835,7 +1830,8 @@ impl super::Sandbox for Sandbox { } fn reg(&self, reg: Reg) -> RegValue { - let mut value = self.vmctx().regs[reg as usize].load(Ordering::Relaxed); + let vmctx = self.vmctx.write().unwrap(); + let mut value = vmctx.regs[reg as usize]; let compiled_module = Self::downcast_module(self.module.as_ref().unwrap()); if compiled_module.bitness == Bitness::B32 { value &= 0xffffffff; @@ -1854,15 +1850,15 @@ impl super::Sandbox for Sandbox { value &= 0xffffffff; } - self.vmctx().regs[reg as usize].store(value, Ordering::Relaxed) + self.vmctx.write().unwrap().regs[reg as usize] = value; } fn gas(&self) -> Gas { - self.vmctx().gas.load(Ordering::Relaxed) + self.vmctx.write().unwrap().gas } fn set_gas(&mut self, gas: Gas) { - self.vmctx().gas.store(gas, Ordering::Relaxed) + self.vmctx.write().unwrap().gas = gas; } fn program_counter(&self) -> Option { @@ -1870,7 +1866,8 @@ impl super::Sandbox for Sandbox { return None; } - Some(ProgramCounter(self.vmctx().program_counter.load(Ordering::Relaxed))) + let vmctx = self.vmctx.write().unwrap(); + Some(ProgramCounter(vmctx.program_counter.load(Ordering::Relaxed))) } fn next_program_counter(&self) -> Option { @@ -1878,10 +1875,11 @@ impl super::Sandbox for Sandbox { return self.next_program_counter; } - if self.vmctx().next_native_program_counter.load(Ordering::Relaxed) == 0 { + let vmctx = self.vmctx.write().unwrap(); + if vmctx.next_native_program_counter == 0 { None } else { - Some(ProgramCounter(self.vmctx().next_program_counter.load(Ordering::Relaxed))) + Some(ProgramCounter(vmctx.next_program_counter.load(Ordering::Relaxed))) } } @@ -1896,7 +1894,8 @@ impl super::Sandbox for Sandbox { return compiled_module.lookup_native_code_address(pc).map(|value| value as usize); } - let value = self.vmctx().next_native_program_counter.load(Ordering::Relaxed); + let vmctx = self.vmctx.write().unwrap(); + let value = vmctx.next_native_program_counter; if value == 0 { None } else { @@ -1910,7 +1909,7 @@ impl super::Sandbox for Sandbox { }; if !self.dynamic_paging_enabled { - self.vmctx().jump_into.store(ZYGOTE_TABLES.1.ext_reset_memory, Ordering::Relaxed); + self.vmctx.write().unwrap().jump_into = ZYGOTE_TABLES.1.ext_reset_memory; self.wake_oneshot_and_expect_idle() } else { self.free_pages(0x10000, 0xffff0000) @@ -1979,7 +1978,8 @@ impl super::Sandbox for Sandbox { } else if address >= memory_map.stack_address_low() { u64::from(address) + data.len() as u64 <= u64::from(memory_map.stack_range().end) } else if address >= memory_map.rw_data_address() { - let end = unsafe { *self.vmctx().heap_info.heap_threshold.get() }; + let vmctx = self.vmctx.write().unwrap(); + let end = unsafe { *vmctx.heap_info.heap_threshold.get() }; u64::from(address) + data.len() as u64 <= end } else { false @@ -2028,7 +2028,8 @@ impl super::Sandbox for Sandbox { } else if address >= memory_map.stack_address_low() { u64::from(address) + u64::from(length) <= u64::from(memory_map.stack_range().end) } else if address >= memory_map.rw_data_address() { - let end = unsafe { *self.vmctx().heap_info.heap_threshold.get() }; + let vmctx = self.vmctx.write().unwrap(); + let end = unsafe { *vmctx.heap_info.heap_threshold.get() }; u64::from(address) + u64::from(length) <= end } else { false @@ -2041,11 +2042,9 @@ impl super::Sandbox for Sandbox { }); } - self.vmctx().arg.store(address, Ordering::Relaxed); - self.vmctx().arg2.store(length, Ordering::Relaxed); - self.vmctx() - .jump_into - .store(ZYGOTE_TABLES.1.ext_zero_memory_chunk, Ordering::Relaxed); + self.vmctx.write().unwrap().arg.store(address, Ordering::Relaxed); + self.vmctx.write().unwrap().arg2.store(length, Ordering::Relaxed); + self.vmctx.write().unwrap().jump_into = ZYGOTE_TABLES.1.ext_zero_memory_chunk; if let Err(error) = self.wake_oneshot_and_expect_idle() { return Err(MemoryAccessError::Error(error.into())); } @@ -2106,22 +2105,23 @@ impl super::Sandbox for Sandbox { } fn heap_size(&self) -> u32 { - let heap_base = unsafe { *self.vmctx().heap_base.get() }; - let heap_top = unsafe { *self.vmctx().heap_info.heap_top.get() }; + let vmctx = self.vmctx.write().unwrap(); + let heap_base = unsafe { *vmctx.heap_base.get() }; + let heap_top = unsafe { *vmctx.heap_info.heap_top.get() }; (heap_top - u64::from(heap_base)) as u32 } fn sbrk(&mut self, size: u32) -> Result, Error> { if size == 0 { - return Ok(Some(unsafe { *self.vmctx().heap_info.heap_top.get() as u32 })); + return Ok(Some(unsafe { *self.vmctx.write().unwrap().heap_info.heap_top.get() as u32 })); } - self.vmctx().jump_into.store(ZYGOTE_TABLES.1.ext_sbrk, Ordering::Relaxed); - self.vmctx().arg.store(size, Ordering::Relaxed); + self.vmctx.write().unwrap().jump_into = ZYGOTE_TABLES.1.ext_sbrk; + self.vmctx.write().unwrap().arg.store(size, Ordering::Relaxed); self.wake_worker()?; self.wait()?.expect_idle()?; - let result = self.vmctx().arg.load(Ordering::Relaxed); + let result = self.vmctx.write().unwrap().arg.load(Ordering::Relaxed); if result == 0 { Ok(None) } else { @@ -2141,9 +2141,9 @@ impl super::Sandbox for Sandbox { fn offset_table() -> OffsetTable { OffsetTable { arg: get_field_offset!(VmCtx::new(), |base| base.arg.as_ptr()), - gas: get_field_offset!(VmCtx::new(), |base| base.gas.as_ptr()), + gas: get_field_offset!(VmCtx::new(), |base| &base.gas), heap_info: get_field_offset!(VmCtx::new(), |base| &base.heap_info), - next_native_program_counter: get_field_offset!(VmCtx::new(), |base| base.next_native_program_counter.as_ptr()), + next_native_program_counter: get_field_offset!(VmCtx::new(), |base| &base.next_native_program_counter), next_program_counter: get_field_offset!(VmCtx::new(), |base| base.next_program_counter.as_ptr()), program_counter: get_field_offset!(VmCtx::new(), |base| base.program_counter.as_ptr()), regs: get_field_offset!(VmCtx::new(), |base| base.regs.as_ptr()), @@ -2180,13 +2180,14 @@ impl Interrupt { impl Sandbox { #[inline] - fn vmctx(&self) -> &VmCtx { - unsafe { &*self.vmctx_mmap.as_ptr().cast::() } + fn vmctx(&mut self) -> &mut RwLock> { + &mut self.vmctx } fn wake_worker(&self) -> Result<(), Error> { - self.vmctx().futex.store(VMCTX_FUTEX_BUSY, Ordering::Release); - linux_raw::sys_futex_wake_one(&self.vmctx().futex).map(|_| ()) + let vmctx = self.vmctx.write().unwrap(); + vmctx.futex.store(VMCTX_FUTEX_BUSY, Ordering::Release); + linux_raw::sys_futex_wake_one(&vmctx.futex).map(|_| ()) } fn wake_oneshot_and_expect_idle(&mut self) -> Result<(), Error> { @@ -2202,7 +2203,7 @@ impl Sandbox { 'outer: loop { self.count_wait_loop_start += 1; - let state = self.vmctx().futex.load(Ordering::Relaxed); + let state = self.vmctx.write().unwrap().futex.load(Ordering::Relaxed); if state == VMCTX_FUTEX_IDLE { core::sync::atomic::fence(Ordering::Acquire); return Ok(Interrupt::Idle); @@ -2213,13 +2214,13 @@ impl Sandbox { let compiled_module = Self::downcast_module(self.module.as_ref().unwrap()); if compiled_module.bitness == Bitness::B32 { - for reg_value in &self.vmctx().regs { - reg_value.fetch_and(0xffffffff, Ordering::Relaxed); + for reg in &mut self.vmctx.write().unwrap().regs { + *reg &= 0xffffffff; } } - let address = self.vmctx().next_native_program_counter.load(Ordering::Relaxed); - let gas = self.vmctx().gas.load(Ordering::Relaxed); + let address = self.vmctx.write().unwrap().next_native_program_counter; + let gas = self.vmctx.write().unwrap().gas; if gas < 0 { // Read the gas cost from the machine code. let gas_metering_trap_offset = match compiled_module.bitness { @@ -2231,9 +2232,7 @@ impl Sandbox { return Err(Error::from_str("internal error: address underflow after a trap")); }; - self.vmctx() - .next_native_program_counter - .store(compiled_module.native_code_origin + offset, Ordering::Relaxed); + self.vmctx.write().unwrap().next_native_program_counter = compiled_module.native_code_origin + offset; let Some(program_counter) = compiled_module.program_counter_by_native_code_address(address, false) else { return Err(Error::from_str("internal error: failed to find the program counter based on the native program counter when running out of gas")); @@ -2251,16 +2250,20 @@ impl Sandbox { )); }; - let gas_cost = u32::from_le_bytes([gas_cost[0], gas_cost[1], gas_cost[2], gas_cost[3]]); - let gas = self.vmctx().gas.fetch_add(i64::from(gas_cost), Ordering::Relaxed); + let gas_cost = u32::from_le_bytes([gas_cost[0], gas_cost[1], gas_cost[2], gas_cost[3]]) as i64; + let gas = self.vmctx.write().unwrap().gas; + self.vmctx.write().unwrap().gas += gas_cost; + log::trace!( "Out of gas; program counter = {program_counter}, reverting gas: {gas} -> {new_gas} (gas cost: {gas_cost})", new_gas = gas + i64::from(gas_cost) ); self.is_program_counter_valid = true; - self.vmctx().program_counter.store(program_counter.0, Ordering::Relaxed); - self.vmctx().next_program_counter.store(program_counter.0, Ordering::Relaxed); + self.vmctx.write().unwrap().program_counter.store(program_counter.0, Ordering::Relaxed); + self.vmctx.write().unwrap() + .next_program_counter + .store(program_counter.0, Ordering::Relaxed); return Ok(Interrupt::NotEnoughGas); } else { @@ -2270,8 +2273,8 @@ impl Sandbox { }; self.is_program_counter_valid = true; - self.vmctx().program_counter.store(program_counter.0, Ordering::Relaxed); - self.vmctx().next_native_program_counter.store(0, Ordering::Relaxed); + self.vmctx.write().unwrap().program_counter.store(program_counter.0, Ordering::Relaxed); + self.vmctx.write().unwrap().next_native_program_counter = 0; return Ok(Interrupt::Trap); } @@ -2279,7 +2282,7 @@ impl Sandbox { if state == VMCTX_FUTEX_GUEST_ECALLI { core::sync::atomic::fence(Ordering::Acquire); - let hostcall = self.vmctx().arg.load(Ordering::Relaxed); + let hostcall = self.vmctx.write().unwrap().arg.load(Ordering::Relaxed); return Ok(Interrupt::Ecalli(hostcall)); } @@ -2307,9 +2310,8 @@ impl Sandbox { if !self.iouring_futex_wait_queued { self.count_futex_wait += 1; - let vmctx = unsafe { &*self.vmctx_mmap.as_ptr().cast::() }; iouring - .queue_futex_wait(IO_URING_JOB_FUTEX_WAIT, &vmctx.futex, VMCTX_FUTEX_BUSY) + .queue_futex_wait(IO_URING_JOB_FUTEX_WAIT, &self.vmctx.write().unwrap().futex, VMCTX_FUTEX_BUSY) .expect("internal error: io_uring queue overflow"); self.iouring_futex_wait_queued = true; } @@ -2404,20 +2406,21 @@ impl Sandbox { for _ in 0..spin_target { core::hint::spin_loop(); - if self.vmctx().futex.load(Ordering::Relaxed) != VMCTX_FUTEX_BUSY { + if self.vmctx.write().unwrap().futex.load(Ordering::Relaxed) != VMCTX_FUTEX_BUSY { continue 'outer; } } for _ in 0..yield_target { let _ = linux_raw::sys_sched_yield(); - if self.vmctx().futex.load(Ordering::Relaxed) != VMCTX_FUTEX_BUSY { + if self.vmctx.write().unwrap().futex.load(Ordering::Relaxed) != VMCTX_FUTEX_BUSY { continue 'outer; } } self.count_futex_wait += 1; - match linux_raw::sys_futex_wait(&self.vmctx().futex, VMCTX_FUTEX_BUSY, Some(Duration::from_millis(100))) { + let status = linux_raw::sys_futex_wait(&self.vmctx.write().unwrap().futex, VMCTX_FUTEX_BUSY, Some(Duration::from_millis(100))); + match status { Ok(()) => continue, Err(error) if error.errno() == linux_raw::EAGAIN || error.errno() == linux_raw::EINTR => continue, Err(error) if error.errno() == linux_raw::ETIMEDOUT => { @@ -2437,7 +2440,8 @@ impl Sandbox { } log::trace!("Child #{} is not running anymore: {status}", self.child.pid); - let message = get_message(self.vmctx()); + + let message = get_message(&self.vmctx.write().unwrap()); if let Some(message) = message { Err(Error::from(format!("{status}: {message}"))) } else { @@ -2457,9 +2461,11 @@ impl Sandbox { }; self.is_program_counter_valid = true; - self.vmctx().program_counter.store(program_counter.0, Ordering::Relaxed); - self.vmctx().next_program_counter.store(program_counter.0, Ordering::Relaxed); - self.vmctx().next_native_program_counter.store(regs.rip, Ordering::Relaxed); + self.vmctx.write().unwrap().program_counter.store(program_counter.0, Ordering::Relaxed); + self.vmctx.write().unwrap() + .next_program_counter + .store(program_counter.0, Ordering::Relaxed); + self.vmctx.write().unwrap().next_native_program_counter = regs.rip; for reg in Reg::ALL { use polkavm_common::regmap::NativeReg::*; @@ -2487,7 +2493,7 @@ impl Sandbox { value &= 0xffffffff; } - self.vmctx().regs[reg as usize].store(value, Ordering::Relaxed); + self.vmctx.write().unwrap().regs[reg as usize] = value; } Ok(()) @@ -2517,7 +2523,8 @@ impl Sandbox { r15 => &mut regs.r15, }; - *value = self.vmctx().regs[reg as usize].load(Ordering::Relaxed); + let vmctx = self.vmctx.write().unwrap(); + *value = vmctx.regs[reg as usize]; } linux_raw::sys_ptrace_setregs(self.child.pid, ®s)?; @@ -2527,15 +2534,12 @@ impl Sandbox { fn cancel_pagefault(&mut self) -> Result<(), Error> { log::trace!("Cancelling pending page fault..."); - - // This will cancel *our own* `futex_wait` which we've queued up with iouring. - linux_raw::sys_futex_wake_one(&self.vmctx().futex)?; - + linux_raw::sys_futex_wake_one(&self.vmctx.write().unwrap().futex)?; // Forcibly return the worker to the idle state. // // The worker's currently stuck in a page fault somewhere inside guest code, // so it can't do this by itself. - self.vmctx().futex.store(VMCTX_FUTEX_IDLE, Ordering::Release); + self.vmctx.write().unwrap().futex.store(VMCTX_FUTEX_IDLE, Ordering::Release); linux_raw::sys_ptrace_setregs(self.child.pid, &self.idle_regs)?; linux_raw::sys_ptrace_continue(self.child.pid, None) }