From ce269d0c1146bd57ad88f4d49446f540524f12ba Mon Sep 17 00:00:00 2001 From: crosschainer <68580992+crosschainer@users.noreply.github.com> Date: Wed, 5 Jun 2024 11:15:32 +0200 Subject: [PATCH] fix and tests (#52) --- src/contracting/execution/executor.py | 5 + tests/integration/test_contracts/exception.py | 20 ++++ tests/unit/test_revert_on_exception.py | 101 ++++++++++++++++++ 3 files changed, 126 insertions(+) create mode 100644 tests/integration/test_contracts/exception.py create mode 100644 tests/unit/test_revert_on_exception.py diff --git a/src/contracting/execution/executor.py b/src/contracting/execution/executor.py index 166c2d21..07806238 100644 --- a/src/contracting/execution/executor.py +++ b/src/contracting/execution/executor.py @@ -51,6 +51,8 @@ def execute(self, sender, contract_name, function_name, kwargs, stamp_cost=constants.STAMPS_PER_TAU, metering=None) -> dict: + current_driver_pending_writes = deepcopy(self.driver.pending_writes) + if not self.bypass_privates: assert not function_name.startswith(constants.PRIVATE_METHOD_PREFIX), 'Private method not callable.' @@ -135,6 +137,9 @@ def execute(self, sender, contract_name, function_name, kwargs, except Exception as e: result = e status_code = 1 + + # Revert the writes if the transaction fails + driver.pending_writes = current_driver_pending_writes if auto_commit: driver.flush_cache() diff --git a/tests/integration/test_contracts/exception.py b/tests/integration/test_contracts/exception.py new file mode 100644 index 00000000..9272c233 --- /dev/null +++ b/tests/integration/test_contracts/exception.py @@ -0,0 +1,20 @@ +balances = Hash(default_value=0) + +@construct +def seed(): + balances['stu'] = 999 + balances['colin'] = 555 + +@export +def transfer(amount: int, to: str): + sender = ctx.caller + assert balances[sender] >= amount, 'Not enough coins to send!' + + balances[sender] -= amount + balances[to] += amount + + raise Exception('This is an exception') + +@export +def balance_of(account: str): + return balances[account] \ No newline at end of file diff --git a/tests/unit/test_revert_on_exception.py b/tests/unit/test_revert_on_exception.py new file mode 100644 index 00000000..0c77d4e7 --- /dev/null +++ b/tests/unit/test_revert_on_exception.py @@ -0,0 +1,101 @@ +import unittest +from contracting.storage.driver import Driver +from contracting.execution.executor import Executor +from contracting.constants import STAMPS_PER_TAU +from xian.processor import TxProcessor +from contracting.client import ContractingClient +import contracting +import random +import string +import os +import sys +from loguru import logger + +# Get the directory where the script is located +script_dir = os.path.dirname(os.path.abspath(sys.argv[0])) + +# Change the current working directory +os.chdir(script_dir) + +def submission_kwargs_for_file(f): + # Get the file name only by splitting off directories + split = f.split('/') + split = split[-1] + + # Now split off the .s + split = split.split('.') + contract_name = split[0] + + with open(f) as file: + contract_code = file.read() + + return { + 'name': f'con_{contract_name}', + 'code': contract_code, + } + +TEST_SUBMISSION_KWARGS = { + 'sender': 'stu', + 'contract_name': 'submission', + 'function_name': 'submit_contract' +} + +class MyTestCase(unittest.TestCase): + + def setUp(self): + self.c = ContractingClient() + self.tx_processor = TxProcessor(client=self.c) + # Hard load the submission contract + self.d = self.c.raw_driver + self.d.flush_full() + + with open(contracting.__path__[0] + '/contracts/submission.s.py') as f: + contract = f.read() + + self.d.set_contract(name='submission', code=contract) + + with open('../integration/test_contracts/currency.s.py') as f: + contract = f.read() + self.d.set_contract(name='currency', code=contract) + + self.c.executor.execute(**TEST_SUBMISSION_KWARGS, kwargs=submission_kwargs_for_file('../integration/test_contracts/currency.s.py'), metering=False, auto_commit=True) + + self.c.executor.execute(**TEST_SUBMISSION_KWARGS, + kwargs=submission_kwargs_for_file('../integration/test_contracts/exception.py'), + metering=False, auto_commit=True) + self.d.commit() + + def test_exception(self): + prior_balance = self.d.get('con_exception.balances:stu') + logger.debug(f"Prior balance (exception): {prior_balance}") + + + output = self.tx_processor.process_tx({ + "payload": + {'sender': 'stu', 'contract': 'con_exception', 'function': 'transfer', 'kwargs': {'amount': 100, 'to': 'colin'},"stamps_supplied":1000}, + "metadata": + {"signature":"abc"},"b_meta":{"nanos":0, + "hash":"0x0","height":0}}) + logger.debug(f"Output (exception): {output}") + + new_balance = self.d.get('con_exception.balances:stu') + logger.debug(f"New balance (exception): {new_balance}") + + self.assertEqual(prior_balance, new_balance) + + def test_non_exception(self): + prior_balance = self.d.get('con_currency.balances:stu') + + output = self.tx_processor.process_tx({ + "payload": + {'sender': 'stu', 'contract': 'con_currency', 'function': 'transfer', 'kwargs': {'amount': 100, 'to': 'colin'},"stamps_supplied":1000}, + "metadata": + {"signature":"abc"},"b_meta":{"nanos":0,"hash":"0x0","height":0}}) + + new_balance = self.d.get('con_currency.balances:stu') + logger.debug(f"New balance (non-exception): {new_balance}") + + self.assertEqual(prior_balance - 100, new_balance) + +if __name__ == '__main__': + unittest.main()