Skip to content

Commit

Permalink
feat!: Methods without #[payable] cannot accept assets (#800)
Browse files Browse the repository at this point in the history
Closes #742

Forwarding CallParameters with an amount > 0 to a method that is not annotated as payable results in a AssetsForwardedToNonPayableMethod error.

BREAKING CHANGE: `.call_params()` now returns a Result
---------

Co-authored-by: Halil Beglerović <[email protected]>
Co-authored-by: Ahmed Sagdati <[email protected]>
  • Loading branch information
3 people authored Feb 7, 2023
1 parent 0198c43 commit 2cad5b6
Show file tree
Hide file tree
Showing 18 changed files with 137 additions and 18 deletions.
8 changes: 8 additions & 0 deletions docs/src/calling-contracts/call-params.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ Then, in Rust, after setting up and deploying the above contract, you can config
{{#include ../../../examples/contracts/src/lib.rs:call_parameters}}
```

`call_params` returns a result to ensure you don't forward assets to a contract method that isn't payable. In the following example, we try to forward an amount of 100 of the base asset to `non_payable`. As its name suggests, `non_payable` isn't annotated with `#[payable]` in the contract code. Passing `CallParameters` with an amount other than 0 leads to an `InvalidCallParameters` error:

```rust,ignore
{{#include ../../../packages/fuels/tests/contracts.rs:non_payable_params}}
```

> **Note:** forwarding gas to a contract call is always possible, regardless of the contract method being non-payable.
You can also use `CallParameters::default()` to use the default values:

```rust,ignore
Expand Down
6 changes: 3 additions & 3 deletions examples/contracts/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,14 +313,14 @@ mod tests {
let response = contract_methods
.get_msg_amount() // Our contract method.
.tx_params(tx_params) // Chain the tx params setting method.
.call_params(call_params) // Chain the call params setting method.
.call_params(call_params)? // Chain the call params setting method.
.call() // Perform the contract call.
.await?;
// ANCHOR_END: call_parameters
// ANCHOR: call_parameters_default
let response = contract_methods
.initialize_counter(42)
.call_params(CallParameters::default())
.call_params(CallParameters::default())?
.call()
.await?;

Expand Down Expand Up @@ -586,7 +586,7 @@ mod tests {
let response = contract_methods
.get_msg_amount() // Our contract method.
.tx_params(tx_params) // Chain the tx params setting method.
.call_params(call_params) // Chain the call params setting method.
.call_params(call_params)? // Chain the call params setting method.
.call() // Perform the contract call.
.await?;
// ANCHOR_END: call_params_gas
Expand Down
4 changes: 2 additions & 2 deletions examples/cookbook/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ mod tests {
let call_params = CallParameters::new(Some(deposit_amount), Some(base_asset_id), None);
contract_methods
.deposit(wallet.address().into())
.call_params(call_params)
.call_params(call_params)?
.append_variable_outputs(1)
.call()
.await?;
Expand All @@ -66,7 +66,7 @@ mod tests {
let call_params = CallParameters::new(Some(lp_token_balance), Some(lp_asset_id), None);
contract_methods
.withdraw(wallet.address().into())
.call_params(call_params)
.call_params(call_params)?
.append_variable_outputs(1)
.call()
.await?;
Expand Down
2 changes: 2 additions & 0 deletions examples/rust_bindings/src/rust_bindings_formatted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ pub mod abigen_bindings {
),
&[Tokenizable::into_token(value)],
log_decoder,
false,
)
.expect("method not found (this should never happen)")
}
Expand All @@ -81,6 +82,7 @@ pub mod abigen_bindings {
),
&[Tokenizable::into_token(value)],
log_decoder,
false,
)
.expect("method not found (this should never happen)")
}
Expand Down
16 changes: 14 additions & 2 deletions packages/fuels-code-gen/src/program_bindings/abi_types.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::HashMap;

use fuel_abi_types::program_abi::{
ABIFunction, LoggedType, ProgramABI, TypeApplication, TypeDeclaration,
ABIFunction, Attribute, LoggedType, ProgramABI, TypeApplication, TypeDeclaration,
};

use crate::error::{error, Result};
Expand Down Expand Up @@ -63,13 +63,15 @@ pub(crate) struct FullABIFunction {
name: String,
inputs: Vec<FullTypeApplication>,
output: FullTypeApplication,
attributes: Vec<Attribute>,
}

impl FullABIFunction {
pub(crate) fn new(
name: String,
inputs: Vec<FullTypeApplication>,
output: FullTypeApplication,
attributes: Vec<Attribute>,
) -> Result<Self> {
if name.is_empty() {
Err(error!("FullABIFunction's name cannot be empty!"))
Expand All @@ -78,6 +80,7 @@ impl FullABIFunction {
name,
inputs,
output,
attributes,
})
}
}
Expand All @@ -94,6 +97,10 @@ impl FullABIFunction {
&self.output
}

pub(crate) fn is_payable(&self) -> bool {
self.attributes.iter().any(|attr| attr.name == "payable")
}

pub(crate) fn from_counterpart(
abi_function: &ABIFunction,
types: &HashMap<usize, TypeDeclaration>,
Expand All @@ -104,10 +111,15 @@ impl FullABIFunction {
.map(|input| FullTypeApplication::from_counterpart(input, types))
.collect();

let attributes = abi_function
.attributes
.as_ref()
.map_or(vec![], Clone::clone);
FullABIFunction::new(
abi_function.name.clone(),
inputs,
FullTypeApplication::from_counterpart(&abi_function.output, types),
attributes,
)
}
}
Expand Down Expand Up @@ -227,7 +239,7 @@ mod tests {
type_arguments: vec![],
};

let err = FullABIFunction::new("".to_string(), vec![], fn_output)
let err = FullABIFunction::new("".to_string(), vec![], fn_output, vec![])
.expect_err("Should have failed.");

assert_eq!(err.to_string(), "FullABIFunction's name cannot be empty!");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ pub(crate) fn expand_fn(

let fn_selector = generator.fn_selector();
let arg_tokens = generator.tokenized_args();
let is_payable = abi_fun.is_payable();
let body = quote! {
let provider = self.wallet.get_provider().expect("Provider not set up");
::fuels::programs::contract::Contract::method_hash(
Expand All @@ -155,7 +156,8 @@ pub(crate) fn expand_fn(
&self.wallet,
#fn_selector,
&#arg_tokens,
self.log_decoder.clone()
self.log_decoder.clone(),
#is_payable,
)
.expect("method not found (this should never happen)")
};
Expand Down Expand Up @@ -349,7 +351,8 @@ mod tests {
::fuels::types::traits::Tokenizable::into_token(s_1),
::fuels::types::traits::Tokenizable::into_token(s_2)
],
self.log_decoder.clone()
self.log_decoder.clone(),
false,
)
.expect("method not found (this should never happen)")
}
Expand Down Expand Up @@ -409,7 +412,8 @@ mod tests {
&[<bool as ::fuels::types::traits::Parameterize>::param_type()]
),
&[::fuels::types::traits::Tokenizable::into_token(bimbam)],
self.log_decoder.clone()
self.log_decoder.clone(),
false,
)
.expect("method not found (this should never happen)")
}
Expand Down Expand Up @@ -525,7 +529,8 @@ mod tests {
&[::fuels::types::traits::Tokenizable::into_token(
the_only_allowed_input
)],
self.log_decoder.clone()
self.log_decoder.clone(),
false,
)
.expect("method not found (this should never happen)")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ mod tests {
}],
}];

FullABIFunction::new("test_function".to_string(), fn_inputs, fn_output)
FullABIFunction::new("test_function".to_string(), fn_inputs, fn_output, vec![])
.expect("Hand crafted function known to be correct")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ mod tests {
},
type_arguments: vec![],
},
vec![],
)
.expect("hand-crafted, should not fail!")
}
Expand Down
2 changes: 2 additions & 0 deletions packages/fuels-programs/src/call_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ mod test {
external_contracts: Default::default(),
output_param: ParamType::Unit,
message_outputs: None,
is_payable: false,
custom_assets: Default::default(),
}
}
Expand Down Expand Up @@ -399,6 +400,7 @@ mod test {
message_outputs: None,
external_contracts: vec![],
output_param: ParamType::Unit,
is_payable: false,
custom_assets: Default::default(),
})
.collect();
Expand Down
14 changes: 12 additions & 2 deletions packages/fuels-programs/src/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ impl Contract {
signature: Selector,
args: &[Token],
log_decoder: LogDecoder,
is_payable: bool,
) -> Result<ContractCallHandler<D>> {
let encoded_selector = signature;

Expand All @@ -129,6 +130,7 @@ impl Contract {
message_outputs: None,
external_contracts: vec![],
output_param: D::param_type(),
is_payable,
custom_assets: Default::default(),
};

Expand Down Expand Up @@ -362,6 +364,7 @@ pub struct ContractCall {
pub message_outputs: Option<Vec<Output>>,
pub external_contracts: Vec<Bech32ContractId>,
pub output_param: ParamType,
pub is_payable: bool,
pub custom_assets: HashMap<(AssetId, Option<Bech32Address>), u64>,
}

Expand Down Expand Up @@ -591,6 +594,10 @@ where
self
}

pub fn is_payable(&self) -> bool {
self.contract_call.is_payable
}

/// Sets the transaction parameters for a given transaction.
/// Note that this is a builder method, i.e. use it as a chain:
Expand All @@ -610,9 +617,12 @@ where
/// let params = CallParameters { amount: 1, asset_id: BASE_ASSET_ID };
/// my_contract_instance.my_method(...).call_params(params).call()
/// ```
pub fn call_params(mut self, params: CallParameters) -> Self {
pub fn call_params(mut self, params: CallParameters) -> Result<Self> {
if !self.is_payable() && params.amount > 0 {
return Err(Error::AssetsForwardedToNonPayableMethod);
}
self.contract_call.call_parameters = params;
self
Ok(self)
}

/// Appends `num` [`fuel_tx::Output::Variable`]s to the transaction.
Expand Down
2 changes: 2 additions & 0 deletions packages/fuels-types/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ pub enum Error {
ValidationError(#[from] CheckError),
#[error("Revert transaction error: {}, receipts: {:?}", .0, .1)]
RevertTransactionError(String, Vec<Receipt>),
#[error("Tried to forward assets to a contract method that is not payable.")]
AssetsForwardedToNonPayableMethod,
}

pub type Result<T> = std::result::Result<T, Error>;
Expand Down
44 changes: 44 additions & 0 deletions packages/fuels/tests/contracts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1068,3 +1068,47 @@ async fn test_deploy_error_messages() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn test_payable_annotation() -> Result<()> {
setup_contract_test!(
Wallets("wallet"),
Abigen(
name = "TestContract",
abi = "packages/fuels/tests/contracts/payable_annotation"
),
Deploy(
name = "contract_instance",
contract = "TestContract",
wallet = "wallet"
),
);

let contract_methods = contract_instance.methods();

let response = contract_methods
.payable()
.call_params(CallParameters::new(Some(100), None, Some(20000)))?
.call()
.await?;

assert_eq!(response.value, 42);

// ANCHOR: non_payable_params
let err = contract_methods
.non_payable()
.call_params(CallParameters::new(Some(100), None, None))
.expect_err("Should return call params error.");

assert!(matches!(err, Error::AssetsForwardedToNonPayableMethod));
// ANCHOR_END: non_payable_params */
let response = contract_methods
.non_payable()
.call_params(CallParameters::new(None, None, Some(20000)))?
.call()
.await?;

assert_eq!(response.value, 42);

Ok(())
}
2 changes: 2 additions & 0 deletions packages/fuels/tests/contracts/contract_test/src/main.sw
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ abi TestContract {
fn array_of_structs(p: [Person; 2]) -> [Person; 2];
fn array_of_enums(p: [State; 2]) -> [State; 2];
fn get_array(p: [u64; 2]) -> [u64; 2];
#[payable]
fn get_msg_amount() -> u64;
fn new() -> u64;
}
Expand All @@ -39,6 +40,7 @@ const COUNTER_KEY = 0x0000000000000000000000000000000000000000000000000000000000

impl TestContract for Contract {
// ANCHOR: msg_amount
#[payable]
fn get_msg_amount() -> u64 {
msg_amount()
}
Expand Down
4 changes: 4 additions & 0 deletions packages/fuels/tests/contracts/liquidity_pool/src/main.sw
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@ use std::{
};

abi LiquidityPool {
#[payable]
fn deposit(recipient: Address);
#[payable]
fn withdraw(recipient: Address);
}

const BASE_TOKEN: b256 = 0x9ae5b658754e096e4d681c548daf46354495a437cc61492599e33fc64dcdc30c;

impl LiquidityPool for Contract {
#[payable]
fn deposit(recipient: Address) {
assert(ContractId::from(BASE_TOKEN) == msg_asset_id());
assert(0 < msg_amount());
Expand All @@ -31,6 +34,7 @@ impl LiquidityPool for Contract {
mint_to_address(amount_to_mint, recipient);
}

#[payable]
fn withdraw(recipient: Address) {
assert(contract_id() == msg_asset_id());
assert(0 < msg_amount());
Expand Down
7 changes: 7 additions & 0 deletions packages/fuels/tests/contracts/payable_annotation/Forc.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[project]
authors = ["Fuel Labs <[email protected]>"]
entry = "main.sw"
license = "Apache-2.0"
name = "payable_annotation"

[dependencies]
18 changes: 18 additions & 0 deletions packages/fuels/tests/contracts/payable_annotation/src/main.sw
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
contract;

abi TestContract {
#[payable]
fn payable() -> u64;
fn non_payable() -> u64;
}

impl TestContract for Contract {
#[payable]
fn payable() -> u64 {
42
}

fn non_payable() -> u64 {
42
}
}
Loading

0 comments on commit 2cad5b6

Please sign in to comment.