diff --git a/src/contracting/execution/executor.py b/src/contracting/execution/executor.py index 7d1f1ff6..3ff28bbc 100644 --- a/src/contracting/execution/executor.py +++ b/src/contracting/execution/executor.py @@ -52,6 +52,7 @@ def execute(self, sender, contract_name, function_name, kwargs, metering=None) -> dict: current_driver_pending_writes = deepcopy(self.driver.pending_writes) + self.driver.clear_transaction_writes() if not self.bypass_privates: assert not function_name.startswith(constants.PRIVATE_METHOD_PREFIX), 'Private method not callable.' @@ -128,6 +129,7 @@ def execute(self, sender, contract_name, function_name, kwargs, enable_restricted_imports() runtime.rt.set_up(stmps=stamps * 1000, meter=metering) result = func(**kwargs) + transaction_writes = deepcopy(driver.transaction_writes) runtime.rt.tracer.stop() disable_restricted_imports() @@ -140,11 +142,13 @@ def execute(self, sender, contract_name, function_name, kwargs, # Revert the writes if the transaction fails driver.pending_writes = current_driver_pending_writes + transaction_writes = {} if auto_commit: driver.flush_cache() finally: + driver.clear_transaction_writes() runtime.rt.tracer.stop() #runtime.rt.tracer.stop() @@ -172,6 +176,7 @@ def execute(self, sender, contract_name, function_name, kwargs, balance = max(balance - to_deduct, 0) driver.set(balances_key, balance) + transaction_writes[balances_key] = balance if auto_commit: driver.commit() @@ -184,7 +189,7 @@ def execute(self, sender, contract_name, function_name, kwargs, 'status_code': status_code, 'result': result, 'stamps_used': stamps_used, - 'writes': deepcopy(driver.pending_writes), + 'writes': transaction_writes, 'reads': driver.pending_reads } diff --git a/src/contracting/storage/driver.py b/src/contracting/storage/driver.py index 760c6578..7763be2b 100644 --- a/src/contracting/storage/driver.py +++ b/src/contracting/storage/driver.py @@ -33,6 +33,7 @@ def __init__(self, bypass_cache=False, storage_home=constants.STORAGE_HOME): self.pending_deltas = {} self.pending_writes = {} self.pending_reads = {} + self.transaction_writes = {} self.cache = TTLCache(maxsize=1000, ttl=6*3600) self.bypass_cache = bypass_cache self.contract_state = storage_home.joinpath("contract_state") @@ -76,7 +77,7 @@ def get(self, key: str, save: bool = True): """ Get a value from the cache, pending reads, or disk. If save is True, the value will be saved to pending_reads. - """ + """ # Parse the key to get the filename and group value = self.find(key) if save and self.pending_reads.get(key) is None: @@ -86,13 +87,16 @@ def get(self, key: str, save: bool = True): return value - def set(self, key, value): + def set(self, key, value, is_txn_write=False): rt.deduct_write(*encode_kv(key, value)) if self.pending_reads.get(key) is None: self.get(key) if type(value) in [decimal.Decimal, float]: value = ContractingDecimal(str(value)) self.pending_writes[key] = value + if is_txn_write: + self.transaction_writes[key] = value + def find(self, key: str): """ @@ -288,6 +292,7 @@ def flush_cache(self): self.pending_writes.clear() self.pending_reads.clear() self.pending_deltas.clear() + self.transaction_writes.clear() self.cache.clear() def flush_disk(self): @@ -417,3 +422,10 @@ def get_run_state(self): run_state[full_key] = value return run_state + + + def clear_transaction_writes(self): + """ + Clear the transaction-specific writes. + """ + self.transaction_writes.clear() diff --git a/src/contracting/storage/orm.py b/src/contracting/storage/orm.py index 318f393d..83a7f938 100644 --- a/src/contracting/storage/orm.py +++ b/src/contracting/storage/orm.py @@ -26,7 +26,7 @@ def set(self, value): assert isinstance(value, self._type), (f'Wrong type passed to variable! ' f'Expected {self._type}, got {type(value)}.') - self._driver.set(self._key, value) + self._driver.set(self._key, value, True) def get(self): return self._driver.get(self._key) @@ -39,7 +39,7 @@ def __init__(self, contract, name, driver: Driver = driver, default_value=None): self._default_value = default_value def _set(self, key, value): - self._driver.set(f'{self._key}{self._delimiter}{key}', value) + self._driver.set(f'{self._key}{self._delimiter}{key}', value, True) def _get(self, item): value = self._driver.get(f'{self._key}{self._delimiter}{item}') diff --git a/tests/integration/test_executor_transaction_writes.py b/tests/integration/test_executor_transaction_writes.py new file mode 100644 index 00000000..49e7f04c --- /dev/null +++ b/tests/integration/test_executor_transaction_writes.py @@ -0,0 +1,58 @@ +import importlib +from unittest import TestCase +from contracting.stdlib.bridge.time import Datetime +from contracting.client import ContractingClient +from contracting.storage.driver import Driver + + +class TestTransactionWrites(TestCase): + def setUp(self): + self.c = ContractingClient() + self.c.flush() + + with open("./test_contracts/currency.s.py") as f: + contract = f.read() + + self.c.submit(contract, name="currency") + + self.c.executor.driver.commit() + + def tearDown(self): + self.c.raw_driver.flush_full() + + def test_transfers(self): + self.c.set_var( + contract="currency", variable="balances", arguments=["bill"], value=200 + ) + res3 = self.c.executor.execute( + contract_name="currency", + function_name="transfer", + kwargs={"to": "someone", "amount": 100}, + stamps=1000, + sender="bill", + ) + self.assertEquals(res3["writes"], self.c.executor.driver.pending_writes) + res2 = self.c.executor.execute( + contract_name="currency", + function_name="transfer", + kwargs={"to": "someone", "amount": 100}, + stamps=1000, + sender="bill", + ) + + self.assertEquals(res2["writes"], self.c.executor.driver.pending_writes) + # This operation will raise an exception, so will not make any writes. + res3 = self.c.executor.execute( + contract_name="currency", + function_name="transfer", + kwargs={"to": "someone", "amount": 100}, + stamps=1000, + sender="bill", + ) + self.assertEquals(res3["writes"], {}) + + +if __name__ == "__main__": + import unittest + + unittest.main() diff --git a/tests/unit/test_new_driver.py b/tests/unit/test_new_driver.py index c5daa1a2..c5683ff1 100644 --- a/tests/unit/test_new_driver.py +++ b/tests/unit/test_new_driver.py @@ -105,6 +105,24 @@ def test_get_all_contract_state(self): contract_state = self.driver.get_all_contract_state() self.assertIn(key, contract_state) self.assertEqual(contract_state[key], value) + + def test_transaction_writes(self): + key = 'test_key' + value = 'test_value' + self.driver.set(key, value) + # self.driver.commit() + transaction_writes = self.driver.transaction_writes + self.assertIn(key, transaction_writes) + self.assertEqual(transaction_writes[key], value) + + def test_clear_transaction_writes(self): + key = 'test_key' + value = 'test_value' + self.driver.set(key, value) + # self.driver.commit() + self.driver.clear_transaction_writes() + transaction_writes = self.driver.transaction_writes + self.assertNotIn(key, transaction_writes) def test_get_run_state(self): # We can't test this function here since we are not running a real blockchain.