Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(auth): audit issues with unordered txs #23392

Merged
merged 13 commits into from
Jan 31, 2025
Merged
17 changes: 17 additions & 0 deletions types/mempool/mempool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"math/rand"
"testing"
"time"

"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
Expand Down Expand Up @@ -55,6 +56,21 @@ type testTx struct {
address sdk.AccAddress
// useful for debugging
strAddress string
unordered bool
timeout *time.Time
}

// GetTimeoutTimeStamp implements types.TxWithUnordered.
func (tx testTx) GetTimeoutTimeStamp() time.Time {
if tx.timeout == nil {
return time.Time{}
}
return *tx.timeout
}

// GetUnordered implements types.TxWithUnordered.
func (tx testTx) GetUnordered() bool {
return tx.unordered
}

func (tx testTx) GetSigners() ([][]byte, error) { panic("not implemented") }
Expand All @@ -73,6 +89,7 @@ func (tx testTx) GetSignaturesV2() (res []txsigning.SignatureV2, err error) {

var (
_ sdk.Tx = (*testTx)(nil)
_ sdk.TxWithUnordered = (*testTx)(nil)
_ signing.SigVerifiableTx = (*testTx)(nil)
_ cryptotypes.PubKey = (*testPubKey)(nil)
)
Expand Down
20 changes: 10 additions & 10 deletions types/mempool/priority_nonce.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,13 @@ func (mp *PriorityNonceMempool[C]) Insert(ctx context.Context, tx sdk.Tx) error
priority := mp.cfg.TxPriority.GetTxPriority(ctx, tx)
nonce := sig.Sequence

// if it's an unordered tx, we use the gas instead of the nonce
// if it's an unordered tx, we use the timeout timestamp instead of the nonce
if unordered, ok := tx.(sdk.TxWithUnordered); ok && unordered.GetUnordered() {
gasLimit, err := unordered.GetGasLimit()
nonce = gasLimit
if err != nil {
return err
timestamp := unordered.GetTimeoutTimeStamp().Unix()
if timestamp < 0 {
return errors.New("invalid timestamp value")
}
nonce = uint64(timestamp)
}

key := txMeta[C]{nonce: nonce, priority: priority, sender: sender}
Expand Down Expand Up @@ -469,13 +469,13 @@ func (mp *PriorityNonceMempool[C]) Remove(tx sdk.Tx) error {
sender := sig.Signer.String()
nonce := sig.Sequence

// if it's an unordered tx, we use the gas instead of the nonce
// if it's an unordered tx, we use the timeout timestamp instead of the nonce
if unordered, ok := tx.(sdk.TxWithUnordered); ok && unordered.GetUnordered() {
gasLimit, err := unordered.GetGasLimit()
nonce = gasLimit
if err != nil {
return err
timestamp := unordered.GetTimeoutTimeStamp().Unix()
if timestamp < 0 {
return errors.New("invalid timestamp value")
}
nonce = uint64(timestamp)
}

scoreKey := txMeta[C]{nonce: nonce, sender: sender}
Expand Down
37 changes: 37 additions & 0 deletions types/mempool/priority_nonce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -970,3 +970,40 @@ func TestNextSenderTx_TxReplacement(t *testing.T) {
iter := mp.Select(ctx, nil)
require.Equal(t, txs[3], iter.Tx())
}

func TestPriorityNonceMempool_UnorderedTx(t *testing.T) {
ctx := sdk.NewContext(nil, false, log.NewNopLogger())
accounts := simtypes.RandomAccounts(rand.New(rand.NewSource(0)), 2)
sa := accounts[0].Address
sb := accounts[1].Address

mp := mempool.DefaultPriorityMempool()

now := time.Now()
oneHour := now.Add(1 * time.Hour)
thirtyMin := now.Add(30 * time.Minute)
twoHours := now.Add(2 * time.Hour)
fifteenMin := now.Add(15 * time.Minute)

txs := []testTx{
{id: 1, priority: 0, address: sa, timeout: &thirtyMin, unordered: true},
{id: 0, priority: 0, address: sa, timeout: &oneHour, unordered: true},
{id: 3, priority: 0, address: sb, timeout: &fifteenMin, unordered: true},
{id: 2, priority: 0, address: sb, timeout: &twoHours, unordered: true},
}

for _, tx := range txs {
c := ctx.WithPriority(tx.priority)
require.NoError(t, mp.Insert(c, tx))
}

require.Equal(t, 4, mp.CountTx())

orderedTxs := fetchTxs(mp.Select(ctx, nil), 100000)
require.Equal(t, len(txs), len(orderedTxs))

// check order
for i, tx := range orderedTxs {
require.Equal(t, txs[i].id, tx.(testTx).id)
}
}
28 changes: 14 additions & 14 deletions types/mempool/sender_nonce.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,21 @@ func (snm *SenderNonceMempool) Insert(_ context.Context, tx sdk.Tx) error {
sender := sdk.AccAddress(sig.PubKey.Address()).String()
nonce := sig.Sequence

// if it's an unordered tx, we use the timeout timestamp instead of the nonce
if unordered, ok := tx.(sdk.TxWithUnordered); ok && unordered.GetUnordered() {
timestamp := unordered.GetTimeoutTimeStamp().Unix()
if timestamp < 0 {
return errors.New("invalid timestamp value")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was in the old code before but as you touched it ( ;-) ), it would make sense to have this block moved before snm.senders is set in L145. We should not add elements to the object before the error cases are handled

}
nonce = uint64(timestamp)
}

senderTxs, found := snm.senders[sender]
if !found {
senderTxs = skiplist.New(skiplist.Uint64)
snm.senders[sender] = senderTxs
}

// if it's an unordered tx, we use the gas instead of the nonce
if unordered, ok := tx.(sdk.TxWithUnordered); ok && unordered.GetUnordered() {
gasLimit, err := unordered.GetGasLimit()
nonce = gasLimit
if err != nil {
return err
}
}

senderTxs.Set(nonce, tx)

key := txKey{nonce: nonce, address: sender}
Expand Down Expand Up @@ -236,13 +236,13 @@ func (snm *SenderNonceMempool) Remove(tx sdk.Tx) error {
sender := sdk.AccAddress(sig.PubKey.Address()).String()
nonce := sig.Sequence

// if it's an unordered tx, we use the gas instead of the nonce
// if it's an unordered tx, we use the timeout timestamp instead of the nonce
if unordered, ok := tx.(sdk.TxWithUnordered); ok && unordered.GetUnordered() {
gasLimit, err := unordered.GetGasLimit()
nonce = gasLimit
if err != nil {
return err
timestamp := unordered.GetTimeoutTimeStamp().Unix()
if timestamp < 0 {
return errors.New("invalid timestamp value")
}
nonce = uint64(timestamp)
}

senderTxs, found := snm.senders[sender]
Expand Down
65 changes: 65 additions & 0 deletions types/mempool/sender_nonce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"math/rand"
"testing"
"time"

"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -192,3 +193,67 @@ func (s *MempoolTestSuite) TestTxNotFoundOnSender() {
err = mp.Remove(tx)
require.Equal(t, mempool.ErrTxNotFound, err)
}

func (s *MempoolTestSuite) TestUnorderedTx() {
t := s.T()

ctx := sdk.NewContext(nil, false, log.NewNopLogger())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to just use context.Background() for this test

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need an sdk.Context here, because below we call ctx.WithPriority when calling mp.Insert

accounts := simtypes.RandomAccounts(rand.New(rand.NewSource(0)), 2)
sa := accounts[0].Address
sb := accounts[1].Address

mp := mempool.NewSenderNonceMempool(mempool.SenderNonceMaxTxOpt(5000))

now := time.Now()
oneHour := now.Add(1 * time.Hour)
thirtyMin := now.Add(30 * time.Minute)
twoHours := now.Add(2 * time.Hour)
fifteenMin := now.Add(15 * time.Minute)

txs := []testTx{
{id: 0, address: sa, timeout: &oneHour, unordered: true},
{id: 1, address: sa, timeout: &thirtyMin, unordered: true},
{id: 2, address: sb, timeout: &twoHours, unordered: true},
{id: 3, address: sb, timeout: &fifteenMin, unordered: true},
}

for _, tx := range txs {
c := ctx.WithPriority(tx.priority)
require.NoError(t, mp.Insert(c, tx))
}

require.Equal(t, 4, mp.CountTx())

orderedTxs := fetchTxs(mp.Select(ctx, nil), 100000)
require.Equal(t, len(txs), len(orderedTxs))

// Because the sender is selected randomly it can be any of these options
acceptableOptions := [][]int{
{3, 1, 2, 0},
{3, 1, 0, 2},
{3, 2, 1, 0},
{1, 3, 0, 2},
{1, 3, 2, 0},
{1, 0, 3, 2},
}

orderedTxsIds := make([]int, len(orderedTxs))
for i, tx := range orderedTxs {
orderedTxsIds[i] = tx.(testTx).id
}

anyAcceptableOrder := false
for _, option := range acceptableOptions {
for i, tx := range orderedTxs {
if tx.(testTx).id != txs[option[i]].id {
break
}

if i == len(orderedTxs)-1 {
anyAcceptableOrder = true
}
}
}

require.True(t, anyAcceptableOrder, "expected any of %v but got %v", acceptableOptions, orderedTxsIds)
}
32 changes: 32 additions & 0 deletions x/auth/ante/ante_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"strings"
"testing"
"time"

"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
Expand Down Expand Up @@ -1384,3 +1385,34 @@ func TestAnteHandlerReCheck(t *testing.T) {
_, err = suite.anteHandler(suite.ctx, tx, false)
require.NotNil(t, err, "antehandler on recheck did not fail once feePayer no longer has sufficient funds")
}

func TestAnteHandlerUnorderedTx(t *testing.T) {
suite := SetupTestSuite(t, false)
accs := suite.CreateTestAccounts(1)
msg := testdata.NewTestMsg(accs[0].acc.GetAddress())

// First send a normal sequential tx with sequence 0
suite.bankKeeper.EXPECT().SendCoinsFromAccountToModule(gomock.Any(), accs[0].acc.GetAddress(), authtypes.FeeCollectorName, testdata.NewTestFeeAmount()).Return(nil).AnyTimes()

privs, accNums, accSeqs := []cryptotypes.PrivKey{accs[0].priv}, []uint64{1000}, []uint64{0}
_, err := suite.DeliverMsgs(t, privs, []sdk.Msg{msg}, testdata.NewTestFeeAmount(), testdata.NewTestGasLimit(), accNums, accSeqs, suite.ctx.ChainID(), false)
require.NoError(t, err)
Comment on lines +1394 to +1399
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add assertion for account sequence.

The test should verify that the account sequence is incremented after the first transaction.

 	_, err := suite.DeliverMsgs(t, privs, []sdk.Msg{msg}, testdata.NewTestFeeAmount(), testdata.NewTestGasLimit(), accNums, accSeqs, suite.ctx.ChainID(), false)
 	require.NoError(t, err)
+	// Verify sequence is incremented
+	acc := suite.accountKeeper.GetAccount(suite.ctx, accs[0].acc.GetAddress())
+	require.Equal(t, uint64(1), acc.GetSequence())
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// First send a normal sequential tx with sequence 0
suite.bankKeeper.EXPECT().SendCoinsFromAccountToModule(gomock.Any(), accs[0].acc.GetAddress(), authtypes.FeeCollectorName, testdata.NewTestFeeAmount()).Return(nil).AnyTimes()
privs, accNums, accSeqs := []cryptotypes.PrivKey{accs[0].priv}, []uint64{1000}, []uint64{0}
_, err := suite.DeliverMsgs(t, privs, []sdk.Msg{msg}, testdata.NewTestFeeAmount(), testdata.NewTestGasLimit(), accNums, accSeqs, suite.ctx.ChainID(), false)
require.NoError(t, err)
// First send a normal sequential tx with sequence 0
suite.bankKeeper.EXPECT().SendCoinsFromAccountToModule(gomock.Any(), accs[0].acc.GetAddress(), authtypes.FeeCollectorName, testdata.NewTestFeeAmount()).Return(nil).AnyTimes()
privs, accNums, accSeqs := []cryptotypes.PrivKey{accs[0].priv}, []uint64{1000}, []uint64{0}
_, err := suite.DeliverMsgs(t, privs, []sdk.Msg{msg}, testdata.NewTestFeeAmount(), testdata.NewTestGasLimit(), accNums, accSeqs, suite.ctx.ChainID(), false)
require.NoError(t, err)
// Verify sequence is incremented
acc := suite.accountKeeper.GetAccount(suite.ctx, accs[0].acc.GetAddress())
require.Equal(t, uint64(1), acc.GetSequence())


// we try to send another tx with the same sequence, it will fail
_, err = suite.DeliverMsgs(t, privs, []sdk.Msg{msg}, testdata.NewTestFeeAmount(), testdata.NewTestGasLimit(), accNums, accSeqs, suite.ctx.ChainID(), false)
require.Error(t, err)
Comment on lines +1401 to +1403
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance error assertion.

The error check should verify the specific error type for sequence mismatch.

-	require.Error(t, err)
+	require.ErrorIs(t, err, sdkerrors.ErrWrongSequence)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// we try to send another tx with the same sequence, it will fail
_, err = suite.DeliverMsgs(t, privs, []sdk.Msg{msg}, testdata.NewTestFeeAmount(), testdata.NewTestGasLimit(), accNums, accSeqs, suite.ctx.ChainID(), false)
require.Error(t, err)
// we try to send another tx with the same sequence, it will fail
_, err = suite.DeliverMsgs(t, privs, []sdk.Msg{msg}, testdata.NewTestFeeAmount(), testdata.NewTestGasLimit(), accNums, accSeqs, suite.ctx.ChainID(), false)
require.ErrorIs(t, err, sdkerrors.ErrWrongSequence)


// now we'll still use the same sequence but because it's unordered, it will be ignored and accepted anyway
msgs := []sdk.Msg{msg}
require.NoError(t, suite.txBuilder.SetMsgs(msgs...))
suite.txBuilder.SetFeeAmount(testdata.NewTestFeeAmount())
suite.txBuilder.SetGasLimit(testdata.NewTestGasLimit())

tx, txErr := suite.CreateTestUnorderedTx(suite.ctx, privs, accNums, accSeqs, suite.ctx.ChainID(), apisigning.SignMode_SIGN_MODE_DIRECT, true, time.Now().Add(time.Minute))
require.NoError(t, txErr)
txBytes, err := suite.clientCtx.TxConfig.TxEncoder()(tx)
bytesCtx := suite.ctx.WithTxBytes(txBytes)
require.NoError(t, err)
_, err = suite.anteHandler(bytesCtx, tx, false)
require.NoError(t, err)
Comment on lines +1411 to +1417
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add timeout validation test cases.

The test should include cases for expired and future timeouts to ensure proper validation.

+	// Test expired timeout
+	expiredTx, _ := suite.CreateTestUnorderedTx(suite.ctx, privs, accNums, accSeqs, suite.ctx.ChainID(), apisigning.SignMode_SIGN_MODE_DIRECT, true, time.Now().Add(-time.Minute))
+	_, err = suite.anteHandler(bytesCtx, expiredTx, false)
+	require.ErrorIs(t, err, sdkerrors.ErrTxTimeoutHeight)
+
+	// Test far future timeout
+	futureTx, _ := suite.CreateTestUnorderedTx(suite.ctx, privs, accNums, accSeqs, suite.ctx.ChainID(), apisigning.SignMode_SIGN_MODE_DIRECT, true, time.Now().Add(24*time.Hour))
+	_, err = suite.anteHandler(bytesCtx, futureTx, false)
+	require.ErrorIs(t, err, sdkerrors.ErrInvalidTimeout)

Committable suggestion skipped: line range outside the PR's diff.

}
22 changes: 14 additions & 8 deletions x/auth/ante/sigverify.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,18 +320,24 @@ func (svd SigVerificationDecorator) consumeSignatureGas(
// verifySig will verify the signature of the provided signer account.
func (svd SigVerificationDecorator) verifySig(ctx context.Context, tx sdk.Tx, acc sdk.AccountI, sig signing.SignatureV2, newlyCreated bool) error {
execMode := svd.ak.GetEnvironment().TransactionService.ExecMode(ctx)
if execMode == transaction.ExecModeCheck {
if sig.Sequence < acc.GetSequence() {
unorderedTx, ok := tx.(sdk.TxWithUnordered)
isUnordered := ok && unorderedTx.GetUnordered()

// only check sequence if the tx is not unordered
if !isUnordered {
julienrbrt marked this conversation as resolved.
Show resolved Hide resolved
if execMode == transaction.ExecModeCheck {
if sig.Sequence < acc.GetSequence() {
return errorsmod.Wrapf(
sdkerrors.ErrWrongSequence,
"account sequence mismatch: expected higher than or equal to %d, got %d", acc.GetSequence(), sig.Sequence,
)
}
} else if sig.Sequence != acc.GetSequence() {
return errorsmod.Wrapf(
sdkerrors.ErrWrongSequence,
"account sequence mismatch, expected higher than or equal to %d, got %d", acc.GetSequence(), sig.Sequence,
"account sequence mismatch: expected %d, got %d", acc.GetSequence(), sig.Sequence,
)
}
} else if sig.Sequence != acc.GetSequence() {
return errorsmod.Wrapf(
sdkerrors.ErrWrongSequence,
"account sequence mismatch: expected %d, got %d", acc.GetSequence(), sig.Sequence,
)
}

// we're in simulation mode, or in ReCheckTx, or context is not
Expand Down
62 changes: 62 additions & 0 deletions x/auth/ante/testutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ante_test
import (
"context"
"testing"
"time"

"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
Expand Down Expand Up @@ -241,6 +242,67 @@ func (suite *AnteTestSuite) RunTestCase(t *testing.T, tc TestCase, args TestCase
}
}

func (suite *AnteTestSuite) CreateTestUnorderedTx(
ctx sdk.Context, privs []cryptotypes.PrivKey,
accNums, accSeqs []uint64,
chainID string, signMode apisigning.SignMode,
unordered bool, unorderedTimeout time.Time,
) (xauthsigning.Tx, error) {
suite.txBuilder.SetUnordered(unordered)
suite.txBuilder.SetTimeoutTimestamp(unorderedTimeout)

// First round: we gather all the signer infos. We use the "set empty
// signature" hack to do that.
var sigsV2 []signing.SignatureV2
for i, priv := range privs {
sigV2 := signing.SignatureV2{
PubKey: priv.PubKey(),
Data: &signing.SingleSignatureData{
SignMode: signMode,
Signature: nil,
},
Sequence: accSeqs[i],
}

sigsV2 = append(sigsV2, sigV2)
}
err := suite.txBuilder.SetSignatures(sigsV2...)
if err != nil {
return nil, err
}

// Second round: all signer infos are set, so each signer can sign.
sigsV2 = []signing.SignatureV2{}
for i, priv := range privs {
anyPk, err := codectypes.NewAnyWithValue(priv.PubKey())
if err != nil {
return nil, err
}

signerData := txsigning.SignerData{
Address: sdk.AccAddress(priv.PubKey().Address()).String(),
ChainID: chainID,
AccountNumber: accNums[i],
Sequence: accSeqs[i],
PubKey: &anypb.Any{TypeUrl: anyPk.TypeUrl, Value: anyPk.Value},
}
sigV2, err := tx.SignWithPrivKey(
ctx, signMode, signerData,
suite.txBuilder, priv, suite.clientCtx.TxConfig, accSeqs[i])
if err != nil {
return nil, err
}

sigsV2 = append(sigsV2, sigV2)
}
err = suite.txBuilder.SetSignatures(sigsV2...)
if err != nil {
return nil, err
}

return suite.txBuilder.GetTx(), nil
}

// CreateTestTx is a helper function to create a tx given multiple inputs.
func (suite *AnteTestSuite) CreateTestTx(
ctx sdk.Context, privs []cryptotypes.PrivKey,
Expand Down
Loading
Loading