Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(shwap): cache Both Row sides #4005

Merged
merged 3 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions share/eds/rsmt2d.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,7 @@ func (eds *Rsmt2D) AxisHalf(_ context.Context, axisType rsmt2d.Axis, axisIdx int
// HalfRow constructs a new shwap.Row from an Extended Data Square based on the specified index and
// side.
func (eds *Rsmt2D) HalfRow(idx int, side shwap.RowSide) (shwap.Row, error) {
shares := eds.ExtendedDataSquare.Row(uint(idx))
sh, err := libshare.FromBytes(shares)
if err != nil {
return shwap.Row{}, fmt.Errorf("while converting shares from bytes: %w", err)
}
return shwap.RowFromShares(sh, side), nil
return shwap.RowFromEDS(eds.ExtendedDataSquare, idx, side)
}

// RowNamespaceData returns data for the given namespace and row index.
Expand Down
97 changes: 64 additions & 33 deletions share/shwap/row.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/celestiaorg/celestia-app/v3/pkg/wrapper"
libshare "github.com/celestiaorg/go-square/v2/share"
"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/shwap/pb"
Expand All @@ -20,33 +21,42 @@ type RowSide int
const (
Left RowSide = iota // Left side of the row.
Right // Right side of the row.
Both // Both sides of the row.
)

// Row represents a portion of a row in an EDS, either left or right half.
type Row struct {
halfShares []libshare.Share // halfShares holds the shares of either the left or right half of a row.
side RowSide // side indicates whether the row half is left or right.
shares []libshare.Share // holds the shares of Left or Right or Both sides of the row.
side RowSide // side indicates which side the shares belong to.
}

// NewRow creates a new Row with the specified shares and side.
func NewRow(halfShares []libshare.Share, side RowSide) Row {
func NewRow(shares []libshare.Share, side RowSide) Row {
return Row{
halfShares: halfShares,
side: side,
shares: shares,
side: side,
}
}

// RowFromShares constructs a new Row from an Extended Data Square based on the specified index and
// side.
func RowFromShares(shares []libshare.Share, side RowSide) Row {
var halfShares []libshare.Share
if side == Right {
halfShares = shares[len(shares)/2:] // Take the right half of the shares.
} else {
halfShares = shares[:len(shares)/2] // Take the left half of the shares.
// RowFromEDS constructs a new Row from an EDS based on the specified row index and side.
func RowFromEDS(eds *rsmt2d.ExtendedDataSquare, rowIdx int, side RowSide) (Row, error) {
rowBytes := eds.Row(uint(rowIdx))
shares, err := libshare.FromBytes(rowBytes)
if err != nil {
return Row{}, fmt.Errorf("while converting shares from bytes: %w", err)
}

switch side {
case Both:
case Left:
shares = shares[:len(shares)/2]
case Right:
shares = shares[len(shares)/2:]
default:
return Row{}, fmt.Errorf("invalid RowSide: %d", side)
}

return NewRow(halfShares, side)
return NewRow(shares, side), nil
}

// RowFromProto converts a protobuf Row to a Row structure.
Expand All @@ -56,54 +66,75 @@ func RowFromProto(r *pb.Row) (Row, error) {
return Row{}, err
}
return Row{
halfShares: shrs,
side: sideFromProto(r.GetHalfSide()),
shares: shrs,
side: sideFromProto(r.GetHalfSide()),
}, nil
}

// Shares reconstructs the complete row shares from the half provided, using RSMT2D for data
// recovery if needed.
func (r Row) Shares() ([]libshare.Share, error) {
shares := make([]libshare.Share, len(r.halfShares)*2)
offset := 0
if r.side == Right {
offset = len(r.halfShares) // Position the halfShares in the second half if it's the right side.
// It caches the reconstructed shares for future use and converts Row to Both side.
func (r *Row) Shares() ([]libshare.Share, error) {
if r.side == Both {
return r.shares, nil
}
for i, share := range r.halfShares {

shares := make([]libshare.Share, len(r.shares)*2)
offset := len(r.shares) * int(r.side)
for i, share := range r.shares {
shares[i+offset] = share
}

rowShares, err := share.DefaultRSMT2DCodec().Decode(libshare.ToBytes(shares))
if err != nil {
return nil, err
}
return libshare.FromBytes(rowShares)

r.shares, err = libshare.FromBytes(rowShares)
if err != nil {
return nil, err
}

r.side = Both
return r.shares, nil
}

// ToProto converts the Row to its protobuf representation.
func (r Row) ToProto() *pb.Row {
if r.side == Both {
// we don't need to send the whole row over the wire
// so if we have both sides, we can save bandwidth and send the left half only
return &pb.Row{
SharesHalf: SharesToProto(r.shares[:len(r.shares)/2]),
HalfSide: pb.Row_LEFT,
}
}

return &pb.Row{
SharesHalf: SharesToProto(r.halfShares),
SharesHalf: SharesToProto(r.shares),
HalfSide: r.side.ToProto(),
}
}

// IsEmpty reports whether the Row is empty, i.e. doesn't contain any shares.
func (r Row) IsEmpty() bool {
return r.halfShares == nil
return len(r.shares) == 0
}

// Verify checks if the row's shares match the expected number from the root data and validates
// the side of the row.
func (r Row) Verify(roots *share.AxisRoots, idx int) error {
if len(r.halfShares) == 0 {
return fmt.Errorf("empty half row")
func (r *Row) Verify(roots *share.AxisRoots, idx int) error {
if len(r.shares) == 0 {
return fmt.Errorf("empt row")
}
expectedShares := len(roots.RowRoots)
if r.side != Both {
expectedShares /= 2
}
expectedShares := len(roots.RowRoots) / 2
if len(r.halfShares) != expectedShares {
return fmt.Errorf("shares size doesn't match root size: %d != %d", len(r.halfShares), expectedShares)
if len(r.shares) != expectedShares {
return fmt.Errorf("shares size doesn't match root size: %d != %d", len(r.shares), expectedShares)
}
if r.side != Left && r.side != Right {
if r.side != Left && r.side != Right && r.side != Both {
return fmt.Errorf("invalid RowSide: %d", r.side)
}

Expand All @@ -115,7 +146,7 @@ func (r Row) Verify(roots *share.AxisRoots, idx int) error {

// verifyInclusion verifies the integrity of the row's shares against the provided root hash for the
// given row index.
func (r Row) verifyInclusion(roots *share.AxisRoots, idx int) error {
func (r *Row) verifyInclusion(roots *share.AxisRoots, idx int) error {
shrs, err := r.Shares()
if err != nil {
return fmt.Errorf("while extending shares: %w", err)
Expand Down
52 changes: 21 additions & 31 deletions share/shwap/row_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,20 @@ import (
"github.com/celestiaorg/celestia-node/share/eds/edstest"
)

func TestRowFromShares(t *testing.T) {
func TestRowShares(t *testing.T) {
const odsSize = 8
eds := edstest.RandEDS(t, odsSize)

for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ {
for _, side := range []RowSide{Left, Right} {
shrs := eds.Row(uint(rowIdx))
shares, err := libshare.FromBytes(shrs)
for _, side := range []RowSide{Left, Right, Both} {
row, err := RowFromEDS(eds, rowIdx, side)
require.NoError(t, err)
row := RowFromShares(shares, side)
require.Equal(t, side, row.side)

extended, err := row.Shares()
require.NoError(t, err)
require.Equal(t, shares, extended)

var half []libshare.Share
if side == Right {
half = shares[odsSize:]
} else {
half = shares[:odsSize]
}
require.Equal(t, half, row.halfShares)
require.Equal(t, side, row.side)
require.Len(t, extended, odsSize*2)
require.Equal(t, Both, row.side)
}
}
}
Expand All @@ -44,11 +36,9 @@ func TestRowValidate(t *testing.T) {
require.NoError(t, err)

for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ {
for _, side := range []RowSide{Left, Right} {
shrs := eds.Row(uint(rowIdx))
shares, err := libshare.FromBytes(shrs)
for _, side := range []RowSide{Left, Right, Both} {
row, err := RowFromEDS(eds, rowIdx, side)
require.NoError(t, err)
row := RowFromShares(shares, side)

err = row.Verify(root, rowIdx)
require.NoError(t, err)
Expand All @@ -65,10 +55,10 @@ func TestRowValidateNegativeCases(t *testing.T) {
shrs := eds.Row(0)
shares, err := libshare.FromBytes(shrs)
require.NoError(t, err)
row := RowFromShares(shares, Left)
row := NewRow(shares, Left)

// Test with incorrect side specification
invalidSideRow := Row{halfShares: row.halfShares, side: RowSide(999)}
invalidSideRow := Row{shares: row.shares, side: RowSide(999)}
err = invalidSideRow.Verify(root, 0)
require.Error(t, err, "should error on invalid row side")

Expand All @@ -79,12 +69,12 @@ func TestRowValidateNegativeCases(t *testing.T) {
require.NoError(t, err)
incorrectShares[i] = *shr
}
invalidRow := Row{halfShares: incorrectShares, side: Left}
invalidRow := Row{shares: incorrectShares, side: Left}
err = invalidRow.Verify(root, 0)
require.Error(t, err, "should error on incorrect number of shares")

// Test with empty shares
emptyRow := Row{halfShares: []libshare.Share{}, side: Left}
emptyRow := Row{shares: []libshare.Share{}, side: Left}
err = emptyRow.Verify(root, 0)
require.Error(t, err, "should error on empty halfShares")

Expand All @@ -99,16 +89,18 @@ func TestRowProtoEncoding(t *testing.T) {
eds := edstest.RandEDS(t, odsSize)

for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ {
for _, side := range []RowSide{Left, Right} {
shrs := eds.Row(uint(rowIdx))
shares, err := libshare.FromBytes(shrs)
for _, side := range []RowSide{Left, Right, Both} {
row, err := RowFromEDS(eds, rowIdx, side)
require.NoError(t, err)
row := RowFromShares(shares, side)

pb := row.ToProto()
rowOut, err := RowFromProto(pb)
require.NoError(t, err)
require.Equal(t, row, rowOut)
if side == Both {
require.NotEqual(t, row, rowOut)
} else {
require.Equal(t, row, rowOut)
}
}
}
}
Expand All @@ -120,10 +112,8 @@ func BenchmarkRowValidate(b *testing.B) {
eds := edstest.RandEDS(b, odsSize)
root, err := share.NewAxisRoots(eds)
require.NoError(b, err)
shrs := eds.Row(0)
shares, err := libshare.FromBytes(shrs)
row, err := RowFromEDS(eds, 0, Left)
require.NoError(b, err)
row := RowFromShares(shares, Left)

b.ResetTimer()
for i := 0; i < b.N; i++ {
Expand Down
Loading