Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(contracts): make MultiPayment contract upgradable and emit payment events #858

Merged
merged 13 commits into from
Feb 13, 2025
29 changes: 0 additions & 29 deletions contracts/src/multi-payment/MultiPayment.sol

This file was deleted.

59 changes: 59 additions & 0 deletions contracts/src/multi-payment/MultiPaymentV1.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// SPDX-License-Identifier: GNU GENERAL PUBLIC LICENSE
pragma solidity ^0.8.27;

import {UUPSUpgradeable} from "@openzeppelin/contracts-upgradeable/proxy/utils/UUPSUpgradeable.sol";
import {OwnableUpgradeable} from "@openzeppelin/contracts-upgradeable/access/OwnableUpgradeable.sol";

contract MultiPaymentV1 is UUPSUpgradeable, OwnableUpgradeable {
error RecipientsAndAmountsMismatch();
error InvalidValue();

event Payment(address indexed recipient, uint256 amount, bool success);

// Initializers
function initialize() public initializer {
__Ownable_init(msg.sender);
}

// Overrides
function _authorizeUpgrade(address newImplementation) internal override onlyOwner {}

function version() external pure returns (uint256) {
return 1;
}

function pay(address payable[] calldata recipients, uint256[] calldata amounts) external payable {
if (recipients.length != amounts.length) {
revert RecipientsAndAmountsMismatch();
}

// Ensure value sent is equal to the total amount to send
uint256 total = 0;
for (uint256 i = 0; i < amounts.length; i++) {
total += amounts[i];
}
if (msg.value != total) {
revert InvalidValue();
}

uint256 numRecipients = recipients.length;
oXtxNt9U marked this conversation as resolved.
Show resolved Hide resolved
if (numRecipients == 0) {
return;
}

for (uint256 i = 0; i < recipients.length; i++) {
(bool sent,) = recipients[i].call{value: amounts[i], gas: 5000}("");
if (sent) {
total -= amounts[i];
}

emit Payment(recipients[i], amounts[i], sent);
}

// Refund any remaining value due to partial payments
if (total > 0) {
(bool success,) = msg.sender.call{value: total}("");
require(success, "Refund failed");
}
}
}
46 changes: 46 additions & 0 deletions contracts/test/multi-payment/MultiPayment-Proxy.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// SPDX-License-Identifier: GNU GENERAL PUBLIC LICENSE
pragma solidity ^0.8.13;

import {Test, console} from "@forge-std/Test.sol";
import {MultiPaymentV1} from "@contracts/multi-payment/MultiPaymentV1.sol";
import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
import {Initializable} from "@openzeppelin/contracts/proxy/utils/Initializable.sol";

contract MultiPaymentV2Test is MultiPaymentV1 {
function versionv2() external pure returns (uint256) {
return 2;
}
}

contract ProxyTest is Test {
MultiPaymentV1 public multiPayment;

function setUp() public {
bytes memory data = abi.encode(MultiPaymentV1.initialize.selector);
address proxy = address(new ERC1967Proxy(address(new MultiPaymentV1()), data));
multiPayment = MultiPaymentV1(proxy);
}

function test_initialize_should_revert() public {
vm.expectRevert(Initializable.InvalidInitialization.selector);
multiPayment.initialize();
}

function test_should_have_valid_UPGRADE_INTERFACE_VERSION() public view {
assertEq(multiPayment.UPGRADE_INTERFACE_VERSION(), "5.0.0");
}

function test_proxy_should_update() public {
assertEq(multiPayment.version(), 1);
assertEq(multiPayment.UPGRADE_INTERFACE_VERSION(), "5.0.0");
multiPayment.upgradeToAndCall(address(new MultiPaymentV2Test()), bytes(""));

// Cast proxy to new contract
MultiPaymentV2Test multiPaymentNew = MultiPaymentV2Test(address(multiPayment));
assertEq(multiPaymentNew.versionv2(), 2);

// Should keep old data
vm.expectRevert(Initializable.InvalidInitialization.selector);
multiPaymentNew.initialize();
}
}
194 changes: 165 additions & 29 deletions contracts/test/multi-payment/MultiPayment.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,27 @@
pragma solidity ^0.8.13;

import {Test, console} from "@forge-std/Test.sol";
import {MultiPayment} from "@contracts/multi-payment/MultiPayment.sol";
import {MultiPaymentV1} from "@contracts/multi-payment/MultiPaymentV1.sol";
import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
import {OwnableUpgradeable} from "@openzeppelin/contracts-upgradeable/access/OwnableUpgradeable.sol";

contract RejectPayments {
fallback() external payable {
revert("Recipient always reverts");
}

receive() external payable {
revert("Direct payments are not accepted");
revert("Recipient always reverts");
}
}

contract MultiPaymentTest is Test {
MultiPayment public multiPayment;
MultiPaymentV1 public multiPayment;

function setUp() public {
multiPayment = new MultiPayment();
bytes memory data = abi.encode(MultiPaymentV1.initialize.selector);
address proxy = address(new ERC1967Proxy(address(new MultiPaymentV1()), data));
multiPayment = MultiPaymentV1(proxy);
}

function test_pay_pass_with_zero_payment() public {
Expand Down Expand Up @@ -108,6 +116,130 @@ contract MultiPaymentTest is Test {
assertEq(sender.balance, 40 ether);
}

function test_pay_pass_with_partial_success() public {
address payable sender = payable(address(9999));
vm.deal(sender, 100 ether);
assertEq(sender.balance, 100 ether);

vm.startPrank(sender);

address payable recipient1 = payable(address(1));

RejectPayments rejectPayments = new RejectPayments();
address payable recipient2 = payable(address(rejectPayments));
assertEq(recipient2.balance, 0);

address payable recipient3 = payable(address(3));

address payable[] memory recipients = new address payable[](3);
recipients[0] = recipient1;
recipients[1] = recipient2;
recipients[2] = recipient3;

uint256[] memory amounts = new uint256[](3);
amounts[0] = 10 ether;
amounts[1] = 20 ether;
amounts[2] = 30 ether;

// Act
multiPayment.pay{value: 60 ether}(recipients, amounts);

// Assert
assertEq(recipient1.balance, 10 ether);
assertEq(recipient2.balance, 0 ether); // failed
assertEq(recipient3.balance, 30 ether);
assertEq(sender.balance, 60 ether); // refunded 20 ether

vm.stopPrank();
}

function test_pay_emitted_events() public {
address payable sender = payable(address(this));
vm.deal(sender, 100 ether);
assertEq(sender.balance, 100 ether);

address payable recipient1 = payable(address(1));
address payable recipient2 = payable(address(2));
address payable recipient3 = payable(address(3));

address payable[] memory recipients = new address payable[](3);
recipients[0] = recipient1;
recipients[1] = recipient2;
recipients[2] = recipient3;

uint256[] memory amounts = new uint256[](3);
amounts[0] = 10 ether;
amounts[1] = 20 ether;
amounts[2] = 30 ether;

// Events
vm.expectEmit();
emit MultiPaymentV1.Payment(recipient1, 10 ether, true);

vm.expectEmit();
emit MultiPaymentV1.Payment(recipient2, 20 ether, true);

vm.expectEmit();
emit MultiPaymentV1.Payment(recipient3, 30 ether, true);

// Act
multiPayment.pay{value: 60 ether}(recipients, amounts);

// Assert
assertEq(recipient1.balance, 10 ether);
assertEq(recipient2.balance, 20 ether);
assertEq(recipient3.balance, 30 ether);
assertEq(sender.balance, 40 ether);
}

function test_pay_emitted_events_with_reverts() public {
address payable sender = payable(address(9999));
vm.deal(sender, 100 ether);
assertEq(sender.balance, 100 ether);

vm.startPrank(sender);

address payable recipient1 = payable(address(1));

RejectPayments rejectPayments = new RejectPayments();
address payable recipient2 = payable(address(rejectPayments));
assertEq(recipient2.balance, 0);

address payable recipient3 = payable(address(3));

address payable[] memory recipients = new address payable[](3);
recipients[0] = recipient1;
recipients[1] = recipient2;
recipients[2] = recipient3;

uint256[] memory amounts = new uint256[](3);
amounts[0] = 10 ether;
amounts[1] = 20 ether;
amounts[2] = 30 ether;

// Force recipient2 to reject payment

vm.expectEmit();
emit MultiPaymentV1.Payment(recipient1, 10 ether, true);

vm.expectEmit();
emit MultiPaymentV1.Payment(recipient2, 20 ether, false);

vm.expectEmit();
emit MultiPaymentV1.Payment(recipient3, 30 ether, true);

// Act
multiPayment.pay{value: 60 ether}(recipients, amounts);

// Assert
assertEq(recipient1.balance, 10 ether);
assertEq(recipient2.balance, 0 ether);
assertEq(recipient3.balance, 30 ether);
assertEq(sender.balance, 60 ether);

vm.stopPrank();
oXtxNt9U marked this conversation as resolved.
Show resolved Hide resolved
}

function test_pay_pass_with_multiple_payments_same_address() public {
address payable sender = payable(address(this));
vm.deal(sender, 100 ether);
Expand Down Expand Up @@ -135,13 +267,15 @@ contract MultiPaymentTest is Test {
}

function test_pay_pass_with_multiple_payments_large() public {
uint256 payments = 10000;
uint256 payments = 100;
// 21 000 000
oXtxNt9U marked this conversation as resolved.
Show resolved Hide resolved
address payable[] memory recipients = new address payable[](payments);
uint256[] memory amounts = new uint256[](payments);

uint256 total = 0;
for (uint256 i = 0; i < payments; i++) {
recipients[i] = payable(address(uint160(i + 10))); // For some reason address(9) reverts // TODO: Check why
// Low addresses are reserved by foundry (Cheat Code Addresses) and cause side effects when used
recipients[i] = payable(address(uint160(1000 + i)));
amounts[i] = 1;
total += 1;
}
Expand Down Expand Up @@ -175,7 +309,7 @@ contract MultiPaymentTest is Test {
amounts[1] = 60 ether;

// Act
vm.expectRevert(MultiPayment.RecipientsAndAmountsMismatch.selector);
vm.expectRevert(MultiPaymentV1.RecipientsAndAmountsMismatch.selector);
multiPayment.pay{value: 100 ether}(recipients, amounts);
}

Expand All @@ -194,15 +328,17 @@ contract MultiPaymentTest is Test {
amounts[0] = 40 ether;

// Act
vm.expectRevert(MultiPayment.InvalidValue.selector);
vm.expectRevert(MultiPaymentV1.InvalidValue.selector);
multiPayment.pay{value: 50 ether}(recipients, amounts);
}

function test_pay_fail_with_failed_to_send_ether() public {
address payable sender = payable(address(this));
function test_pay_refund_when_failed_to_send_ether() public {
address payable sender = payable(address(999));
vm.deal(sender, 100 ether);
assertEq(sender.balance, 100 ether);

vm.startPrank(sender);

RejectPayments rejectPayments = new RejectPayments();
address payable recipient = payable(address(rejectPayments));
assertEq(recipient.balance, 0);
Expand All @@ -214,30 +350,30 @@ contract MultiPaymentTest is Test {
amounts[0] = 40 ether;

// Act
recipient = payable(address(0)); // Force recipient to be address(0)
vm.expectRevert(MultiPayment.FailedToSendEther.selector);
multiPayment.pay{value: 40 ether}(recipients, amounts);
}

// Test disabled, because of foundy updates. Check:
// https://book.getfoundry.sh/cheatcodes/expect-revert#description
assertEq(sender.balance, 100 ether);

// function test_pay_fail_if_no_enough_balance() public {
// address payable sender = payable(address(this));
// vm.deal(sender, 100 ether);
// assertEq(sender.balance, 100 ether);
vm.stopPrank();
}

// address payable recipient = payable(address(1));
// assertEq(recipient.balance, 0);
/// forge-config: default.allow_internal_expect_revert = true
function test_pay_fail_if_no_enough_balance() public {
address payable sender = payable(address(this));
vm.deal(sender, 100 ether);
assertEq(sender.balance, 100 ether);

address payable recipient = payable(address(1));
assertEq(recipient.balance, 0);

// address payable[] memory recipients = new address payable[](1);
// recipients[0] = recipient;
address payable[] memory recipients = new address payable[](1);
recipients[0] = recipient;

// uint256[] memory amounts = new uint256[](1);
// amounts[0] = 10 ether;
uint256[] memory amounts = new uint256[](1);
amounts[0] = 10 ether;

// // Act
// vm.expectRevert();
// multiPayment.pay{value: 110 ether}(recipients, amounts);
// }
// Act
vm.expectRevert();
multiPayment.pay{value: 110 ether}(recipients, amounts);
}
}
Loading
Loading