diff --git a/tests/functional/codegen/features/decorators/test_nonreentrant.py b/tests/functional/codegen/features/decorators/test_nonreentrant.py index 75a582c538..92a21cd302 100644 --- a/tests/functional/codegen/features/decorators/test_nonreentrant.py +++ b/tests/functional/codegen/features/decorators/test_nonreentrant.py @@ -2,30 +2,103 @@ from vyper.exceptions import FunctionDeclarationException - # TODO test functions in this module across all evm versions # once we have cancun support. + + def test_nonreentrant_decorator(get_contract, tx_failed): - calling_contract_code = """ -interface SpecialContract: + malicious_code = """ +interface ProtectedContract: + def protected_function(callback_address: address): nonpayable + +@external +def do_callback(): + ProtectedContract(msg.sender).protected_function(self) + """ + + protected_code = """ +interface Callbackable: + def do_callback(): nonpayable + +@external +@nonreentrant +def protected_function(c: Callbackable): + c.do_callback() + +# add a default function so we know the callback didn't fail for any reason +# besides nonreentrancy +@external +def __default__(): + pass + """ + contract = get_contract(protected_code) + malicious = get_contract(malicious_code) + + with tx_failed(): + contract.protected_function(malicious.address) + + +def test_nonreentrant_view_function(get_contract, tx_failed): + malicious_code = """ +interface ProtectedContract: + def protected_function(): nonpayable + def protected_view_fn() -> uint256: view + +@external +def do_callback() -> uint256: + return ProtectedContract(msg.sender).protected_view_fn() + """ + + protected_code = """ +interface Callbackable: + def do_callback(): nonpayable + +@external +@nonreentrant +def protected_function(c: Callbackable): + c.do_callback() + +@external +@nonreentrant +@view +def protected_view_fn() -> uint256: + return 10 + +# add a default function so we know the callback didn't fail for any reason +# besides nonreentrancy +@external +def __default__(): + pass + """ + contract = get_contract(protected_code) + malicious = get_contract(malicious_code) + + with tx_failed(): + contract.protected_function(malicious.address) + + +def test_multi_function_nonreentrant(get_contract, tx_failed): + malicious_code = """ +interface ProtectedContract: def unprotected_function(val: String[100], do_callback: bool): nonpayable def protected_function(val: String[100], do_callback: bool): nonpayable def special_value() -> String[100]: nonpayable @external def updated(): - SpecialContract(msg.sender).unprotected_function('surprise!', False) + ProtectedContract(msg.sender).unprotected_function('surprise!', False) @external def updated_protected(): # This should fail. - SpecialContract(msg.sender).protected_function('surprise protected!', False) + ProtectedContract(msg.sender).protected_function('surprise protected!', False) """ - reentrant_code = """ + protected_code = """ interface Callback: def updated(): nonpayable def updated_protected(): nonpayable + interface Self: def protected_function(val: String[100], do_callback: bool) -> uint256: nonpayable def protected_function2(val: String[100], do_callback: bool) -> uint256: nonpayable @@ -82,37 +155,42 @@ def unprotected_function(val: String[100], do_callback: bool): if do_callback: self.callback.updated() + +# add a default function so we know the callback didn't fail for any reason +# besides nonreentrancy +@external +def __default__(): + pass """ + contract = get_contract(protected_code) + malicious = get_contract(malicious_code) - reentrant_contract = get_contract(reentrant_code) - calling_contract = get_contract(calling_contract_code) - - reentrant_contract.set_callback(calling_contract.address, transact={}) - assert reentrant_contract.callback() == calling_contract.address + contract.set_callback(malicious.address, transact={}) + assert contract.callback() == malicious.address # Test unprotected function. - reentrant_contract.unprotected_function("some value", True, transact={}) - assert reentrant_contract.special_value() == "surprise!" + contract.unprotected_function("some value", True, transact={}) + assert contract.special_value() == "surprise!" # Test protected function. - reentrant_contract.protected_function("some value", False, transact={}) - assert reentrant_contract.special_value() == "some value" - assert reentrant_contract.protected_view_fn() == "some value" + contract.protected_function("some value", False, transact={}) + assert contract.special_value() == "some value" + assert contract.protected_view_fn() == "some value" with tx_failed(): - reentrant_contract.protected_function("zzz value", True, transact={}) + contract.protected_function("zzz value", True, transact={}) - reentrant_contract.protected_function2("another value", False, transact={}) - assert reentrant_contract.special_value() == "another value" + contract.protected_function2("another value", False, transact={}) + assert contract.special_value() == "another value" with tx_failed(): - reentrant_contract.protected_function2("zzz value", True, transact={}) + contract.protected_function2("zzz value", True, transact={}) - reentrant_contract.protected_function3("another value", False, transact={}) - assert reentrant_contract.special_value() == "another value" + contract.protected_function3("another value", False, transact={}) + assert contract.special_value() == "another value" with tx_failed(): - reentrant_contract.protected_function3("zzz value", True, transact={}) + contract.protected_function3("zzz value", True, transact={}) def test_nonreentrant_decorator_for_default(w3, get_contract, tx_failed):