Skip to content

Commit

Permalink
add more reentrancy tests
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper committed Feb 12, 2024
1 parent 59804e0 commit 16c3cb4
Showing 1 changed file with 101 additions and 23 deletions.
124 changes: 101 additions & 23 deletions tests/functional/codegen/features/decorators/test_nonreentrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,103 @@

from vyper.exceptions import FunctionDeclarationException


# TODO test functions in this module across all evm versions
# once we have cancun support.


def test_nonreentrant_decorator(get_contract, tx_failed):
calling_contract_code = """
interface SpecialContract:
malicious_code = """
interface ProtectedContract:
def protected_function(callback_address: address): nonpayable
@external
def do_callback():
ProtectedContract(msg.sender).protected_function(self)
"""

protected_code = """
interface Callbackable:
def do_callback(): nonpayable
@external
@nonreentrant
def protected_function(c: Callbackable):
c.do_callback()
# add a default function so we know the callback didn't fail for any reason
# besides nonreentrancy
@external
def __default__():
pass
"""
contract = get_contract(protected_code)
malicious = get_contract(malicious_code)

with tx_failed():
contract.protected_function(malicious.address)


def test_nonreentrant_view_function(get_contract, tx_failed):
malicious_code = """
interface ProtectedContract:
def protected_function(): nonpayable
def protected_view_fn() -> uint256: view
@external
def do_callback() -> uint256:
return ProtectedContract(msg.sender).protected_view_fn()
"""

protected_code = """
interface Callbackable:
def do_callback(): nonpayable
@external
@nonreentrant
def protected_function(c: Callbackable):
c.do_callback()
@external
@nonreentrant
@view
def protected_view_fn() -> uint256:
return 10
# add a default function so we know the callback didn't fail for any reason
# besides nonreentrancy
@external
def __default__():
pass
"""
contract = get_contract(protected_code)
malicious = get_contract(malicious_code)

with tx_failed():
contract.protected_function(malicious.address)


def test_multi_function_nonreentrant(get_contract, tx_failed):
malicious_code = """
interface ProtectedContract:
def unprotected_function(val: String[100], do_callback: bool): nonpayable
def protected_function(val: String[100], do_callback: bool): nonpayable
def special_value() -> String[100]: nonpayable
@external
def updated():
SpecialContract(msg.sender).unprotected_function('surprise!', False)
ProtectedContract(msg.sender).unprotected_function('surprise!', False)
@external
def updated_protected():
# This should fail.
SpecialContract(msg.sender).protected_function('surprise protected!', False)
ProtectedContract(msg.sender).protected_function('surprise protected!', False)
"""

reentrant_code = """
protected_code = """
interface Callback:
def updated(): nonpayable
def updated_protected(): nonpayable
interface Self:
def protected_function(val: String[100], do_callback: bool) -> uint256: nonpayable
def protected_function2(val: String[100], do_callback: bool) -> uint256: nonpayable
Expand Down Expand Up @@ -82,37 +155,42 @@ def unprotected_function(val: String[100], do_callback: bool):
if do_callback:
self.callback.updated()
# add a default function so we know the callback didn't fail for any reason
# besides nonreentrancy
@external
def __default__():
pass
"""
contract = get_contract(protected_code)
malicious = get_contract(malicious_code)

reentrant_contract = get_contract(reentrant_code)
calling_contract = get_contract(calling_contract_code)

reentrant_contract.set_callback(calling_contract.address, transact={})
assert reentrant_contract.callback() == calling_contract.address
contract.set_callback(malicious.address, transact={})
assert contract.callback() == malicious.address

# Test unprotected function.
reentrant_contract.unprotected_function("some value", True, transact={})
assert reentrant_contract.special_value() == "surprise!"
contract.unprotected_function("some value", True, transact={})
assert contract.special_value() == "surprise!"

# Test protected function.
reentrant_contract.protected_function("some value", False, transact={})
assert reentrant_contract.special_value() == "some value"
assert reentrant_contract.protected_view_fn() == "some value"
contract.protected_function("some value", False, transact={})
assert contract.special_value() == "some value"
assert contract.protected_view_fn() == "some value"

with tx_failed():
reentrant_contract.protected_function("zzz value", True, transact={})
contract.protected_function("zzz value", True, transact={})

reentrant_contract.protected_function2("another value", False, transact={})
assert reentrant_contract.special_value() == "another value"
contract.protected_function2("another value", False, transact={})
assert contract.special_value() == "another value"

with tx_failed():
reentrant_contract.protected_function2("zzz value", True, transact={})
contract.protected_function2("zzz value", True, transact={})

reentrant_contract.protected_function3("another value", False, transact={})
assert reentrant_contract.special_value() == "another value"
contract.protected_function3("another value", False, transact={})
assert contract.special_value() == "another value"

with tx_failed():
reentrant_contract.protected_function3("zzz value", True, transact={})
contract.protected_function3("zzz value", True, transact={})


def test_nonreentrant_decorator_for_default(w3, get_contract, tx_failed):
Expand Down

0 comments on commit 16c3cb4

Please sign in to comment.