Skip to content

Commit

Permalink
Implementation of ERC1155 checks in go, including xchain IT tests. (#958
Browse files Browse the repository at this point in the history
)
  • Loading branch information
clemire authored Sep 4, 2024
1 parent a6f64dd commit 9133f8d
Show file tree
Hide file tree
Showing 7 changed files with 2,655 additions and 58 deletions.
976 changes: 974 additions & 2 deletions core/contracts/base/channels.go

Large diffs are not rendered by default.

1,235 changes: 1,235 additions & 0 deletions core/contracts/base/deploy/mock_erc1155.go

Large diffs are not rendered by default.

32 changes: 32 additions & 0 deletions core/contracts/types/test_util/entitlements.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,38 @@ func Erc721Check(chainId uint64, contractAddress common.Address, threshold uint6
}
}

func Erc1155Check(
chainId uint64,
contractAddress common.Address,
threshold uint64,
tokenId uint64,
) base.IRuleEntitlementBaseRuleDataV2 {
params := contract_types.ERC1155Params{
Threshold: new(big.Int).SetUint64(threshold),
TokenId: new(big.Int).SetUint64(tokenId),
}
encodedParams, err := params.AbiEncode()
if err != nil {
panic(err)
}
return base.IRuleEntitlementBaseRuleDataV2{
Operations: []base.IRuleEntitlementBaseOperation{
{
OpType: uint8(contract_types.CHECK),
Index: 0,
},
},
CheckOperations: []base.IRuleEntitlementBaseCheckOperationV2{
{
OpType: uint8(contract_types.ERC1155),
ChainId: new(big.Int).SetUint64(chainId),
ContractAddress: contractAddress,
Params: encodedParams,
},
},
}
}

func Erc20Check(chainId uint64, contractAddress common.Address, threshold uint64) base.IRuleEntitlementBaseRuleData {
return base.IRuleEntitlementBaseRuleData{
Operations: []base.IRuleEntitlementBaseOperation{
Expand Down
179 changes: 138 additions & 41 deletions core/xchain/entitlement/check_operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,76 +13,131 @@ import (
"github.com/river-build/river/core/contracts/base"
"github.com/river-build/river/core/contracts/types"
"github.com/river-build/river/core/node/dlog"
"github.com/river-build/river/core/xchain/bindings/erc1155"
"github.com/river-build/river/core/xchain/bindings/erc20"
"github.com/river-build/river/core/xchain/bindings/erc721"
)

func (e *Evaluator) evaluateCheckOperation(
ctx context.Context,
op *types.CheckOperation,
linkedWallets []common.Address,
) (bool, error) {
defer prometheus.NewTimer(e.evalHistrogram.WithLabelValues(op.CheckType.String())).ObserveDuration()
func checkThresholdParam(threshold *big.Int) error {
if threshold == nil {
return fmt.Errorf("threshold is nil")
}
if threshold.Sign() <= 0 {
return fmt.Errorf(
"threshold %s is nonpositive",
threshold,
)
}
return nil
}

if op.CheckType == types.MOCK {
return e.evaluateMockOperation(ctx, op)
} else if op.CheckType == types.CheckNONE {
return false, fmt.Errorf("unknown operation")
func checkTokenIdParam(tokenId *big.Int) error {
if tokenId == nil {
return fmt.Errorf("token ID is nil")
}
if tokenId.Sign() < 0 {
return fmt.Errorf("token ID %s is negative", tokenId)
}
return nil
}

// Sanity checks
log := dlog.FromCtx(ctx).With("function", "evaluateCheckOperation")
func validateCheckOperation(ctx context.Context, op *types.CheckOperation) error {
// Validation for each of the following fields is applied to relevant check types.
// 1. Chain ID is not nil
// 2. Contract address is not nil
// 3. Threshold is positive
// 4. Token ID is non-negative
log := dlog.FromCtx(ctx).With("function", "validateCheckOperation")
if op.ChainID == nil {
log.Error("Entitlement check: chain ID is nil for operation", "operation", op.CheckType.String())
return false, fmt.Errorf("evaluateCheckOperation: Chain ID is nil for operation %s", op.CheckType)
return fmt.Errorf("validateCheckOperation: chain ID is nil for operation %s", op.CheckType)
}

zeroAddress := common.Address{}
if op.CheckType != types.NATIVE_COIN_BALANCE && op.ContractAddress == zeroAddress {
log.Error("Entitlement check: contract address is nil for operation", "operation", op.CheckType.String())
return false, fmt.Errorf(
"evaluateCheckOperation: Contract address is nil for operation %s",
return fmt.Errorf(
"validateCheckOperation: contract address is nil for operation %s",
op.CheckType,
)
}

if op.CheckType == types.ERC20 || op.CheckType == types.ERC721 || op.CheckType == types.ERC1155 ||
op.CheckType == types.NATIVE_COIN_BALANCE {
if op.CheckType == types.ERC20 || op.CheckType == types.ERC721 || op.CheckType == types.NATIVE_COIN_BALANCE {
params, err := types.DecodeThresholdParams(op.Params)
if err != nil {
log.Error(
"evaluateCheckOperation: failed to decode threshold params",
"validateCheckOperation: failed to decode threshold params",
"error",
err,
"params",
op.Params,
"operation",
op.CheckType.String(),
)
return false, err
return fmt.Errorf("validateCheckOperation: failed to decode threshold params, %w", err)
}
if params.Threshold == nil {
log.Error("Entitlement check: threshold is nil for operation", "operation", op.CheckType.String())
return false, fmt.Errorf(
"evaluateCheckOperation: Threshold is nil for operation %s",
op.CheckType,
if err := checkThresholdParam(params.Threshold); err != nil {
// Wrap the error with the operation type
err = fmt.Errorf("validateCheckOperation: %w for operation %s", err, op.CheckType)
log.Error(
"Entitlement check: invalid threshold for operation",
"operation",
op.CheckType.String(),
"error",
err,
)
return err
}
if params.Threshold.Sign() <= 0 {
} else if op.CheckType == types.ERC1155 {
params, err := types.DecodeERC1155Params(op.Params)
if err != nil {
log.Error("validateCheckOperation: failed to decode ERC1155 params", "error", err)
return fmt.Errorf("validateCheckOperation: failed to decode ERC1155 params, %w", err)
}
if err := checkTokenIdParam(params.TokenId); err != nil {
// Wrap the error with the operation type
err = fmt.Errorf("validateCheckOperation: %w for operation %s", err, op.CheckType)
log.Error(
"Entitlement check: threshold is nonpositive for operation",
"Entitlement check: invalid token ID for operation",
"operation",
op.CheckType.String(),
"threshold",
params.Threshold.String(),
"error",
err,
)
return false, fmt.Errorf(
"evaluateCheckOperation: Threshold %s is nonpositive for operation %s",
params.Threshold,
op.CheckType,
return err
}
if err := checkThresholdParam(params.Threshold); err != nil {
// Wrap the error with the operation type
err = fmt.Errorf("validateCheckOperation: %w for operation %s", err, op.CheckType)
log.Error(
"Entitlement check: invalid threshold for operation",
"operation",
op.CheckType.String(),
"error",
err,
)
return err
}
}
return nil
}

func (e *Evaluator) evaluateCheckOperation(
ctx context.Context,
op *types.CheckOperation,
linkedWallets []common.Address,
) (bool, error) {
defer prometheus.NewTimer(e.evalHistrogram.WithLabelValues(op.CheckType.String())).ObserveDuration()

if op.CheckType == types.MOCK {
return e.evaluateMockOperation(ctx, op)
} else if op.CheckType == types.CheckNONE {
return false, fmt.Errorf("unknown operation")
}

if err := validateCheckOperation(ctx, op); err != nil {
return false, err
}

switch op.CheckType {
case types.ISENTITLED:
Expand All @@ -92,7 +147,7 @@ func (e *Evaluator) evaluateCheckOperation(
case types.ERC721:
return e.evaluateErc721Operation(ctx, op, linkedWallets)
case types.ERC1155:
return e.evaluateErc1155Operation(ctx, op)
return e.evaluateErc1155Operation(ctx, op, linkedWallets)
case types.NATIVE_COIN_BALANCE:
return e.evaluateNativeCoinBalanceOperation(ctx, op, linkedWallets)
case types.CheckNONE:
Expand Down Expand Up @@ -327,12 +382,6 @@ func (e *Evaluator) evaluateErc721Operation(

// Accumulate the total balance across evaluated wallets
total.Add(total, tokenBalance)
// log.Info("Retrieved ERC721 token balance for wallet",
// "balance", tokenBalance.String(),
// "total", total.String(),
// "threshold", op.Threshold.String(),
// "wallet", wallet,
// )

// Iteratively check if the total balance of evaluated wallets is greater than or equal to the threshold
// Note threshold is always positive and total is non-negative.
Expand All @@ -343,8 +392,56 @@ func (e *Evaluator) evaluateErc721Operation(
return false, err
}

func (e *Evaluator) evaluateErc1155Operation(ctx context.Context,
func (e *Evaluator) evaluateErc1155Operation(
ctx context.Context,
op *types.CheckOperation,
linkedWallets []common.Address,
) (bool, error) {
return false, fmt.Errorf("ERC1155 not implemented")
log := dlog.FromCtx(ctx).With("function", "evaluateErc1155Operation")

client, err := e.clients.Get(op.ChainID.Uint64())
if err != nil {
log.Error("Chain ID not found", "chainID", op.ChainID)
return false, fmt.Errorf("evaluateErc1155Operation: Chain ID %v not found", op.ChainID)
}

collection, err := erc1155.NewErc1155Caller(op.ContractAddress, client)
if err != nil {
log.Error("Failed to instantiate an ERC1155 contract",
"err", err,
"contractAddress", op.ContractAddress,
)
return false, err
}

// Decode the ERC1155 params
params, err := types.DecodeERC1155Params(op.Params)
if err != nil {
log.Error("evaluateErc1155Operation: failed to decode erc1155 params", "error", err)
return false, fmt.Errorf("evaluateErc1155Operation: failed to decode erc1155 params, %w", err)
}

total := big.NewInt(0)
for _, wallet := range linkedWallets {
tokenBalance, err := collection.BalanceOf(&bind.CallOpts{Context: ctx}, wallet, params.TokenId)
if err != nil {
log.Error("Failed to retrieve ERC1155 token balance",
"error", err,
"contractAddress", op.ContractAddress,
"wallet", wallet,
"tokenId", params.TokenId.String(),
)
return false, err
}

// Accumulate the total balance across evaluated wallets
total.Add(total, tokenBalance)

// Iteratively check if the total balance of evaluated wallets is greater than or equal to the threshold
// Note threshold is always positive and total is non-negative.
if total.Cmp(params.Threshold) >= 0 {
return true, nil
}
}
return false, err
}
Loading

0 comments on commit 9133f8d

Please sign in to comment.