diff --git a/Cargo.lock b/Cargo.lock index d14463a..f97e7b8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -361,6 +361,7 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" name = "motsu" version = "0.3.0" dependencies = [ + "alloy-sol-types", "const-hex", "dashmap", "motsu-proc", diff --git a/crates/motsu-proc/src/test.rs b/crates/motsu-proc/src/test.rs index 2e88961..2e2425d 100644 --- a/crates/motsu-proc/src/test.rs +++ b/crates/motsu-proc/src/test.rs @@ -3,7 +3,7 @@ use proc_macro::TokenStream; use quote::quote; use syn::{parse_macro_input, FnArg}; -/// Defines a unit test that provides access to Stylus execution context. +/// Defines a unit test that provides access to Stylus' execution context. /// /// For more information see [`crate::test`]. pub(crate) fn test(_attr: &TokenStream, input: TokenStream) -> TokenStream { @@ -21,49 +21,31 @@ pub(crate) fn test(_attr: &TokenStream, input: TokenStream) -> TokenStream { } // Whether 1 or none contracts will be declared. - let arg_binding_and_ty = match fn_args - .into_iter() - .map(|arg| { - let FnArg::Typed(arg) = arg else { - error!(@arg, "unexpected receiver argument in test signature"); - }; - let contract_arg_binding = &arg.pat; - let contract_ty = &arg.ty; - Ok((contract_arg_binding, contract_ty)) - }) - .collect::, _>>() - { - Ok(res) => res, - Err(err) => return err.to_compile_error().into(), - }; + let contract_declarations = fn_args.into_iter().map(|arg| { + let FnArg::Typed(arg) = arg else { + error!(arg, "unexpected receiver argument in test signature"); + }; + let contract_arg_binding = &arg.pat; + let contract_ty = &arg.ty; - let contract_arg_defs = - arg_binding_and_ty.iter().map(|(arg_binding, contract_ty)| { - // Test case assumes, that contract's variable has `&mut` reference - // to contract's type. - quote! { - #arg_binding: &mut #contract_ty - } - }); - - let contract_args = - arg_binding_and_ty.iter().map(|(_arg_binding, contract_ty)| { - // Pass mutable reference to the contract. - quote! { - &mut <#contract_ty>::default() - } - }); + // Test case assumes, that contract's variable has `&mut` reference + // to contract's type. + quote! { + let mut #contract_arg_binding = <#contract_ty>::default(); + let #contract_arg_binding = &mut #contract_arg_binding; + } + }); - // Declare test case closure. - // Pass mut ref to the test closure and call it. - // Reset storage for the test context and return test's output. + // Output full testcase function. + // Declare contract. + // And in the end, reset storage for test context. quote! { #( #attrs )* #[test] fn #fn_name() #fn_return_type { use ::motsu::prelude::DefaultStorage; - let test = | #( #contract_arg_defs ),* | #fn_block; - let res = test( #( #contract_args ),* ); + #( #contract_declarations )* + let res = #fn_block; ::motsu::prelude::Context::current().reset_storage(); res } diff --git a/crates/motsu/Cargo.toml b/crates/motsu/Cargo.toml index 5bfbe22..a11620e 100644 --- a/crates/motsu/Cargo.toml +++ b/crates/motsu/Cargo.toml @@ -9,6 +9,7 @@ repository.workspace = true version = "0.3.0" [dependencies] +alloy-sol-types.workspace = true const-hex.workspace = true once_cell.workspace = true tiny-keccak.workspace = true diff --git a/crates/motsu/src/context.rs b/crates/motsu/src/context.rs deleted file mode 100644 index 381d504..0000000 --- a/crates/motsu/src/context.rs +++ /dev/null @@ -1,109 +0,0 @@ -//! Unit-testing context for Stylus contracts. - -use std::{collections::HashMap, ptr}; - -use dashmap::DashMap; -use once_cell::sync::Lazy; -use stylus_sdk::{alloy_primitives::uint, prelude::StorageType}; - -use crate::prelude::{Bytes32, WORD_BYTES}; - -/// Context of stylus unit tests associated with the current test thread. -#[allow(clippy::module_name_repetitions)] -pub struct Context { - thread_name: ThreadName, -} - -impl Context { - /// Get test context associated with the current test thread. - #[must_use] - pub fn current() -> Self { - Self { thread_name: ThreadName::current() } - } - - /// Get the value at `key` in storage. - pub(crate) fn get_bytes(self, key: &Bytes32) -> Bytes32 { - let storage = STORAGE.entry(self.thread_name).or_default(); - storage.contract_data.get(key).copied().unwrap_or_default() - } - - /// Get the raw value at `key` in storage and write it to `value`. - pub(crate) unsafe fn get_bytes_raw(self, key: *const u8, value: *mut u8) { - let key = read_bytes32(key); - - write_bytes32(value, self.get_bytes(&key)); - } - - /// Set the value at `key` in storage to `value`. - pub(crate) fn set_bytes(self, key: Bytes32, value: Bytes32) { - let mut storage = STORAGE.entry(self.thread_name).or_default(); - storage.contract_data.insert(key, value); - } - - /// Set the raw value at `key` in storage to `value`. - pub(crate) unsafe fn set_bytes_raw(self, key: *const u8, value: *const u8) { - let (key, value) = (read_bytes32(key), read_bytes32(value)); - self.set_bytes(key, value); - } - - /// Clears storage, removing all key-value pairs associated with the current - /// test thread. - pub fn reset_storage(self) { - STORAGE.remove(&self.thread_name); - } -} - -/// Storage mock: A global mutable key-value store. -/// Allows concurrent access. -/// -/// The key is the name of the test thread, and the value is the storage of the -/// test case. -static STORAGE: Lazy> = - Lazy::new(DashMap::new); - -/// Test thread name metadata. -#[derive(Clone, Eq, PartialEq, Hash)] -struct ThreadName(String); - -impl ThreadName { - /// Get the name of the current test thread. - fn current() -> Self { - let current_thread_name = std::thread::current() - .name() - .expect("should retrieve current thread name") - .to_string(); - Self(current_thread_name) - } -} - -/// Storage for unit test's mock data. -#[derive(Default)] -struct MockStorage { - /// Contract's mock data storage. - contract_data: HashMap, -} - -/// Read the word from location pointed by `ptr`. -unsafe fn read_bytes32(ptr: *const u8) -> Bytes32 { - let mut res = Bytes32::default(); - ptr::copy(ptr, res.as_mut_ptr(), WORD_BYTES); - res -} - -/// Write the word `bytes` to the location pointed by `ptr`. -unsafe fn write_bytes32(ptr: *mut u8, bytes: Bytes32) { - ptr::copy(bytes.as_ptr(), ptr, WORD_BYTES); -} - -/// Initializes fields of contract storage and child contract storages with -/// default values. -pub trait DefaultStorage: StorageType { - /// Initializes fields of contract storage and child contract storages with - /// default values. - #[must_use] - fn default() -> Self { - unsafe { Self::new(uint!(0_U256), 0) } - } -} - -impl DefaultStorage for ST {} diff --git a/crates/motsu/src/context/environment.rs b/crates/motsu/src/context/environment.rs new file mode 100644 index 0000000..52b1617 --- /dev/null +++ b/crates/motsu/src/context/environment.rs @@ -0,0 +1,83 @@ +//! Module with unit test EVM environment for Stylus contracts. + +/// Block Timestamp - Epoch timestamp: 1st January 2025 `00::00::00`. +const BLOCK_TIMESTAMP: u64 = 1_735_689_600; +/// Arbitrum's CHAID ID. +const CHAIN_ID: u64 = 42161; + +/// Dummy contract address set for tests. +const CONTRACT_ADDRESS: &[u8; 42] = + b"0xdCE82b5f92C98F27F116F70491a487EFFDb6a2a9"; + +/// Externally Owned Account (EOA) code hash. +const EOA_CODEHASH: &[u8; 66] = + b"0xc5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470"; + +/// Dummy msg sender set for tests. +const MSG_SENDER: &[u8; 42] = b"0xDeaDbeefdEAdbeefdEadbEEFdeadbeEFdEaDbeeF"; + +pub(crate) struct Environment { + account_codehash: [u8; 66], + block_timestamp: u64, + chain_id: u64, + contract_address: [u8; 42], + events: Vec>, + msg_sender: [u8; 42], +} + +impl Default for Environment { + /// Creates default environment for a test case. + fn default() -> Environment { + Self { + account_codehash: *EOA_CODEHASH, + block_timestamp: BLOCK_TIMESTAMP, + chain_id: CHAIN_ID, + contract_address: *CONTRACT_ADDRESS, + events: Vec::new(), + msg_sender: *MSG_SENDER, + } + } +} + +impl Environment { + /// Gets the code hash of the account at the given address. + pub(crate) fn account_codehash(&self) -> [u8; 66] { + self.account_codehash + } + + /// Gets a bounded estimate of the Unix timestamp at which the Sequencer + /// sequenced the transaction. + pub(crate) fn block_timestamp(&self) -> u64 { + self.block_timestamp + } + + /// Gets the chain ID of the current chain. + pub(crate) fn chain_id(&self) -> u64 { + self.chain_id + } + + /// Gets the address of the current program. + pub(crate) fn contract_address(&self) -> [u8; 42] { + self.contract_address + } + + /// Gets the address of the account that called the program. + pub(crate) fn msg_sender(&self) -> [u8; 42] { + self.msg_sender + } + + /// Stores emitted event. + pub(crate) fn store_event(&mut self, event: &[u8]) { + self.events.push(Vec::from(event)); + } + + /// Removes all the stored events. + pub(crate) fn clear_events(&mut self) { + self.events.clear(); + } + + /// Gets all emitted events. + pub(crate) fn events(&self) -> Vec> { + self.events.clone() + } +} diff --git a/crates/motsu/src/context/mod.rs b/crates/motsu/src/context/mod.rs new file mode 100644 index 0000000..a46c032 --- /dev/null +++ b/crates/motsu/src/context/mod.rs @@ -0,0 +1,188 @@ +//! Unit-testing context for Stylus contracts. + +use std::{collections::HashMap, ptr}; + +use dashmap::DashMap; +use once_cell::sync::Lazy; +use stylus_sdk::{alloy_primitives::uint, prelude::StorageType}; + +use crate::prelude::{Bytes32, WORD_BYTES}; + +mod environment; + +use environment::Environment; + +/// Context of stylus unit tests associated with the current test thread. +#[allow(clippy::module_name_repetitions)] +pub struct Context { + thread_name: ThreadName, +} + +impl Context { + /// Get test context associated with the current test thread. + #[must_use] + pub fn current() -> Self { + Self { thread_name: ThreadName::current() } + } + + /// Get the value at `key` in storage. + pub(crate) fn get_bytes(self, key: &Bytes32) -> Bytes32 { + let context = EVM.entry(self.thread_name).or_default(); + context.storage.contract_data.get(key).copied().unwrap_or_default() + } + + /// Get the raw value at `key` in storage and write it to `value`. + pub(crate) unsafe fn get_bytes_raw(self, key: *const u8, value: *mut u8) { + let key = read_bytes32(key); + + write_bytes32(value, self.get_bytes(&key)); + } + + /// Set the value at `key` in storage to `value`. + pub(crate) fn set_bytes(self, key: Bytes32, value: Bytes32) { + let mut context = EVM.entry(self.thread_name).or_default(); + context.storage.contract_data.insert(key, value); + } + + /// Set the raw value at `key` in storage to `value`. + pub(crate) unsafe fn set_bytes_raw(self, key: *const u8, value: *const u8) { + let (key, value) = (read_bytes32(key), read_bytes32(value)); + self.set_bytes(key, value); + } + + /// Clears storage, removing all key-value pairs associated with the current + /// test thread. + pub fn reset_storage(self) { + EVM.remove(&self.thread_name); + } + + /// Gets the code hash of the account at the given address. + pub fn account_codehash(self) -> [u8; 66] { + let context = EVM.entry(self.thread_name).or_default(); + context.environment.account_codehash() + } + + /// Gets a bounded estimate of the Unix timestamp at which the Sequencer + /// sequenced the transaction. + pub fn block_timestamp(self) -> u64 { + let context = EVM.entry(self.thread_name).or_default(); + context.environment.block_timestamp() + } + + /// Gets the chain ID of the current chain. + pub fn chain_id(self) -> u64 { + let context = EVM.entry(self.thread_name).or_default(); + context.environment.chain_id() + } + + /// Gets the address of the current program. + pub fn contract_address(self) -> [u8; 42] { + let context = EVM.entry(self.thread_name).or_default(); + context.environment.contract_address() + } + + /// Emits an EVM log with the given number of topics and data, the first + /// bytes of which should be the 32-byte-aligned topic data. + /// + /// Data contains `topics` amount of topics and then encoded event in `data` + /// buffer. + pub(crate) unsafe fn emit_log( + self, + data: *const u8, + len: usize, + topics: usize, + ) { + // https://github.com/OffchainLabs/stylus-sdk-rs/blob/v0.6.0/stylus-sdk/src/evm.rs#L38-L52 + let buffer = read_bytes(data, len); + let encoded_event: Vec = + buffer.clone().into_iter().skip(topics * WORD_BYTES).collect(); + panic!("Data: {:x?}, len: {:x?}, topics: {:x?}, Log: {:x?}, encoded_event: {:x?}", data, len, topics, buffer, encoded_event); + let mut context = EVM.entry(self.thread_name).or_default(); + context.environment.store_event(&encoded_event); + } + + /// Gets the address of the account that called the program. + pub fn msg_sender(self) -> [u8; 42] { + let context = EVM.entry(self.thread_name).or_default(); + context.environment.msg_sender() + } + + /// Removes all events for a test case. + pub fn clear_events(self) { + let mut context = EVM.entry(self.thread_name).or_default(); + context.environment.clear_events(); + } + + /// Gets all emitted events for a test case. + pub fn events(self) -> Vec> { + let context = EVM.entry(self.thread_name).or_default(); + context.environment.events() + } +} + +#[derive(Default)] +struct TestCase { + storage: MockStorage, + environment: Environment, +} + +/// A global mutable key-value store mockig EVM behaviour. +/// Allows concurrent access. +/// +/// The key is the name of the test thread, +/// and the value is the context of the test case. +static EVM: Lazy> = Lazy::new(DashMap::new); + +/// Test thread name metadata. +#[derive(Clone, Eq, PartialEq, Hash)] +struct ThreadName(String); + +impl ThreadName { + /// Get the name of the current test thread. + fn current() -> Self { + let current_thread_name = std::thread::current() + .name() + .expect("should retrieve current thread name") + .to_string(); + Self(current_thread_name) + } +} + +/// Storage for unit test's mock data. +#[derive(Default)] +struct MockStorage { + /// Contract's mock data storage. + contract_data: HashMap, +} + +/// Read data from location pointed by `ptr`. +unsafe fn read_bytes(ptr: *const u8, len: usize) -> Vec { + let mut res = Vec::with_capacity(len); + ptr::copy(ptr, res.as_mut_ptr(), len); + res +} + +/// Read the word from location pointed by `ptr`. +unsafe fn read_bytes32(ptr: *const u8) -> Bytes32 { + let mut res = Bytes32::default(); + ptr::copy(ptr, res.as_mut_ptr(), WORD_BYTES); + res +} + +/// Write the word `bytes` to the location pointed by `ptr`. +unsafe fn write_bytes32(ptr: *mut u8, bytes: Bytes32) { + ptr::copy(bytes.as_ptr(), ptr, WORD_BYTES); +} + +/// Initializes fields of contract storage and child contract storages with +/// default values. +pub trait DefaultStorage: StorageType { + /// Initializes fields of contract storage and child contract storages with + /// default values. + #[must_use] + fn default() -> Self { + unsafe { Self::new(uint!(0_U256), 0) } + } +} + +impl DefaultStorage for ST {} diff --git a/crates/motsu/src/event.rs b/crates/motsu/src/event.rs new file mode 100644 index 0000000..3a1e7a7 --- /dev/null +++ b/crates/motsu/src/event.rs @@ -0,0 +1,19 @@ +use alloy_sol_types::SolEvent; + +use crate::context::Context; + +/// Asserts that the `expected` event was emitted in a test case. +pub fn emits_event(expected: E) -> bool +where + E: SolEvent, +{ + let expected = expected.encode_data(); + let events = Context::current().events(); + panic!("Expected: {:x?}, events: {:x?}", expected, events); + events.into_iter().rev().any(|event| expected == event) +} + +/// Removes all emitted events for a test case. +pub fn clear_events() { + Context::current().clear_events(); +} diff --git a/crates/motsu/src/lib.rs b/crates/motsu/src/lib.rs index 7636b75..7816935 100644 --- a/crates/motsu/src/lib.rs +++ b/crates/motsu/src/lib.rs @@ -14,7 +14,7 @@ //! //! Note that we require contracts to implement //! `stylus_sdk::prelude::StorageType`. This trait is typically implemented by -//! default with `stylus_proc::sol_storage` or `stylus_proc::storage` macros. +//! default with `stylus_proc::sol_storage` macro. //! //! ```rust //! #[cfg(test)] @@ -44,7 +44,9 @@ //! //! [test_attribute]: crate::test mod context; +mod event; pub mod prelude; mod shims; +pub use event::{clear_events, emits_event}; pub use motsu_proc::test; diff --git a/crates/motsu/src/shims.rs b/crates/motsu/src/shims.rs index 162f9e2..e131bff 100644 --- a/crates/motsu/src/shims.rs +++ b/crates/motsu/src/shims.rs @@ -116,24 +116,9 @@ pub fn storage_flush_cache(_: bool) { // No-op: we don't use the cache in our unit-tests. } -/// Dummy msg sender set for tests. -pub const MSG_SENDER: &[u8; 42] = b"0xDeaDbeefdEAdbeefdEadbEEFdeadbeEFdEaDbeeF"; - -/// Dummy contract address set for tests. -pub const CONTRACT_ADDRESS: &[u8; 42] = - b"0xdCE82b5f92C98F27F116F70491a487EFFDb6a2a9"; - -/// Arbitrum's CHAID ID. -pub const CHAIN_ID: u64 = 42161; - -/// Externally Owned Account (EOA) code hash. -pub const EOA_CODEHASH: &[u8; 66] = - b"0xc5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470"; - -/// Gets the address of the account that called the program. -/// -/// For normal L2-to-L2 transactions the semantics are equivalent to that of the -/// EVM's [`CALLER`] opcode, including in cases arising from [`DELEGATE_CALL`]. +/// Gets the address of the account that called the program. For normal +/// L2-to-L2 transactions the semantics are equivalent to that of the EVM's +/// [`CALLER`] opcode, including in cases arising from [`DELEGATE_CALL`]. /// /// For L1-to-L2 retryable ticket transactions, the top-level sender's address /// will be aliased. See [`Retryable Ticket Address Aliasing`][aliasing] for @@ -148,7 +133,8 @@ pub const EOA_CODEHASH: &[u8; 66] = /// May panic if fails to parse `MSG_SENDER` as an address. #[no_mangle] pub unsafe extern "C" fn msg_sender(sender: *mut u8) { - let addr = const_hex::const_decode_to_array::<20>(MSG_SENDER).unwrap(); + let msg_sender = Context::current().msg_sender(); + let addr = const_hex::const_decode_to_array::<20>(&msg_sender).unwrap(); std::ptr::copy(addr.as_ptr(), sender, 20); } @@ -162,8 +148,9 @@ pub unsafe extern "C" fn msg_sender(sender: *mut u8) { /// May panic if fails to parse `CONTRACT_ADDRESS` as an address. #[no_mangle] pub unsafe extern "C" fn contract_address(address: *mut u8) { + let contract_address = Context::current().contract_address(); let addr = - const_hex::const_decode_to_array::<20>(CONTRACT_ADDRESS).unwrap(); + const_hex::const_decode_to_array::<20>(&contract_address).unwrap(); std::ptr::copy(addr.as_ptr(), address, 20); } @@ -173,7 +160,7 @@ pub unsafe extern "C" fn contract_address(address: *mut u8) { /// [`CHAINID`]: https://www.evm.codes/#46 #[no_mangle] pub unsafe extern "C" fn chainid() -> u64 { - CHAIN_ID + Context::current().chain_id() } /// Emits an EVM log with the given number of topics and data, the first bytes @@ -189,12 +176,11 @@ pub unsafe extern "C" fn chainid() -> u64 { /// [`LOG3`]: https://www.evm.codes/#a3 /// [`LOG4`]: https://www.evm.codes/#a4 #[no_mangle] -pub unsafe extern "C" fn emit_log(_: *const u8, _: usize, _: usize) { - // No-op: we don't check for events in our unit-tests. +pub unsafe extern "C" fn emit_log(data: *const u8, len: usize, topics: usize) { + Context::current().emit_log(data, len, topics); } /// Gets the code hash of the account at the given address. -/// /// The semantics are equivalent to that of the EVM's [`EXT_CODEHASH`] opcode. /// Note that the code hash of an account without code will be the empty hash /// `keccak("") = @@ -207,8 +193,9 @@ pub unsafe extern "C" fn emit_log(_: *const u8, _: usize, _: usize) { /// May panic if fails to parse `ACCOUNT_CODEHASH` as a keccack hash. #[no_mangle] pub unsafe extern "C" fn account_codehash(_address: *const u8, dest: *mut u8) { + let account_codehash = Context::current().account_codehash(); let account_codehash = - const_hex::const_decode_to_array::<32>(EOA_CODEHASH).unwrap(); + const_hex::const_decode_to_array::<32>(&account_codehash).unwrap(); std::ptr::copy(account_codehash.as_ptr(), dest, 32); } @@ -341,6 +328,5 @@ pub unsafe extern "C" fn delegate_call_contract( /// [`Block Numbers and Time`]: https://developer.arbitrum.io/time #[no_mangle] pub unsafe extern "C" fn block_timestamp() -> u64 { - // Epoch timestamp: 1st January 2025 00::00::00 - 1_735_689_600 + Context::current().block_timestamp() }