Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(cheatcodes): Add vm.mockCalls to mock different return data for multiple calls #9024

Merged
merged 13 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions crates/cheatcodes/assets/cheatcodes.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions crates/cheatcodes/spec/src/vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,14 @@ interface Vm {
#[cheatcode(group = Evm, safety = Unsafe)]
function mockCall(address callee, uint256 msgValue, bytes calldata data, bytes calldata returnData) external;

/// Mocks multiple calls to an address, returning specified data for each call.
#[cheatcode(group = Evm, safety = Unsafe)]
function mockCalls(address callee, bytes calldata data, bytes[] calldata returnData) external;

/// Mocks multiple calls to an address with a specific `msg.value`, returning specified data for each call.
#[cheatcode(group = Evm, safety = Unsafe)]
function mockCalls(address callee, uint256 msgValue, bytes calldata data, bytes[] calldata returnData) external;

/// Reverts a call to an address with specified revert data.
#[cheatcode(group = Evm, safety = Unsafe)]
function mockCallRevert(address callee, bytes calldata data, bytes calldata revertData) external;
Expand Down
38 changes: 35 additions & 3 deletions crates/cheatcodes/src/evm/mock.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{inspector::InnerEcx, Cheatcode, Cheatcodes, CheatsCtxt, Result, Vm::*};
use alloy_primitives::{Address, Bytes, U256};
use revm::{interpreter::InstructionResult, primitives::Bytecode};
use std::cmp::Ordering;
use std::{cmp::Ordering, collections::VecDeque};

/// Mocked call data.
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -65,6 +65,25 @@ impl Cheatcode for mockCall_1Call {
}
}

impl Cheatcode for mockCalls_0Call {
fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result {
let Self { callee, data, returnData } = self;
let _ = make_acc_non_empty(callee, ccx.ecx)?;

mock_calls(ccx.state, callee, data, None, returnData, InstructionResult::Return);
Ok(Default::default())
}
}

impl Cheatcode for mockCalls_1Call {
fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result {
let Self { callee, msgValue, data, returnData } = self;
ccx.ecx.load_account(*callee)?;
mock_calls(ccx.state, callee, data, Some(msgValue), returnData, InstructionResult::Return);
Ok(Default::default())
}
}

impl Cheatcode for mockCallRevert_0Call {
fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result {
let Self { callee, data, revertData } = self;
Expand Down Expand Up @@ -94,18 +113,31 @@ impl Cheatcode for mockFunctionCall {
}
}

#[allow(clippy::ptr_arg)] // Not public API, doesn't matter
fn mock_call(
state: &mut Cheatcodes,
callee: &Address,
cdata: &Bytes,
value: Option<&U256>,
rdata: &Bytes,
ret_type: InstructionResult,
) {
mock_calls(state, callee, cdata, value, &[rdata.clone()], ret_type)
DaniPopes marked this conversation as resolved.
Show resolved Hide resolved
}

fn mock_calls(
state: &mut Cheatcodes,
callee: &Address,
cdata: &Bytes,
value: Option<&U256>,
rdata_vec: &[Bytes],
ret_type: InstructionResult,
) {
state.mocked_calls.entry(*callee).or_default().insert(
MockCallDataContext { calldata: Bytes::copy_from_slice(cdata), value: value.copied() },
MockCallReturnData { ret_type, data: Bytes::copy_from_slice(rdata) },
rdata_vec
.iter()
.map(|rdata| MockCallReturnData { ret_type, data: Bytes::copy_from_slice(rdata) })
DaniPopes marked this conversation as resolved.
Show resolved Hide resolved
.collect::<VecDeque<_>>(),
);
}

Expand Down
40 changes: 25 additions & 15 deletions crates/cheatcodes/src/inspector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ pub struct Cheatcodes {

/// Mocked calls
// **Note**: inner must a BTreeMap because of special `Ord` impl for `MockCallDataContext`
pub mocked_calls: HashMap<Address, BTreeMap<MockCallDataContext, MockCallReturnData>>,
pub mocked_calls: HashMap<Address, BTreeMap<MockCallDataContext, VecDeque<MockCallReturnData>>>,

/// Mocked functions. Maps target address to be mocked to pair of (calldata, mock address).
pub mocked_functions: HashMap<Address, HashMap<Bytes, Address>>,
Expand Down Expand Up @@ -889,26 +889,36 @@ where {
}

// Handle mocked calls
if let Some(mocks) = self.mocked_calls.get(&call.bytecode_address) {
if let Some(mocks) = self.mocked_calls.get_mut(&call.bytecode_address) {
let ctx =
MockCallDataContext { calldata: call.input.clone(), value: call.transfer_value() };
if let Some(return_data) = mocks.get(&ctx).or_else(|| {
mocks
.iter()

if let Some(return_data_queue) = match mocks.get_mut(&ctx) {
Some(queue) => Some(queue),
None => mocks
.iter_mut()
.find(|(mock, _)| {
call.input.get(..mock.calldata.len()) == Some(&mock.calldata[..]) &&
mock.value.map_or(true, |value| Some(value) == call.transfer_value())
})
.map(|(_, v)| v)
}) {
return Some(CallOutcome {
result: InterpreterResult {
result: return_data.ret_type,
output: return_data.data.clone(),
gas,
},
memory_offset: call.return_memory_offset.clone(),
});
.map(|(_, v)| v),
} {
if let Some(return_data) = if return_data_queue.len() == 1 {
// If the mocked calls stack has a single element in it, don't empty it
return_data_queue.front().map(|x| x.to_owned())
} else {
// Else, we pop the front element
return_data_queue.pop_front()
} {
return Some(CallOutcome {
result: InterpreterResult {
result: return_data.ret_type,
output: return_data.data,
gas,
},
memory_offset: call.return_memory_offset.clone(),
});
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions testdata/cheats/Vm.sol

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

59 changes: 59 additions & 0 deletions testdata/default/cheats/MockCalls.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// SPDX-License-Identifier: MIT OR Apache-2.0
pragma solidity 0.8.18;

import "ds-test/test.sol";
import "cheats/Vm.sol";

contract MockCallsTest is DSTest {
Vm constant vm = Vm(HEVM_ADDRESS);

function testMockCallsLastShouldPersist() public {
address mockUser = vm.addr(vm.randomUint());
address mockErc20 = vm.addr(vm.randomUint());
bytes memory data = abi.encodeWithSignature("balanceOf(address)", mockUser);
bytes[] memory mocks = new bytes[](2);
mocks[0] = abi.encode(2 ether);
mocks[1] = abi.encode(7.219 ether);
vm.mockCalls(mockErc20, data, mocks);
(, bytes memory ret1) = mockErc20.call(data);
assertEq(abi.decode(ret1, (uint256)), 2 ether);
(, bytes memory ret2) = mockErc20.call(data);
assertEq(abi.decode(ret2, (uint256)), 7.219 ether);
(, bytes memory ret3) = mockErc20.call(data);
assertEq(abi.decode(ret3, (uint256)), 7.219 ether);
}

function testMockCallsWithValue() public {
address mockUser = vm.addr(vm.randomUint());
address mockErc20 = vm.addr(vm.randomUint());
bytes memory data = abi.encodeWithSignature("balanceOf(address)", mockUser);
bytes[] memory mocks = new bytes[](3);
mocks[0] = abi.encode(2 ether);
mocks[1] = abi.encode(1 ether);
mocks[2] = abi.encode(6.423 ether);
vm.mockCalls(mockErc20, 1 ether, data, mocks);
(, bytes memory ret1) = mockErc20.call{value: 1 ether}(data);
assertEq(abi.decode(ret1, (uint256)), 2 ether);
(, bytes memory ret2) = mockErc20.call{value: 1 ether}(data);
assertEq(abi.decode(ret2, (uint256)), 1 ether);
(, bytes memory ret3) = mockErc20.call{value: 1 ether}(data);
assertEq(abi.decode(ret3, (uint256)), 6.423 ether);
}

function testMockCalls() public {
address mockUser = vm.addr(vm.randomUint());
address mockErc20 = vm.addr(vm.randomUint());
bytes memory data = abi.encodeWithSignature("balanceOf(address)", mockUser);
bytes[] memory mocks = new bytes[](3);
mocks[0] = abi.encode(2 ether);
mocks[1] = abi.encode(1 ether);
mocks[2] = abi.encode(6.423 ether);
vm.mockCalls(mockErc20, data, mocks);
(, bytes memory ret1) = mockErc20.call(data);
assertEq(abi.decode(ret1, (uint256)), 2 ether);
(, bytes memory ret2) = mockErc20.call(data);
assertEq(abi.decode(ret2, (uint256)), 1 ether);
(, bytes memory ret3) = mockErc20.call(data);
assertEq(abi.decode(ret3, (uint256)), 6.423 ether);
}
}
Loading