Skip to content

Commit

Permalink
fix and tests (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
crosschainer authored Jun 5, 2024
1 parent ddea489 commit ce269d0
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/contracting/execution/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def execute(self, sender, contract_name, function_name, kwargs,
stamp_cost=constants.STAMPS_PER_TAU,
metering=None) -> dict:

current_driver_pending_writes = deepcopy(self.driver.pending_writes)

if not self.bypass_privates:
assert not function_name.startswith(constants.PRIVATE_METHOD_PREFIX), 'Private method not callable.'

Expand Down Expand Up @@ -135,6 +137,9 @@ def execute(self, sender, contract_name, function_name, kwargs,
except Exception as e:
result = e
status_code = 1

# Revert the writes if the transaction fails
driver.pending_writes = current_driver_pending_writes

if auto_commit:
driver.flush_cache()
Expand Down
20 changes: 20 additions & 0 deletions tests/integration/test_contracts/exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
balances = Hash(default_value=0)

@construct
def seed():
balances['stu'] = 999
balances['colin'] = 555

@export
def transfer(amount: int, to: str):
sender = ctx.caller
assert balances[sender] >= amount, 'Not enough coins to send!'

balances[sender] -= amount
balances[to] += amount

raise Exception('This is an exception')

@export
def balance_of(account: str):
return balances[account]
101 changes: 101 additions & 0 deletions tests/unit/test_revert_on_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import unittest
from contracting.storage.driver import Driver
from contracting.execution.executor import Executor
from contracting.constants import STAMPS_PER_TAU
from xian.processor import TxProcessor
from contracting.client import ContractingClient
import contracting
import random
import string
import os
import sys
from loguru import logger

# Get the directory where the script is located
script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))

# Change the current working directory
os.chdir(script_dir)

def submission_kwargs_for_file(f):
# Get the file name only by splitting off directories
split = f.split('/')
split = split[-1]

# Now split off the .s
split = split.split('.')
contract_name = split[0]

with open(f) as file:
contract_code = file.read()

return {
'name': f'con_{contract_name}',
'code': contract_code,
}

TEST_SUBMISSION_KWARGS = {
'sender': 'stu',
'contract_name': 'submission',
'function_name': 'submit_contract'
}

class MyTestCase(unittest.TestCase):

def setUp(self):
self.c = ContractingClient()
self.tx_processor = TxProcessor(client=self.c)
# Hard load the submission contract
self.d = self.c.raw_driver
self.d.flush_full()

with open(contracting.__path__[0] + '/contracts/submission.s.py') as f:
contract = f.read()

self.d.set_contract(name='submission', code=contract)

with open('../integration/test_contracts/currency.s.py') as f:
contract = f.read()
self.d.set_contract(name='currency', code=contract)

self.c.executor.execute(**TEST_SUBMISSION_KWARGS, kwargs=submission_kwargs_for_file('../integration/test_contracts/currency.s.py'), metering=False, auto_commit=True)

self.c.executor.execute(**TEST_SUBMISSION_KWARGS,
kwargs=submission_kwargs_for_file('../integration/test_contracts/exception.py'),
metering=False, auto_commit=True)
self.d.commit()

def test_exception(self):
prior_balance = self.d.get('con_exception.balances:stu')
logger.debug(f"Prior balance (exception): {prior_balance}")


output = self.tx_processor.process_tx({
"payload":
{'sender': 'stu', 'contract': 'con_exception', 'function': 'transfer', 'kwargs': {'amount': 100, 'to': 'colin'},"stamps_supplied":1000},
"metadata":
{"signature":"abc"},"b_meta":{"nanos":0,
"hash":"0x0","height":0}})
logger.debug(f"Output (exception): {output}")

new_balance = self.d.get('con_exception.balances:stu')
logger.debug(f"New balance (exception): {new_balance}")

self.assertEqual(prior_balance, new_balance)

def test_non_exception(self):
prior_balance = self.d.get('con_currency.balances:stu')

output = self.tx_processor.process_tx({
"payload":
{'sender': 'stu', 'contract': 'con_currency', 'function': 'transfer', 'kwargs': {'amount': 100, 'to': 'colin'},"stamps_supplied":1000},
"metadata":
{"signature":"abc"},"b_meta":{"nanos":0,"hash":"0x0","height":0}})

new_balance = self.d.get('con_currency.balances:stu')
logger.debug(f"New balance (non-exception): {new_balance}")

self.assertEqual(prior_balance - 100, new_balance)

if __name__ == '__main__':
unittest.main()

0 comments on commit ce269d0

Please sign in to comment.