Skip to content

Commit

Permalink
refactor(contract): simplify join logic and payment handling (#2183)
Browse files Browse the repository at this point in the history
Streamlined `joinSpace` and `joinSpaceWithReferral` methods by
refactoring and centralizing payment checks. Improved logic for
validating and rejecting memberships, added support for overpayment
refunds, and adjusted protocol fee calculations. Updated tests to
reflect these changes and removed redundant methods for clarity and
efficiency.
  • Loading branch information
shuhuiluo authored Feb 4, 2025
1 parent 0eaaa08 commit ce6dffb
Show file tree
Hide file tree
Showing 20 changed files with 190 additions and 209 deletions.
16 changes: 5 additions & 11 deletions contracts/src/spaces/facets/dispatcher/DispatcherBase.sol
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,9 @@ abstract contract DispatcherBase is IDispatcherBase {
return ds.transactionData[transactionId];
}

function _captureValue(bytes32 transactionId, uint256 value) internal {
if (value == 0) CustomRevert.revertWith(Dispatcher__InvalidValue.selector);
if (msg.value != value)
CustomRevert.revertWith(Dispatcher__InvalidValue.selector);

function _captureValue(bytes32 transactionId) internal {
DispatcherStorage.Layout storage ds = DispatcherStorage.layout();
ds.transactionBalance[transactionId] += value;
ds.transactionBalance[transactionId] += msg.value;
}

function _releaseCapturedValue(
Expand Down Expand Up @@ -75,10 +71,10 @@ abstract contract DispatcherBase is IDispatcherBase {
function _registerTransaction(
address sender,
bytes memory data
) internal returns (bytes32) {
) internal returns (bytes32 transactionId) {
bytes32 keyHash = keccak256(abi.encodePacked(sender, block.number));

bytes32 transactionId = _makeDispatchId(
transactionId = _makeDispatchId(
keyHash,
_makeDispatchInputSeed(keyHash, sender, _useDispatchNonce(keyHash))
);
Expand All @@ -90,9 +86,7 @@ abstract contract DispatcherBase is IDispatcherBase {

_captureData(transactionId, data);
if (msg.value != 0) {
_captureValue(transactionId, msg.value);
_captureValue(transactionId);
}

return transactionId;
}
}
2 changes: 0 additions & 2 deletions contracts/src/spaces/facets/dispatcher/IDispatcher.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,5 @@ pragma solidity ^0.8.23;
// contracts

interface IDispatcherBase {
error Dispatcher__InvalidValue();
error Dispatcher__InvalidCaller();
error Dispatcher__TransactionAlreadyExists();
}
13 changes: 4 additions & 9 deletions contracts/src/spaces/facets/membership/MembershipBase.sol
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,13 @@ abstract contract MembershipBase is IMembershipBase {
address payer,
uint256 membershipPrice
) internal returns (uint256 protocolFee) {
IPlatformRequirements platform = _getPlatformRequirements();

address platformRecipient = platform.getFeeRecipient();
protocolFee = _getProtocolFee(membershipPrice);

// transfer the platform fee to the platform fee recipient
CurrencyTransfer.transferCurrency(
_getMembershipCurrency(),
payer, // from
platformRecipient, // to
_getPlatformRequirements().getFeeRecipient(), // to
protocolFee
);
}
Expand Down Expand Up @@ -155,11 +152,11 @@ abstract contract MembershipBase is IMembershipBase {
/// @dev Makes it virtual to allow other pricing strategies
function _getMembershipPrice(
uint256 totalSupply
) internal view virtual returns (uint256) {
) internal view virtual returns (uint256 membershipPrice) {
// get free allocation
uint256 freeAllocation = _getMembershipFreeAllocation();

uint256 membershipPrice = IMembershipPricing(_getPricingModule()).getPrice(
membershipPrice = IMembershipPricing(_getPricingModule()).getPrice(
freeAllocation,
totalSupply
);
Expand All @@ -169,9 +166,7 @@ abstract contract MembershipBase is IMembershipBase {
uint256 minPrice = platform.getMembershipMinPrice();
uint256 fixedFee = platform.getMembershipFee();

if (membershipPrice < minPrice) return fixedFee;

return membershipPrice;
if (membershipPrice < minPrice) membershipPrice = fixedFee;
}

function _setMembershipRenewalPrice(
Expand Down
3 changes: 1 addition & 2 deletions contracts/src/spaces/facets/membership/MembershipFacet.sol
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ contract MembershipFacet is

/// @inheritdoc IMembership
function joinSpace(address receiver) external payable nonReentrant {
ReferralTypes memory emptyReferral;
_joinSpaceWithReferral(receiver, emptyReferral);
_joinSpace(receiver);
}

/// @inheritdoc IMembership
Expand Down
143 changes: 82 additions & 61 deletions contracts/src/spaces/facets/membership/join/MembershipJoin.sol
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,39 @@ abstract contract MembershipJoin is
return abi.encode(selector, sender, receiver, referralData);
}

/// @notice Handles the process of joining a space
/// @param receiver The address that will receive the membership token
function _joinSpace(address receiver) internal {
_validateJoinSpace(receiver);

bool shouldCharge = _shouldChargeForJoinSpace();
if (shouldCharge) _validatePayment();

bytes4 selector = IMembership.joinSpace.selector;

bytes32 transactionId = _registerTransaction(
receiver,
_encodeJoinSpaceData(selector, msg.sender, receiver, "")
);

(bool isEntitled, bool isCrosschainPending) = _checkEntitlement(
receiver,
msg.sender,
transactionId
);

if (!isCrosschainPending) {
if (isEntitled) {
if (shouldCharge) _chargeForJoinSpace(transactionId);

_refundBalance(transactionId, receiver);
_issueToken(receiver);
} else {
_rejectMembership(transactionId, receiver);
}
}
}

/// @notice Handles the process of joining a space with a referral
/// @param receiver The address that will receive the membership token
/// @param referral The referral information
Expand All @@ -67,59 +100,50 @@ abstract contract MembershipJoin is
ReferralTypes memory referral
) internal {
_validateJoinSpace(receiver);
_validatePayment();

bool shouldCharge = _shouldChargeForJoinSpace();
if (shouldCharge) _validatePayment();

_validateUserReferral(receiver, referral);
bool isNotReferral = _isNotReferral(referral);
address sender = msg.sender;

bytes memory referralData = isNotReferral
? bytes("")
: abi.encode(referral);
bytes memory referralData = abi.encode(referral);

bytes4 selector = isNotReferral
? IMembership.joinSpace.selector
: IMembership.joinSpaceWithReferral.selector;
bytes4 selector = IMembership.joinSpaceWithReferral.selector;

bytes32 transactionId = _registerTransaction(
receiver,
_encodeJoinSpaceData(selector, sender, receiver, referralData)
_encodeJoinSpaceData(selector, msg.sender, receiver, referralData)
);

(bool isEntitled, bool isCrosschainPending) = _checkEntitlement(
receiver,
sender,
msg.sender,
transactionId
);

if (!isCrosschainPending) {
if (isEntitled) {
bool shouldCharge = _shouldChargeForJoinSpace();
if (shouldCharge) {
if (isNotReferral) {
_chargeForJoinSpace(transactionId);
} else {
_chargeForJoinSpaceWithReferral(transactionId);
}
} else {
_refundBalance(transactionId, sender);
}
if (shouldCharge) _chargeForJoinSpaceWithReferral(transactionId);

_refundBalance(transactionId, receiver);
_issueToken(receiver);
} else {
_captureData(transactionId, "");
_refundBalance(transactionId, sender);
emit MembershipTokenRejected(receiver);
_rejectMembership(transactionId, receiver);
}
}
}

function _getRequiredAmount() internal view returns (uint256) {
function _rejectMembership(bytes32 transactionId, address receiver) internal {
_captureData(transactionId, "");
_refundBalance(transactionId, receiver);
emit MembershipTokenRejected(receiver);
}

function _getRequiredAmount(uint256 price) internal view returns (uint256) {
// Check if there are any prepaid memberships available
uint256 prepaidSupply = _getPrepaidSupply();
if (prepaidSupply > 0) return 0; // If prepaid memberships exist, no payment is required

// Get the current membership price based on total supply
uint256 price = _getMembershipPrice(_totalSupply());
if (price == 0) return 0; // If the price is zero, no payment is required

// Calculate the protocol fee
Expand All @@ -130,11 +154,11 @@ abstract contract MembershipJoin is
}

function _validatePayment() internal view {
if (msg.value > 0) {
uint256 requiredAmount = _getRequiredAmount();
if (msg.value != requiredAmount)
CustomRevert.revertWith(Membership__InvalidPayment.selector);
}
// Get the current membership price based on total supply
uint256 membershipPrice = _getMembershipPrice(_totalSupply());
uint256 requiredAmount = _getRequiredAmount(membershipPrice);
if (msg.value < requiredAmount)
CustomRevert.revertWith(Membership__InsufficientPayment.selector);
}

function _validateUserReferral(
Expand All @@ -148,15 +172,6 @@ abstract contract MembershipJoin is
}
}

function _isNotReferral(
ReferralTypes memory referral
) internal pure returns (bool) {
return
referral.partner == address(0) &&
referral.userReferral == address(0) &&
bytes(referral.referralCode).length == 0;
}

/// @notice Checks if a user is entitled to join the space and handles the entitlement process
/// @dev This function checks both local and crosschain entitlements
/// @param receiver The address of the user trying to join the space
Expand Down Expand Up @@ -225,9 +240,8 @@ abstract contract MembershipJoin is
/// @notice Processes the charge for joining a space without referral
/// @param transactionId The unique identifier for this join transaction
function _chargeForJoinSpace(bytes32 transactionId) internal {
uint256 payment = _getCapturedValue(transactionId);
if (payment == 0)
CustomRevert.revertWith(Membership__InsufficientPayment.selector);
uint256 membershipPrice = _getMembershipPrice(_totalSupply());
uint256 paymentRequired = _getRequiredAmount(membershipPrice);

(bytes4 selector, address sender, address receiver, ) = abi.decode(
_getCapturedData(transactionId),
Expand All @@ -238,25 +252,24 @@ abstract contract MembershipJoin is
CustomRevert.revertWith(Membership__InvalidTransactionType.selector);
}

uint256 protocolFee = _collectProtocolFee(sender, payment);
uint256 remainingDue = payment - protocolFee;
uint256 protocolFee = _collectProtocolFee(sender, membershipPrice);
uint256 ownerProceeds = paymentRequired - protocolFee;

_afterChargeForJoinSpace(
transactionId,
sender,
receiver,
payment,
remainingDue,
paymentRequired,
ownerProceeds,
protocolFee
);
}

/// @notice Processes the charge for joining a space with referral
/// @param transactionId The unique identifier for this join transaction
function _chargeForJoinSpaceWithReferral(bytes32 transactionId) internal {
uint256 payment = _getCapturedValue(transactionId);
if (payment == 0)
CustomRevert.revertWith(Membership__InsufficientPayment.selector);
uint256 membershipPrice = _getMembershipPrice(_totalSupply());
uint256 paymentRequired = _getRequiredAmount(membershipPrice);

(
bytes4 selector,
Expand All @@ -274,25 +287,32 @@ abstract contract MembershipJoin is

ReferralTypes memory referral = abi.decode(referralData, (ReferralTypes));

uint256 protocolFee = _collectProtocolFee(sender, payment);
uint256 protocolFee = _collectProtocolFee(sender, membershipPrice);

uint256 partnerFee = _collectPartnerFee(sender, referral.partner, payment);
uint256 partnerFee = _collectPartnerFee(
sender,
referral.partner,
membershipPrice
);

uint256 referralFee = _collectReferralCodeFee(
sender,
referral.userReferral,
referral.referralCode,
payment
membershipPrice
);

uint256 remainingDue = payment - protocolFee - partnerFee - referralFee;
uint256 ownerProceeds = paymentRequired -
protocolFee -
partnerFee -
referralFee;

_afterChargeForJoinSpace(
transactionId,
sender,
receiver,
payment,
remainingDue,
paymentRequired,
ownerProceeds,
protocolFee
);
}
Expand All @@ -301,13 +321,14 @@ abstract contract MembershipJoin is
bytes32 transactionId,
address payer,
address receiver,
uint256 payment,
uint256 remainingDue,
uint256 paymentRequired,
uint256 ownerProceeds,
uint256 protocolFee
) internal {
if (remainingDue != 0) _transferIn(payer, remainingDue);
// account for owner's proceeds
if (ownerProceeds != 0) _transferIn(payer, ownerProceeds);

_releaseCapturedValue(transactionId, payment);
_releaseCapturedValue(transactionId, paymentRequired);
_captureData(transactionId, "");

// calculate points and credit them
Expand Down
Loading

0 comments on commit ce6dffb

Please sign in to comment.