diff --git a/crates/cheatcodes/assets/cheatcodes.json b/crates/cheatcodes/assets/cheatcodes.json index 4517f075e7fb..3a16e2a5b74d 100644 --- a/crates/cheatcodes/assets/cheatcodes.json +++ b/crates/cheatcodes/assets/cheatcodes.json @@ -6531,6 +6531,46 @@ "status": "stable", "safety": "unsafe" }, + { + "func": { + "id": "prank_2", + "description": "Sets the *next* delegate call's `msg.sender` to be the input address.", + "declaration": "function prank(address msgSender, bool delegateCall) external;", + "visibility": "external", + "mutability": "", + "signature": "prank(address,bool)", + "selector": "0xa7f8bf5c", + "selectorBytes": [ + 167, + 248, + 191, + 92 + ] + }, + "group": "evm", + "status": "stable", + "safety": "unsafe" + }, + { + "func": { + "id": "prank_3", + "description": "Sets the *next* delegate call's `msg.sender` to be the input address, and the `tx.origin` to be the second input.", + "declaration": "function prank(address msgSender, address txOrigin, bool delegateCall) external;", + "visibility": "external", + "mutability": "", + "signature": "prank(address,address,bool)", + "selector": "0x7d73d042", + "selectorBytes": [ + 125, + 115, + 208, + 66 + ] + }, + "group": "evm", + "status": "stable", + "safety": "unsafe" + }, { "func": { "id": "prevrandao_0", @@ -8251,6 +8291,46 @@ "status": "stable", "safety": "unsafe" }, + { + "func": { + "id": "startPrank_2", + "description": "Sets all subsequent delegate calls' `msg.sender` to be the input address until `stopPrank` is called.", + "declaration": "function startPrank(address msgSender, bool delegateCall) external;", + "visibility": "external", + "mutability": "", + "signature": "startPrank(address,bool)", + "selector": "0x1cc0b435", + "selectorBytes": [ + 28, + 192, + 180, + 53 + ] + }, + "group": "evm", + "status": "stable", + "safety": "unsafe" + }, + { + "func": { + "id": "startPrank_3", + "description": "Sets all subsequent delegate calls' `msg.sender` to be the input address until `stopPrank` is called, and the `tx.origin` to be the second input.", + "declaration": "function startPrank(address msgSender, address txOrigin, bool delegateCall) external;", + "visibility": "external", + "mutability": "", + "signature": "startPrank(address,address,bool)", + "selector": "0x4eb859b5", + "selectorBytes": [ + 78, + 184, + 89, + 181 + ] + }, + "group": "evm", + "status": "stable", + "safety": "unsafe" + }, { "func": { "id": "startStateDiffRecording", diff --git a/crates/cheatcodes/spec/src/vm.rs b/crates/cheatcodes/spec/src/vm.rs index 980bab066a3a..5abe51442e27 100644 --- a/crates/cheatcodes/spec/src/vm.rs +++ b/crates/cheatcodes/spec/src/vm.rs @@ -491,6 +491,22 @@ interface Vm { #[cheatcode(group = Evm, safety = Unsafe)] function startPrank(address msgSender, address txOrigin) external; + /// Sets the *next* delegate call's `msg.sender` to be the input address. + #[cheatcode(group = Evm, safety = Unsafe)] + function prank(address msgSender, bool delegateCall) external; + + /// Sets all subsequent delegate calls' `msg.sender` to be the input address until `stopPrank` is called. + #[cheatcode(group = Evm, safety = Unsafe)] + function startPrank(address msgSender, bool delegateCall) external; + + /// Sets the *next* delegate call's `msg.sender` to be the input address, and the `tx.origin` to be the second input. + #[cheatcode(group = Evm, safety = Unsafe)] + function prank(address msgSender, address txOrigin, bool delegateCall) external; + + /// Sets all subsequent delegate calls' `msg.sender` to be the input address until `stopPrank` is called, and the `tx.origin` to be the second input. + #[cheatcode(group = Evm, safety = Unsafe)] + function startPrank(address msgSender, address txOrigin, bool delegateCall) external; + /// Resets subsequent calls' `msg.sender` to be `address(this)`. #[cheatcode(group = Evm, safety = Unsafe)] function stopPrank() external; diff --git a/crates/cheatcodes/src/evm/prank.rs b/crates/cheatcodes/src/evm/prank.rs index fe5418b3157f..a607d953f76d 100644 --- a/crates/cheatcodes/src/evm/prank.rs +++ b/crates/cheatcodes/src/evm/prank.rs @@ -16,6 +16,8 @@ pub struct Prank { pub depth: u64, /// Whether the prank stops by itself after the next call pub single_call: bool, + /// Whether the prank should be be applied to delegate call + pub delegate_call: bool, /// Whether the prank has been used yet (false if unused) pub used: bool, } @@ -29,8 +31,18 @@ impl Prank { new_origin: Option
, depth: u64, single_call: bool, + delegate_call: bool, ) -> Self { - Self { prank_caller, prank_origin, new_caller, new_origin, depth, single_call, used: false } + Self { + prank_caller, + prank_origin, + new_caller, + new_origin, + depth, + single_call, + delegate_call, + used: false, + } } /// Apply the prank by setting `used` to true iff it is false @@ -47,28 +59,56 @@ impl Prank { impl Cheatcode for prank_0Call { fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { let Self { msgSender } = self; - prank(ccx, msgSender, None, true) + prank(ccx, msgSender, None, true, false) } } impl Cheatcode for startPrank_0Call { fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { let Self { msgSender } = self; - prank(ccx, msgSender, None, false) + prank(ccx, msgSender, None, false, false) } } impl Cheatcode for prank_1Call { fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { let Self { msgSender, txOrigin } = self; - prank(ccx, msgSender, Some(txOrigin), true) + prank(ccx, msgSender, Some(txOrigin), true, false) } } impl Cheatcode for startPrank_1Call { fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { let Self { msgSender, txOrigin } = self; - prank(ccx, msgSender, Some(txOrigin), false) + prank(ccx, msgSender, Some(txOrigin), false, false) + } +} + +impl Cheatcode for prank_2Call { + fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { + let Self { msgSender, delegateCall } = self; + prank(ccx, msgSender, None, true, *delegateCall) + } +} + +impl Cheatcode for startPrank_2Call { + fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { + let Self { msgSender, delegateCall } = self; + prank(ccx, msgSender, None, false, *delegateCall) + } +} + +impl Cheatcode for prank_3Call { + fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { + let Self { msgSender, txOrigin, delegateCall } = self; + prank(ccx, msgSender, Some(txOrigin), true, *delegateCall) + } +} + +impl Cheatcode for startPrank_3Call { + fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { + let Self { msgSender, txOrigin, delegateCall } = self; + prank(ccx, msgSender, Some(txOrigin), false, *delegateCall) } } @@ -85,6 +125,7 @@ fn prank( new_caller: &Address, new_origin: Option<&Address>, single_call: bool, + delegate_call: bool, ) -> Result { let prank = Prank::new( ccx.caller, @@ -93,8 +134,15 @@ fn prank( new_origin.copied(), ccx.ecx.journaled_state.depth(), single_call, + delegate_call, ); + // Ensure that code exists at `msg.sender` if delegate calling. + if delegate_call { + let code = ccx.code(*new_caller)?; + ensure!(!code.is_empty(), "cannot `prank` delegate call from an EOA"); + } + if let Some(Prank { used, single_call: current_single_call, .. }) = ccx.state.prank { ensure!(used, "cannot overwrite a prank until it is applied at least once"); // This case can only fail if the user calls `vm.startPrank` and then `vm.prank` later on. diff --git a/crates/cheatcodes/src/inspector.rs b/crates/cheatcodes/src/inspector.rs index f5238d810c8b..c0046123e448 100644 --- a/crates/cheatcodes/src/inspector.rs +++ b/crates/cheatcodes/src/inspector.rs @@ -37,7 +37,7 @@ use itertools::Itertools; use rand::{rngs::StdRng, Rng, SeedableRng}; use revm::{ interpreter::{ - opcode as op, CallInputs, CallOutcome, CallScheme, CreateInputs, CreateOutcome, + opcode as op, CallInputs, CallOutcome, CallScheme, CallValue, CreateInputs, CreateOutcome, EOFCreateInputs, EOFCreateKind, Gas, InstructionResult, Interpreter, InterpreterAction, InterpreterResult, }, @@ -833,6 +833,19 @@ impl Cheatcodes { // Apply our prank if let Some(prank) = &self.prank { + // Apply delegate call. call.caller will not equal prank.prank_caller + if let CallScheme::DelegateCall = call.scheme { + if prank.delegate_call { + call.target_address = prank.new_caller; + call.caller = prank.new_caller; + let acc = ecx.journaled_state.account(prank.new_caller); + call.value = CallValue::Apparent(acc.info.balance); + if let Some(new_origin) = prank.new_origin { + ecx.env.tx.caller = new_origin; + } + } + } + if ecx.journaled_state.depth() >= prank.depth && call.caller == prank.prank_caller { let mut prank_applied = false; @@ -946,7 +959,7 @@ impl Cheatcodes { initialized = false; old_balance = U256::ZERO; } - let kind = match call.scheme { + let kind: Vm::AccountAccessKind = match call.scheme { CallScheme::Call => crate::Vm::AccountAccessKind::Call, CallScheme::CallCode => crate::Vm::AccountAccessKind::CallCode, CallScheme::DelegateCall => crate::Vm::AccountAccessKind::DelegateCall, diff --git a/testdata/cheats/Vm.sol b/testdata/cheats/Vm.sol index 5b6750237add..092554cc7142 100644 --- a/testdata/cheats/Vm.sol +++ b/testdata/cheats/Vm.sol @@ -322,6 +322,8 @@ interface Vm { function pauseTracing() external view; function prank(address msgSender) external; function prank(address msgSender, address txOrigin) external; + function prank(address msgSender, bool delegateCall) external; + function prank(address msgSender, address txOrigin, bool delegateCall) external; function prevrandao(bytes32 newPrevrandao) external; function prevrandao(uint256 newPrevrandao) external; function projectRoot() external view returns (string memory path); @@ -408,6 +410,8 @@ interface Vm { function startMappingRecording() external; function startPrank(address msgSender) external; function startPrank(address msgSender, address txOrigin) external; + function startPrank(address msgSender, bool delegateCall) external; + function startPrank(address msgSender, address txOrigin, bool delegateCall) external; function startStateDiffRecording() external; function stopAndReturnStateDiff() external returns (AccountAccess[] memory accountAccesses); function stopBroadcast() external; diff --git a/testdata/default/cheats/Prank.t.sol b/testdata/default/cheats/Prank.t.sol index f7dd9b714f80..37ff52504574 100644 --- a/testdata/default/cheats/Prank.t.sol +++ b/testdata/default/cheats/Prank.t.sol @@ -85,9 +85,114 @@ contract NestedPranker { } } +contract ImplementationTest { + uint256 public num; + address public sender; + + function assertCorrectCaller(address expectedSender) public { + require(msg.sender == expectedSender); + } + + function assertCorrectOrigin(address expectedOrigin) public { + require(tx.origin == expectedOrigin); + } + + function setNum(uint256 _num) public { + num = _num; + } +} + +contract ProxyTest { + uint256 public num; + address public sender; +} + contract PrankTest is DSTest { Vm constant vm = Vm(HEVM_ADDRESS); + function testPrankDelegateCallPrank2() public { + ProxyTest proxy = new ProxyTest(); + ImplementationTest impl = new ImplementationTest(); + vm.prank(address(proxy), true); + + // Assert correct `msg.sender` + (bool success,) = + address(impl).delegatecall(abi.encodeWithSignature("assertCorrectCaller(address)", address(proxy))); + require(success, "prank2: delegate call failed assertCorrectCaller"); + + // Assert storage updates + uint256 num = 42; + vm.prank(address(proxy), true); + (bool successTwo,) = address(impl).delegatecall(abi.encodeWithSignature("setNum(uint256)", num)); + require(successTwo, "prank2: delegate call failed setNum"); + require(proxy.num() == num, "prank2: proxy's storage was not set correctly"); + vm.stopPrank(); + } + + function testPrankDelegateCallStartPrank2() public { + ProxyTest proxy = new ProxyTest(); + ImplementationTest impl = new ImplementationTest(); + vm.startPrank(address(proxy), true); + + // Assert correct `msg.sender` + (bool success,) = + address(impl).delegatecall(abi.encodeWithSignature("assertCorrectCaller(address)", address(proxy))); + require(success, "startPrank2: delegate call failed assertCorrectCaller"); + + // Assert storage updates + uint256 num = 42; + (bool successTwo,) = address(impl).delegatecall(abi.encodeWithSignature("setNum(uint256)", num)); + require(successTwo, "startPrank2: delegate call failed setNum"); + require(proxy.num() == num, "startPrank2: proxy's storage was not set correctly"); + vm.stopPrank(); + } + + function testPrankDelegateCallPrank3(address origin) public { + ProxyTest proxy = new ProxyTest(); + ImplementationTest impl = new ImplementationTest(); + vm.prank(address(proxy), origin, true); + + // Assert correct `msg.sender` + (bool success,) = + address(impl).delegatecall(abi.encodeWithSignature("assertCorrectCaller(address)", address(proxy))); + require(success, "prank3: delegate call failed assertCorrectCaller"); + + // Assert correct `tx.origin` + vm.prank(address(proxy), origin, true); + (bool successTwo,) = address(impl).delegatecall(abi.encodeWithSignature("assertCorrectOrigin(address)", origin)); + require(successTwo, "prank3: delegate call failed assertCorrectOrigin"); + + // Assert storage updates + uint256 num = 42; + vm.prank(address(proxy), address(origin), true); + (bool successThree,) = address(impl).delegatecall(abi.encodeWithSignature("setNum(uint256)", num)); + require(successThree, "prank3: delegate call failed setNum"); + require(proxy.num() == num, "prank3: proxy's storage was not set correctly"); + vm.stopPrank(); + } + + function testPrankDelegateCallStartPrank3(address origin) public { + ProxyTest proxy = new ProxyTest(); + ImplementationTest impl = new ImplementationTest(); + vm.startPrank(address(proxy), origin, true); + + // Assert correct `msg.sender` + (bool success,) = + address(impl).delegatecall(abi.encodeWithSignature("assertCorrectCaller(address)", address(proxy))); + require(success, "startPrank3: delegate call failed assertCorrectCaller"); + + // Assert correct `tx.origin` + (bool successTwo,) = address(impl).delegatecall(abi.encodeWithSignature("assertCorrectOrigin(address)", origin)); + require(successTwo, "startPrank3: delegate call failed assertCorrectOrigin"); + + // Assert storage updates + uint256 num = 42; + (bool successThree,) = address(impl).delegatecall(abi.encodeWithSignature("setNum(uint256)", num)); + require(successThree, "startPrank3: delegate call failed setNum"); + require(proxy.num() == num, "startPrank3: proxy's storage was not set correctly"); + vm.stopPrank(); + } + function testPrankSender(address sender) public { // Perform the prank Victim victim = new Victim();