From 49e60d94e283255d7276b2dfe1faec869759204a Mon Sep 17 00:00:00 2001 From: Wondertan Date: Thu, 12 Dec 2024 06:10:32 +0100 Subject: [PATCH] perf(shwap): cache Both Row sides --- share/eds/rsmt2d.go | 7 +-- share/shwap/row.go | 97 +++++++++++++++++++++++++++-------------- share/shwap/row_test.go | 52 +++++++++------------- 3 files changed, 86 insertions(+), 70 deletions(-) diff --git a/share/eds/rsmt2d.go b/share/eds/rsmt2d.go index 6c244de700..1d806324b5 100644 --- a/share/eds/rsmt2d.go +++ b/share/eds/rsmt2d.go @@ -99,12 +99,7 @@ func (eds *Rsmt2D) AxisHalf(_ context.Context, axisType rsmt2d.Axis, axisIdx int // HalfRow constructs a new shwap.Row from an Extended Data Square based on the specified index and // side. func (eds *Rsmt2D) HalfRow(idx int, side shwap.RowSide) (shwap.Row, error) { - shares := eds.ExtendedDataSquare.Row(uint(idx)) - sh, err := libshare.FromBytes(shares) - if err != nil { - return shwap.Row{}, fmt.Errorf("while converting shares from bytes: %w", err) - } - return shwap.RowFromShares(sh, side), nil + return shwap.RowFromEDS(eds.ExtendedDataSquare, idx, side) } // RowNamespaceData returns data for the given namespace and row index. diff --git a/share/shwap/row.go b/share/shwap/row.go index ba2c0d0b82..7057680573 100644 --- a/share/shwap/row.go +++ b/share/shwap/row.go @@ -6,6 +6,7 @@ import ( "github.com/celestiaorg/celestia-app/v3/pkg/wrapper" libshare "github.com/celestiaorg/go-square/v2/share" + "github.com/celestiaorg/rsmt2d" "github.com/celestiaorg/celestia-node/share" "github.com/celestiaorg/celestia-node/share/shwap/pb" @@ -20,33 +21,42 @@ type RowSide int const ( Left RowSide = iota // Left side of the row. Right // Right side of the row. + Both // Both sides of the row. ) // Row represents a portion of a row in an EDS, either left or right half. type Row struct { - halfShares []libshare.Share // halfShares holds the shares of either the left or right half of a row. - side RowSide // side indicates whether the row half is left or right. + shares []libshare.Share // holds the shares of Left or Right or Both sides of the row. + side RowSide // side indicates which side the shares belong to. } // NewRow creates a new Row with the specified shares and side. -func NewRow(halfShares []libshare.Share, side RowSide) Row { +func NewRow(shares []libshare.Share, side RowSide) Row { return Row{ - halfShares: halfShares, - side: side, + shares: shares, + side: side, } } -// RowFromShares constructs a new Row from an Extended Data Square based on the specified index and -// side. -func RowFromShares(shares []libshare.Share, side RowSide) Row { - var halfShares []libshare.Share - if side == Right { - halfShares = shares[len(shares)/2:] // Take the right half of the shares. - } else { - halfShares = shares[:len(shares)/2] // Take the left half of the shares. +// RowFromEDS constructs a new Row from an EDS based on the specified row index and side. +func RowFromEDS(eds *rsmt2d.ExtendedDataSquare, rowIdx int, side RowSide) (Row, error) { + rowBytes := eds.Row(uint(rowIdx)) + shares, err := libshare.FromBytes(rowBytes) + if err != nil { + return Row{}, fmt.Errorf("while converting shares from bytes: %w", err) + } + + switch side { + case Both: + case Left: + shares = shares[:len(shares)/2] + case Right: + shares = shares[len(shares)/2:] + default: + return Row{}, fmt.Errorf("invalid RowSide: %d", side) } - return NewRow(halfShares, side) + return NewRow(shares, side), nil } // RowFromProto converts a protobuf Row to a Row structure. @@ -56,20 +66,22 @@ func RowFromProto(r *pb.Row) (Row, error) { return Row{}, err } return Row{ - halfShares: shrs, - side: sideFromProto(r.GetHalfSide()), + shares: shrs, + side: sideFromProto(r.GetHalfSide()), }, nil } // Shares reconstructs the complete row shares from the half provided, using RSMT2D for data // recovery if needed. -func (r Row) Shares() ([]libshare.Share, error) { - shares := make([]libshare.Share, len(r.halfShares)*2) - offset := 0 - if r.side == Right { - offset = len(r.halfShares) // Position the halfShares in the second half if it's the right side. +// It caches the reconstructed shares for future use and converts Row to Both side. +func (r *Row) Shares() ([]libshare.Share, error) { + if r.side == Both { + return r.shares, nil } - for i, share := range r.halfShares { + + shares := make([]libshare.Share, len(r.shares)*2) + offset := len(r.shares) * int(r.side) + for i, share := range r.shares { shares[i+offset] = share } @@ -77,33 +89,52 @@ func (r Row) Shares() ([]libshare.Share, error) { if err != nil { return nil, err } - return libshare.FromBytes(rowShares) + + r.shares, err = libshare.FromBytes(rowShares) + if err != nil { + return nil, err + } + + r.side = Both + return r.shares, nil } // ToProto converts the Row to its protobuf representation. func (r Row) ToProto() *pb.Row { + if r.side == Both { + // we don't need to send the whole row over the wire + // so if we have both sides, we can save bandwidth and send the left half only + return &pb.Row{ + SharesHalf: SharesToProto(r.shares[:len(r.shares)/2]), + HalfSide: pb.Row_LEFT, + } + } + return &pb.Row{ - SharesHalf: SharesToProto(r.halfShares), + SharesHalf: SharesToProto(r.shares), HalfSide: r.side.ToProto(), } } // IsEmpty reports whether the Row is empty, i.e. doesn't contain any shares. func (r Row) IsEmpty() bool { - return r.halfShares == nil + return len(r.shares) == 0 } // Verify checks if the row's shares match the expected number from the root data and validates // the side of the row. -func (r Row) Verify(roots *share.AxisRoots, idx int) error { - if len(r.halfShares) == 0 { - return fmt.Errorf("empty half row") +func (r *Row) Verify(roots *share.AxisRoots, idx int) error { + if len(r.shares) == 0 { + return fmt.Errorf("empt row") + } + expectedShares := len(roots.RowRoots) + if r.side != Both { + expectedShares /= 2 } - expectedShares := len(roots.RowRoots) / 2 - if len(r.halfShares) != expectedShares { - return fmt.Errorf("shares size doesn't match root size: %d != %d", len(r.halfShares), expectedShares) + if len(r.shares) != expectedShares { + return fmt.Errorf("shares size doesn't match root size: %d != %d", len(r.shares), expectedShares) } - if r.side != Left && r.side != Right { + if r.side != Left && r.side != Right && r.side != Both { return fmt.Errorf("invalid RowSide: %d", r.side) } @@ -115,7 +146,7 @@ func (r Row) Verify(roots *share.AxisRoots, idx int) error { // verifyInclusion verifies the integrity of the row's shares against the provided root hash for the // given row index. -func (r Row) verifyInclusion(roots *share.AxisRoots, idx int) error { +func (r *Row) verifyInclusion(roots *share.AxisRoots, idx int) error { shrs, err := r.Shares() if err != nil { return fmt.Errorf("while extending shares: %w", err) diff --git a/share/shwap/row_test.go b/share/shwap/row_test.go index 9249ea87f4..16bce3893b 100644 --- a/share/shwap/row_test.go +++ b/share/shwap/row_test.go @@ -11,28 +11,20 @@ import ( "github.com/celestiaorg/celestia-node/share/eds/edstest" ) -func TestRowFromShares(t *testing.T) { +func TestRowShares(t *testing.T) { const odsSize = 8 eds := edstest.RandEDS(t, odsSize) for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { - for _, side := range []RowSide{Left, Right} { - shrs := eds.Row(uint(rowIdx)) - shares, err := libshare.FromBytes(shrs) + for _, side := range []RowSide{Left, Right, Both} { + row, err := RowFromEDS(eds, rowIdx, side) require.NoError(t, err) - row := RowFromShares(shares, side) + require.Equal(t, side, row.side) + extended, err := row.Shares() require.NoError(t, err) - require.Equal(t, shares, extended) - - var half []libshare.Share - if side == Right { - half = shares[odsSize:] - } else { - half = shares[:odsSize] - } - require.Equal(t, half, row.halfShares) - require.Equal(t, side, row.side) + require.Len(t, extended, odsSize*2) + require.Equal(t, Both, row.side) } } } @@ -44,11 +36,9 @@ func TestRowValidate(t *testing.T) { require.NoError(t, err) for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { - for _, side := range []RowSide{Left, Right} { - shrs := eds.Row(uint(rowIdx)) - shares, err := libshare.FromBytes(shrs) + for _, side := range []RowSide{Left, Right, Both} { + row, err := RowFromEDS(eds, rowIdx, side) require.NoError(t, err) - row := RowFromShares(shares, side) err = row.Verify(root, rowIdx) require.NoError(t, err) @@ -65,10 +55,10 @@ func TestRowValidateNegativeCases(t *testing.T) { shrs := eds.Row(0) shares, err := libshare.FromBytes(shrs) require.NoError(t, err) - row := RowFromShares(shares, Left) + row := NewRow(shares, Left) // Test with incorrect side specification - invalidSideRow := Row{halfShares: row.halfShares, side: RowSide(999)} + invalidSideRow := Row{shares: row.shares, side: RowSide(999)} err = invalidSideRow.Verify(root, 0) require.Error(t, err, "should error on invalid row side") @@ -79,12 +69,12 @@ func TestRowValidateNegativeCases(t *testing.T) { require.NoError(t, err) incorrectShares[i] = *shr } - invalidRow := Row{halfShares: incorrectShares, side: Left} + invalidRow := Row{shares: incorrectShares, side: Left} err = invalidRow.Verify(root, 0) require.Error(t, err, "should error on incorrect number of shares") // Test with empty shares - emptyRow := Row{halfShares: []libshare.Share{}, side: Left} + emptyRow := Row{shares: []libshare.Share{}, side: Left} err = emptyRow.Verify(root, 0) require.Error(t, err, "should error on empty halfShares") @@ -99,16 +89,18 @@ func TestRowProtoEncoding(t *testing.T) { eds := edstest.RandEDS(t, odsSize) for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { - for _, side := range []RowSide{Left, Right} { - shrs := eds.Row(uint(rowIdx)) - shares, err := libshare.FromBytes(shrs) + for _, side := range []RowSide{Left, Right, Both} { + row, err := RowFromEDS(eds, rowIdx, side) require.NoError(t, err) - row := RowFromShares(shares, side) pb := row.ToProto() rowOut, err := RowFromProto(pb) require.NoError(t, err) - require.Equal(t, row, rowOut) + if side == Both { + require.NotEqual(t, row, rowOut) + } else { + require.Equal(t, row, rowOut) + } } } } @@ -120,10 +112,8 @@ func BenchmarkRowValidate(b *testing.B) { eds := edstest.RandEDS(b, odsSize) root, err := share.NewAxisRoots(eds) require.NoError(b, err) - shrs := eds.Row(0) - shares, err := libshare.FromBytes(shrs) + row, err := RowFromEDS(eds, 0, Left) require.NoError(b, err) - row := RowFromShares(shares, Left) b.ResetTimer() for i := 0; i < b.N; i++ {