Skip to content

Commit

Permalink
remove IncreaseAllowance & add shared nonce (#1)
Browse files Browse the repository at this point in the history
* channel: Include shared Nonce object, incrementing the nonce of an account while the transaction is broadcast. Implement sync.Mutex, locking the nonce incrementals.
In erc20_depositor.go, the depositing process is locked during the "Approve" function call with ```lockKey```. The function ```(d *ERC20Depositor) Deposit``` is renamed to `````(d *ERC20Depositor) DepositOnly`

* client: To avoid timeouts caused by the added locking mechanism during ERC20 deposits, ```context.WithTimeout``` and ```twoPartyTestTimeout``` has been increased in fund_test.go and payment_test.go.

* client/test: Timeouts have been increased to avoid premature timeout errors due to the locking mechanism above.

---------

Signed-off-by: Ilja von Hoessle <[email protected]>
Signed-off-by: sophia1ch <[email protected]>
  • Loading branch information
sophia1ch committed Jan 9, 2024
1 parent 5bbe33a commit b985d8a
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 39 deletions.
1 change: 1 addition & 0 deletions channel/conclude.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ const (
// - it searches for a past concluded event by calling `isConcluded`
// - if found, channel is already concluded and success is returned
// - if none found, conclude/concludeFinal is called on the adjudicator
//
// - it waits for a Concluded event from the blockchain.
func (a *Adjudicator) ensureConcluded(ctx context.Context, req channel.AdjudicatorReq, subStates channel.StateMap) error {
// Check whether it is already concluded.
Expand Down
34 changes: 27 additions & 7 deletions channel/contractbackend.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ const (
// create a TxTimedoutError with additional context.
var errTxTimedOut = errors.New("")

// SharedExpected Nonce is a map of each expected next nonce of all clients.
var (
SharedExpectedNonces map[ChainID]map[common.Address]uint64
SharedNonceMtx map[ChainID]map[common.Address]*sync.Mutex
)

// SharedMutex controls the reads and writes on the nonceMtx and ecpectedNextNonce of the ContractBackend.
var SharedMutex = &sync.Mutex{}

// ContractInterface provides all functions needed by an ethereum backend.
// Both test.SimulatedBackend and ethclient.Client implement this interface.
type ContractInterface interface {
Expand All @@ -63,7 +72,7 @@ type Transactor interface {
type ContractBackend struct {
ContractInterface
tr Transactor
nonceMtx *sync.Mutex
nonceMtx map[common.Address]*sync.Mutex
expectedNextNonce map[common.Address]uint64
txFinalityDepth uint64
chainID ChainID
Expand All @@ -73,11 +82,24 @@ type ContractBackend struct {
// txFinalityDepth defines in how many consecutive blocks a TX has to be
// included to be considered final. Must be at least 1.
func NewContractBackend(cf ContractInterface, chainID ChainID, tr Transactor, txFinalityDepth uint64) ContractBackend {
// Check if the shared maps are initialized, if not, initialize them.
if SharedExpectedNonces == nil {
SharedExpectedNonces = make(map[ChainID]map[common.Address]uint64)
}
if SharedNonceMtx == nil {
SharedNonceMtx = make(map[ChainID]map[common.Address]*sync.Mutex)
}

// Check if the specific chainID entry exists in the shared maps, if not, create it.
if _, exists := SharedExpectedNonces[chainID]; !exists {
SharedExpectedNonces[chainID] = make(map[common.Address]uint64)
SharedNonceMtx[chainID] = make(map[common.Address]*sync.Mutex)
}
return ContractBackend{
ContractInterface: cf,
tr: tr,
expectedNextNonce: make(map[common.Address]uint64),
nonceMtx: &sync.Mutex{},
expectedNextNonce: SharedExpectedNonces[chainID],
nonceMtx: SharedNonceMtx[chainID],
txFinalityDepth: txFinalityDepth,
chainID: chainID,
}
Expand Down Expand Up @@ -165,10 +187,7 @@ func (c *ContractBackend) nonce(ctx context.Context, sender common.Address) (uin
err = cherrors.CheckIsChainNotReachableError(err)
return 0, errors.WithMessage(err, "fetching nonce")
}

// Look up expected next nonce locally.
c.nonceMtx.Lock()
defer c.nonceMtx.Unlock()
SharedMutex.Lock()
expectedNextNonce, found := c.expectedNextNonce[sender]
if !found {
c.expectedNextNonce[sender] = 0
Expand All @@ -181,6 +200,7 @@ func (c *ContractBackend) nonce(ctx context.Context, sender common.Address) (uin

// Update local expectation.
c.expectedNextNonce[sender] = nonce + 1
SharedMutex.Unlock()
return nonce, nil
}

Expand Down
135 changes: 118 additions & 17 deletions channel/erc20_depositor.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@ package channel

import (
"context"
"fmt"
"math/big"
"sync"

"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/pkg/errors"
"perun.network/go-perun/log"

"github.com/perun-network/perun-eth-backend/bindings/assetholdererc20"
"github.com/perun-network/perun-eth-backend/bindings/peruntoken"
Expand All @@ -40,42 +45,138 @@ const ERC20DepositorTXGasLimit = 100000
// Return value of ERC20Depositor.NumTx.
const erc20DepositorNumTx = 2

// Keep track of the increase allowance and deposit processes.
var mu sync.Mutex
var locks = make(map[string]*sync.Mutex)

// DepositResult is created to keep track of the returned values.
type DepositResult struct {
Transactions types.Transactions
Error error
}

// Create key from account address and asset to only lock the process when hub deposits the same asset at the same time.
func lockKey(account common.Address, asset common.Address) string {
return fmt.Sprintf("%s-%s", account.Hex(), asset.Hex())
}

// Retrieves Lock for specific key.
func handleLock(lockKey string) *sync.Mutex {
mu.Lock()
defer mu.Unlock()

if lock, exists := locks[lockKey]; exists {
return lock
}

lock := &sync.Mutex{}
locks[lockKey] = lock
return lock
}

// Locks the lock argument, runs the given function and then unlocks the lock argument.
func lockAndUnlock(lock *sync.Mutex, fn func()) {
mu.Lock()
defer mu.Unlock()
lock.Lock()
defer lock.Unlock()
fn()
}

// NewERC20Depositor creates a new ERC20Depositor.
func NewERC20Depositor(token common.Address) *ERC20Depositor {
return &ERC20Depositor{Token: token}
}

// Deposit deposits ERC20 tokens into the ERC20 AssetHolder specified at the
// request's asset address.
// Deposit approves the value to be swapped and calls DepositOnly.
//
//nolint:funlen
func (d *ERC20Depositor) Deposit(ctx context.Context, req DepositReq) (types.Transactions, error) {
// Bind a `AssetHolderERC20` instance.
assetholder, err := assetholdererc20.NewAssetholdererc20(req.Asset.EthAddress(), req.CB)
if err != nil {
return nil, errors.Wrapf(err, "binding AssetHolderERC20 contract at: %x", req.Asset)
}
lockKey := lockKey(req.Account.Address, req.Asset.EthAddress())
lock := handleLock(lockKey)

// Bind an `ERC20` instance.
token, err := peruntoken.NewPeruntoken(d.Token, req.CB)
if err != nil {
return nil, errors.Wrapf(err, "binding ERC20 contract at: %x", d.Token)
}
// Increase the allowance.
opts, err := req.CB.NewTransactor(ctx, ERC20DepositorTXGasLimit, req.Account)
if err != nil {
return nil, errors.WithMessagef(err, "creating transactor for asset: %x", req.Asset)
callOpts := bind.CallOpts{
Pending: false,
Context: ctx,
}
tx1, err := token.IncreaseAllowance(opts, req.Asset.EthAddress(), req.Balance)
// variables for the return value.
var depResult DepositResult
var approvalReceived bool
var tx1 *types.Transaction
var err1 error
lockAndUnlock(lock, func() {
allowance, err := token.Allowance(&callOpts, req.Account.Address, req.Asset.EthAddress())
if err != nil {
depResult.Transactions = nil
depResult.Error = errors.WithMessagef(err, "could not get Allowance for asset: %x", req.Asset)
}
result := new(big.Int).Add(req.Balance, allowance)

// Increase the allowance.
opts, err := req.CB.NewTransactor(ctx, ERC20DepositorTXGasLimit, req.Account)
if err != nil {
depResult.Transactions = nil
depResult.Error = errors.WithMessagef(err, "creating transactor for asset: %x", req.Asset)
}
// Create a channel for receiving PeruntokenApproval events
eventSink := make(chan *peruntoken.PeruntokenApproval)

// Create a channel for receiving the Approval event
eventReceived := make(chan bool)

// Watch for Approval events and send them to the eventSink
subscription, err := token.WatchApproval(&bind.WatchOpts{Start: nil, Context: ctx}, eventSink, []common.Address{req.Account.Address}, []common.Address{req.Asset.EthAddress()})
if err != nil {
depResult.Transactions = nil
depResult.Error = errors.WithMessagef(err, "Cannot listen for event")
}
tx1, err1 = token.Approve(opts, req.Asset.EthAddress(), result)
if err1 != nil {
err = cherrors.CheckIsChainNotReachableError(err)
depResult.Transactions = nil
depResult.Error = errors.WithMessagef(err, "increasing allowance for asset: %x", req.Asset)
}

go func() {
select {
case event := <-eventSink:
log.Printf("Received Approval event: Owner: %s, Spender: %s, Value: %s\n", event.Owner.Hex(), event.Spender.Hex(), event.Value.String())
eventReceived <- true
case err := <-subscription.Err():
log.Println("Subscription error:", err)
}
}()
approvalReceived = <-eventReceived
})
if approvalReceived {
tx2, err := d.DepositOnly(ctx, req)
depResult.Transactions = []*types.Transaction{tx1, tx2}
depResult.Error = errors.WithMessage(err, "AssetHolderERC20 depositing")
}
return depResult.Transactions, depResult.Error
}

// DepositOnly deposits ERC20 tokens into the ERC20 AssetHolder specified at the
// requests asset address.
func (d *ERC20Depositor) DepositOnly(ctx context.Context, req DepositReq) (*types.Transaction, error) {
// Bind a `AssetHolderERC20` instance.
assetholder, err := assetholdererc20.NewAssetholdererc20(req.Asset.EthAddress(), req.CB)
if err != nil {
err = cherrors.CheckIsChainNotReachableError(err)
return nil, errors.WithMessagef(err, "increasing allowance for asset: %x", req.Asset)
return nil, errors.Wrapf(err, "binding AssetHolderERC20 contract at: %x", req.Asset)
}
// Deposit.
opts, err = req.CB.NewTransactor(ctx, ERC20DepositorTXGasLimit, req.Account)
opts, err := req.CB.NewTransactor(ctx, ERC20DepositorTXGasLimit, req.Account)
if err != nil {
return nil, errors.WithMessagef(err, "creating transactor for asset: %x", req.Asset)
}

tx2, err := assetholder.Deposit(opts, req.FundingID, req.Balance)
err = cherrors.CheckIsChainNotReachableError(err)
return []*types.Transaction{tx1, tx2}, errors.WithMessage(err, "AssetHolderERC20 depositing")
return tx2, err
}

// NumTX returns 2 since it does IncreaseAllowance and Deposit.
Expand Down
4 changes: 2 additions & 2 deletions channel/funder.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,11 @@ func (f *Funder) Fund(ctx context.Context, request channel.FundingReq) error {
nonFundingErrg := perror.NewGatherer()
for _, err := range perror.Causes(errg.Wait()) {
if channel.IsAssetFundingError(err) && err != nil {
fudingErr, ok := err.(*channel.AssetFundingError)
fundingErr, ok := err.(*channel.AssetFundingError)
if !ok {
return fmt.Errorf("wrong type: expected %T, got %T", &channel.AssetFundingError{}, err)
}
fundingErrs = append(fundingErrs, fudingErr)
fundingErrs = append(fundingErrs, fundingErr)
} else if err != nil {
nonFundingErrg.Add(err)
}
Expand Down
13 changes: 5 additions & 8 deletions channel/withdraw.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,19 @@ import (
"perun.network/go-perun/log"
)

// Withdraw ensures that a channel has been concluded and the final outcome
// Withdraw ensures that a channel has been concluded and the final outcome.
// withdrawn from the asset holders.
func (a *Adjudicator) Withdraw(ctx context.Context, req channel.AdjudicatorReq, subStates channel.StateMap) error {
if err := a.ensureConcluded(ctx, req, subStates); err != nil {
return errors.WithMessage(err, "ensure Concluded")
}

if err := a.checkConcludedState(ctx, req, subStates); err != nil {
return errors.WithMessage(err, "check concluded state")
}

return errors.WithMessage(a.ensureWithdrawn(ctx, req), "ensure Withdrawn")
}

// ensureWithdrawn ensures that the channel has been withdrawn from the asset.
func (a *Adjudicator) ensureWithdrawn(ctx context.Context, req channel.AdjudicatorReq) error {
g, ctx := errgroup.WithContext(ctx)

Expand All @@ -75,7 +74,7 @@ func (a *Adjudicator) ensureWithdrawn(ctx context.Context, req channel.Adjudicat
}
defer sub.Close()

// Check for past event
// Check for past event.
if err := sub.ReadPast(ctx, events); err != nil {
return errors.WithMessage(err, "reading past events")
}
Expand All @@ -90,7 +89,7 @@ func (a *Adjudicator) ensureWithdrawn(ctx context.Context, req channel.Adjudicat
return errors.WithMessage(err, "withdrawing assets failed")
}

// Wait for event
// Wait for event.
go func() {
subErr <- sub.Read(ctx, events)
}()
Expand Down Expand Up @@ -146,13 +145,11 @@ func (a *Adjudicator) callAssetWithdraw(ctx context.Context, request channel.Adj
if err != nil {
return nil, errors.WithMessagef(err, "creating transactor for asset %d", asset.assetIndex)
}

tx, err := asset.Withdraw(trans, auth, sig)
if err != nil {
err = cherrors.CheckIsChainNotReachableError(err)
return nil, errors.WithMessagef(err, "withdrawing asset %d", asset.assetIndex)
return nil, errors.WithMessagef(err, "withdrawing asset %d with transaction nonce %d", asset.assetIndex, trans.Nonce)
}
log.Debugf("Sent transaction %v", tx.Hash().Hex())
return tx, nil
}()
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion client/fund_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
func TestFundRecovery(t *testing.T) {
rng := test.Prng(t)

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

ctest.TestFundRecovery(
Expand Down
2 changes: 1 addition & 1 deletion client/payment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import (
)

const (
twoPartyTestTimeout = 10 * time.Second
twoPartyTestTimeout = 60 * time.Second
TxFinalityDepth = 3
)

Expand Down
6 changes: 3 additions & 3 deletions client/test/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ import (

const (
// DefaultTimeout is the default timeout for client tests.
DefaultTimeout = 5 * time.Second
DefaultTimeout = 20 * time.Second
// BlockInterval is the default block interval for the simulated chain.
BlockInterval = 100 * time.Millisecond
BlockInterval = 200 * time.Millisecond
// challenge duration in blocks that is used by MakeRoleSetups.
challengeDurationBlocks = 60
challengeDurationBlocks = 90
)

// MakeRoleSetups creates a two party client test setup with the provided names.
Expand Down

0 comments on commit b985d8a

Please sign in to comment.