Skip to content

Commit

Permalink
Merge pull request #86 from xian-network/driver-tx-specific-writes
Browse files Browse the repository at this point in the history
driver - transaction specific writes
  • Loading branch information
duelingbenjos authored Oct 8, 2024
2 parents f3c3f27 + ac80f4c commit e8e78ac
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 5 deletions.
7 changes: 6 additions & 1 deletion src/contracting/execution/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def execute(self, sender, contract_name, function_name, kwargs,
metering=None) -> dict:

current_driver_pending_writes = deepcopy(self.driver.pending_writes)
self.driver.clear_transaction_writes()

if not self.bypass_privates:
assert not function_name.startswith(constants.PRIVATE_METHOD_PREFIX), 'Private method not callable.'
Expand Down Expand Up @@ -128,6 +129,7 @@ def execute(self, sender, contract_name, function_name, kwargs,
enable_restricted_imports()
runtime.rt.set_up(stmps=stamps * 1000, meter=metering)
result = func(**kwargs)
transaction_writes = deepcopy(driver.transaction_writes)
runtime.rt.tracer.stop()
disable_restricted_imports()

Expand All @@ -140,11 +142,13 @@ def execute(self, sender, contract_name, function_name, kwargs,

# Revert the writes if the transaction fails
driver.pending_writes = current_driver_pending_writes
transaction_writes = {}

if auto_commit:
driver.flush_cache()

finally:
driver.clear_transaction_writes()
runtime.rt.tracer.stop()

#runtime.rt.tracer.stop()
Expand Down Expand Up @@ -172,6 +176,7 @@ def execute(self, sender, contract_name, function_name, kwargs,
balance = max(balance - to_deduct, 0)

driver.set(balances_key, balance)
transaction_writes[balances_key] = balance

if auto_commit:
driver.commit()
Expand All @@ -184,7 +189,7 @@ def execute(self, sender, contract_name, function_name, kwargs,
'status_code': status_code,
'result': result,
'stamps_used': stamps_used,
'writes': deepcopy(driver.pending_writes),
'writes': transaction_writes,
'reads': driver.pending_reads
}

Expand Down
16 changes: 14 additions & 2 deletions src/contracting/storage/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, bypass_cache=False, storage_home=constants.STORAGE_HOME):
self.pending_deltas = {}
self.pending_writes = {}
self.pending_reads = {}
self.transaction_writes = {}
self.cache = TTLCache(maxsize=1000, ttl=6*3600)
self.bypass_cache = bypass_cache
self.contract_state = storage_home.joinpath("contract_state")
Expand Down Expand Up @@ -76,7 +77,7 @@ def get(self, key: str, save: bool = True):
"""
Get a value from the cache, pending reads, or disk. If save is True,
the value will be saved to pending_reads.
"""
"""
# Parse the key to get the filename and group
value = self.find(key)
if save and self.pending_reads.get(key) is None:
Expand All @@ -86,13 +87,16 @@ def get(self, key: str, save: bool = True):
return value


def set(self, key, value):
def set(self, key, value, is_txn_write=False):
rt.deduct_write(*encode_kv(key, value))
if self.pending_reads.get(key) is None:
self.get(key)
if type(value) in [decimal.Decimal, float]:
value = ContractingDecimal(str(value))
self.pending_writes[key] = value
if is_txn_write:
self.transaction_writes[key] = value


def find(self, key: str):
"""
Expand Down Expand Up @@ -288,6 +292,7 @@ def flush_cache(self):
self.pending_writes.clear()
self.pending_reads.clear()
self.pending_deltas.clear()
self.transaction_writes.clear()
self.cache.clear()

def flush_disk(self):
Expand Down Expand Up @@ -417,3 +422,10 @@ def get_run_state(self):
run_state[full_key] = value

return run_state


def clear_transaction_writes(self):
"""
Clear the transaction-specific writes.
"""
self.transaction_writes.clear()
4 changes: 2 additions & 2 deletions src/contracting/storage/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def set(self, value):
assert isinstance(value, self._type), (f'Wrong type passed to variable! '
f'Expected {self._type}, got {type(value)}.')

self._driver.set(self._key, value)
self._driver.set(self._key, value, True)

def get(self):
return self._driver.get(self._key)
Expand All @@ -39,7 +39,7 @@ def __init__(self, contract, name, driver: Driver = driver, default_value=None):
self._default_value = default_value

def _set(self, key, value):
self._driver.set(f'{self._key}{self._delimiter}{key}', value)
self._driver.set(f'{self._key}{self._delimiter}{key}', value, True)

def _get(self, item):
value = self._driver.get(f'{self._key}{self._delimiter}{item}')
Expand Down
58 changes: 58 additions & 0 deletions tests/integration/test_executor_transaction_writes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import importlib
from unittest import TestCase
from contracting.stdlib.bridge.time import Datetime
from contracting.client import ContractingClient
from contracting.storage.driver import Driver


class TestTransactionWrites(TestCase):
def setUp(self):
self.c = ContractingClient()
self.c.flush()

with open("./test_contracts/currency.s.py") as f:
contract = f.read()

self.c.submit(contract, name="currency")

self.c.executor.driver.commit()

def tearDown(self):
self.c.raw_driver.flush_full()

def test_transfers(self):
self.c.set_var(
contract="currency", variable="balances", arguments=["bill"], value=200
)
res3 = self.c.executor.execute(
contract_name="currency",
function_name="transfer",
kwargs={"to": "someone", "amount": 100},
stamps=1000,
sender="bill",
)
self.assertEquals(res3["writes"], self.c.executor.driver.pending_writes)
res2 = self.c.executor.execute(
contract_name="currency",
function_name="transfer",
kwargs={"to": "someone", "amount": 100},
stamps=1000,
sender="bill",
)

self.assertEquals(res2["writes"], self.c.executor.driver.pending_writes)
# This operation will raise an exception, so will not make any writes.
res3 = self.c.executor.execute(
contract_name="currency",
function_name="transfer",
kwargs={"to": "someone", "amount": 100},
stamps=1000,
sender="bill",
)
self.assertEquals(res3["writes"], {})


if __name__ == "__main__":
import unittest

unittest.main()
18 changes: 18 additions & 0 deletions tests/unit/test_new_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,24 @@ def test_get_all_contract_state(self):
contract_state = self.driver.get_all_contract_state()
self.assertIn(key, contract_state)
self.assertEqual(contract_state[key], value)

def test_transaction_writes(self):
key = 'test_key'
value = 'test_value'
self.driver.set(key, value)
# self.driver.commit()
transaction_writes = self.driver.transaction_writes
self.assertIn(key, transaction_writes)
self.assertEqual(transaction_writes[key], value)

def test_clear_transaction_writes(self):
key = 'test_key'
value = 'test_value'
self.driver.set(key, value)
# self.driver.commit()
self.driver.clear_transaction_writes()
transaction_writes = self.driver.transaction_writes
self.assertNotIn(key, transaction_writes)

def test_get_run_state(self):
# We can't test this function here since we are not running a real blockchain.
Expand Down

0 comments on commit e8e78ac

Please sign in to comment.