From 14d878d549f9de415c90ae9cb4e5c67ebb2d97d3 Mon Sep 17 00:00:00 2001 From: son trinh Date: Wed, 13 Dec 2023 03:11:10 +0700 Subject: [PATCH] test: remove test cases that vary based on codec (#274) ## Overview Closed: [#216](https://github.com/celestiaorg/rsmt2d/issues/216) ## Checklist - [ ] New and updated code has appropriate documentation - [x] New and updated code has new and/or updated testing - [x] Required CI checks are passing - [ ] Visual proof for any user facing features like CLI or documentation updates - [x] Linked issues closed with keywords --- extendeddatacrossword_test.go | 440 ++++++++++++++++------------------ 1 file changed, 203 insertions(+), 237 deletions(-) diff --git a/extendeddatacrossword_test.go b/extendeddatacrossword_test.go index 2a840fa..6266b79 100644 --- a/extendeddatacrossword_test.go +++ b/extendeddatacrossword_test.go @@ -24,156 +24,127 @@ type PseudoFraudProof struct { } func TestRepairExtendedDataSquare(t *testing.T) { - tests := []struct { - name string - codec Codec - }{ - {"leopard", NewLeoRSCodec()}, - } + codec := NewLeoRSCodec() + original := createTestEds(codec, shareSize) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - name, codec := test.name, test.codec - original := createTestEds(codec, shareSize) - - rowRoots, err := original.RowRoots() - require.NoError(t, err) - - colRoots, err := original.ColRoots() - require.NoError(t, err) - - // Verify that an EDS can be repaired after the maximum amount of erasures - t.Run("MaximumErasures", func(t *testing.T) { - flattened := original.Flattened() - flattened[0], flattened[2], flattened[3] = nil, nil, nil - flattened[4], flattened[5], flattened[6], flattened[7] = nil, nil, nil, nil - flattened[8], flattened[9], flattened[10] = nil, nil, nil - flattened[12], flattened[13] = nil, nil - - // Re-import the data square. - eds, err := ImportExtendedDataSquare(flattened, codec, NewDefaultTree) - if err != nil { - t.Errorf("ImportExtendedDataSquare failed: %v", err) - } + rowRoots, err := original.RowRoots() + require.NoError(t, err) - err = eds.Repair(rowRoots, colRoots) - if err != nil { - t.Errorf("unexpected err while repairing data square: %v, codec: :%s", err, name) - } else { - assert.Equal(t, original.GetCell(0, 0), bytes.Repeat([]byte{1}, shareSize)) - assert.Equal(t, original.GetCell(0, 1), bytes.Repeat([]byte{2}, shareSize)) - assert.Equal(t, original.GetCell(1, 0), bytes.Repeat([]byte{3}, shareSize)) - assert.Equal(t, original.GetCell(1, 1), bytes.Repeat([]byte{4}, shareSize)) - } - }) - - // Verify that an EDS returns an error when there are too many erasures - t.Run("Unrepairable", func(t *testing.T) { - flattened := original.Flattened() - flattened[0], flattened[2], flattened[3] = nil, nil, nil - flattened[4], flattened[5], flattened[6], flattened[7] = nil, nil, nil, nil - flattened[8], flattened[9], flattened[10] = nil, nil, nil - flattened[12], flattened[13], flattened[14] = nil, nil, nil - - // Re-import the data square. - eds, err := ImportExtendedDataSquare(flattened, codec, NewDefaultTree) - if err != nil { - t.Errorf("ImportExtendedDataSquare failed: %v", err) - } + colRoots, err := original.ColRoots() + require.NoError(t, err) - err = eds.Repair(rowRoots, colRoots) - if err != ErrUnrepairableDataSquare { - t.Errorf("did not return an error on trying to repair an unrepairable square") - } - }) - }) - } + // Verify that an EDS can be repaired after the maximum amount of erasures + t.Run("MaximumErasures", func(t *testing.T) { + flattened := original.Flattened() + flattened[0], flattened[2], flattened[3] = nil, nil, nil + flattened[4], flattened[5], flattened[6], flattened[7] = nil, nil, nil, nil + flattened[8], flattened[9], flattened[10] = nil, nil, nil + flattened[12], flattened[13] = nil, nil + + // Re-import the data square. + eds, err := ImportExtendedDataSquare(flattened, codec, NewDefaultTree) + if err != nil { + t.Errorf("ImportExtendedDataSquare failed: %v", err) + } + + err = eds.Repair(rowRoots, colRoots) + if err != nil { + t.Errorf("unexpected err while repairing data square: %v, codec: :%s", err, codec.Name()) + } else { + assert.Equal(t, original.GetCell(0, 0), bytes.Repeat([]byte{1}, shareSize)) + assert.Equal(t, original.GetCell(0, 1), bytes.Repeat([]byte{2}, shareSize)) + assert.Equal(t, original.GetCell(1, 0), bytes.Repeat([]byte{3}, shareSize)) + assert.Equal(t, original.GetCell(1, 1), bytes.Repeat([]byte{4}, shareSize)) + } + }) + + // Verify that an EDS returns an error when there are too many erasures + t.Run("Unrepairable", func(t *testing.T) { + flattened := original.Flattened() + flattened[0], flattened[2], flattened[3] = nil, nil, nil + flattened[4], flattened[5], flattened[6], flattened[7] = nil, nil, nil, nil + flattened[8], flattened[9], flattened[10] = nil, nil, nil + flattened[12], flattened[13], flattened[14] = nil, nil, nil + + // Re-import the data square. + eds, err := ImportExtendedDataSquare(flattened, codec, NewDefaultTree) + if err != nil { + t.Errorf("ImportExtendedDataSquare failed: %v", err) + } + + err = eds.Repair(rowRoots, colRoots) + if err != ErrUnrepairableDataSquare { + t.Errorf("did not return an error on trying to repair an unrepairable square") + } + }) } func TestValidFraudProof(t *testing.T) { + codec := NewLeoRSCodec() + corruptChunk := bytes.Repeat([]byte{66}, shareSize) - tests := []struct { - name string - codec Codec - }{ - {"leopard", NewLeoRSCodec()}, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - name, codec := test.name, test.codec - original := createTestEds(codec, shareSize) + original := createTestEds(codec, shareSize) - var byzData *ErrByzantineData - corrupted, err := original.deepCopy(codec) - if err != nil { - t.Fatalf("unexpected err while copying original data: %v, codec: :%s", err, name) - } - corrupted.setCell(0, 0, corruptChunk) - assert.NoError(t, err) + var byzData *ErrByzantineData + corrupted, err := original.deepCopy(codec) + if err != nil { + t.Fatalf("unexpected err while copying original data: %v, codec: :%s", err, codec.Name()) + } + corrupted.setCell(0, 0, corruptChunk) + assert.NoError(t, err) - rowRoots, err := corrupted.getRowRoots() - assert.NoError(t, err) + rowRoots, err := corrupted.getRowRoots() + assert.NoError(t, err) - colRoots, err := corrupted.getColRoots() - assert.NoError(t, err) + colRoots, err := corrupted.getColRoots() + assert.NoError(t, err) - err = corrupted.Repair(rowRoots, colRoots) - errors.As(err, &byzData) + err = corrupted.Repair(rowRoots, colRoots) + errors.As(err, &byzData) - // Construct the fraud proof - fraudProof := PseudoFraudProof{0, byzData.Index, byzData.Shares} - // Verify the fraud proof - // TODO in a real fraud proof, also verify Merkle proof for each non-nil share. - rebuiltShares, err := codec.Decode(fraudProof.Shares) - if err != nil { - t.Errorf("could not decode fraud proof shares; got: %v", err) - } - root, err := corrupted.computeSharesRoot(rebuiltShares, byzData.Axis, fraudProof.Index) - assert.NoError(t, err) - rowRoot, err := corrupted.getRowRoot(fraudProof.Index) - assert.NoError(t, err) - if bytes.Equal(root, rowRoot) { - // If the roots match, then the fraud proof should be for invalid erasure coding. - parityShares, err := codec.Encode(rebuiltShares[0:corrupted.originalDataWidth]) - if err != nil { - t.Errorf("could not encode fraud proof shares; %v", fraudProof) - } - startIndex := len(rebuiltShares) - int(corrupted.originalDataWidth) - if bytes.Equal(flattenChunks(parityShares), flattenChunks(rebuiltShares[startIndex:])) { - t.Errorf("invalid fraud proof %v", fraudProof) - } - } - }) + // Construct the fraud proof + fraudProof := PseudoFraudProof{0, byzData.Index, byzData.Shares} + // Verify the fraud proof + // TODO in a real fraud proof, also verify Merkle proof for each non-nil share. + rebuiltShares, err := codec.Decode(fraudProof.Shares) + if err != nil { + t.Errorf("could not decode fraud proof shares; got: %v", err) + } + root, err := corrupted.computeSharesRoot(rebuiltShares, byzData.Axis, fraudProof.Index) + assert.NoError(t, err) + rowRoot, err := corrupted.getRowRoot(fraudProof.Index) + assert.NoError(t, err) + if bytes.Equal(root, rowRoot) { + // If the roots match, then the fraud proof should be for invalid erasure coding. + parityShares, err := codec.Encode(rebuiltShares[0:corrupted.originalDataWidth]) + if err != nil { + t.Errorf("could not encode fraud proof shares; %v", fraudProof) + } + startIndex := len(rebuiltShares) - int(corrupted.originalDataWidth) + if bytes.Equal(flattenChunks(parityShares), flattenChunks(rebuiltShares[startIndex:])) { + t.Errorf("invalid fraud proof %v", fraudProof) + } } } func TestCannotRepairSquareWithBadRoots(t *testing.T) { - corruptChunk := bytes.Repeat([]byte{66}, shareSize) - tests := []struct { - name string - codec Codec - }{ - {"leopard", NewLeoRSCodec()}, - } + codec := NewLeoRSCodec() - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - original := createTestEds(test.codec, shareSize) + corruptChunk := bytes.Repeat([]byte{66}, shareSize) + original := createTestEds(codec, shareSize) - rowRoots, err := original.RowRoots() - require.NoError(t, err) + rowRoots, err := original.RowRoots() + require.NoError(t, err) - colRoots, err := original.ColRoots() - require.NoError(t, err) + colRoots, err := original.ColRoots() + require.NoError(t, err) - original.setCell(0, 0, corruptChunk) - require.NoError(t, err) - err = original.Repair(rowRoots, colRoots) - if err == nil { - t.Errorf("did not return an error on trying to repair a square with bad roots") - } - }) + original.setCell(0, 0, corruptChunk) + require.NoError(t, err) + err = original.Repair(rowRoots, colRoots) + if err == nil { + t.Errorf("did not return an error on trying to repair a square with bad roots") } } @@ -223,36 +194,34 @@ func TestCorruptedEdsReturnsErrByzantineData(t *testing.T) { }, } - for codecName, codec := range codecs { - t.Run(codecName, func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - eds := createTestEds(codec, shareSize) - - // compute the rowRoots prior to corruption - rowRoots, err := eds.getRowRoots() - assert.NoError(t, err) + codec := NewLeoRSCodec() - // compute the colRoots prior to corruption - colRoots, err := eds.getColRoots() - assert.NoError(t, err) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + eds := createTestEds(codec, shareSize) - for i, coords := range test.coords { - x := coords[0] - y := coords[1] - eds.setCell(x, y, test.values[i]) - } + // compute the rowRoots prior to corruption + rowRoots, err := eds.getRowRoots() + assert.NoError(t, err) - err = eds.Repair(rowRoots, colRoots) - assert.Error(t, err) + // compute the colRoots prior to corruption + colRoots, err := eds.getColRoots() + assert.NoError(t, err) - // due to parallelisation, the ErrByzantineData axis may be either row or col - var byzData *ErrByzantineData - assert.ErrorAs(t, err, &byzData, "did not return a ErrByzantineData for a bad col or row") - assert.NotEmpty(t, byzData.Shares) - assert.Contains(t, byzData.Shares, corruptChunk) - }) + for i, coords := range test.coords { + x := coords[0] + y := coords[1] + eds.setCell(x, y, test.values[i]) } + + err = eds.Repair(rowRoots, colRoots) + assert.Error(t, err) + + // due to parallelisation, the ErrByzantineData axis may be either row or col + var byzData *ErrByzantineData + assert.ErrorAs(t, err, &byzData, "did not return a ErrByzantineData for a bad col or row") + assert.NotEmpty(t, byzData.Shares) + assert.Contains(t, byzData.Shares, corruptChunk) }) } } @@ -260,67 +229,66 @@ func TestCorruptedEdsReturnsErrByzantineData(t *testing.T) { func BenchmarkRepair(b *testing.B) { // For different ODS sizes for originalDataWidth := 4; originalDataWidth <= 512; originalDataWidth *= 2 { - for codecName, codec := range codecs { - if codec.MaxChunks() < originalDataWidth*originalDataWidth { - // Only test codecs that support this many chunks - continue - } + codec := NewLeoRSCodec() + if codec.MaxChunks() < originalDataWidth*originalDataWidth { + // Only test codecs that support this many chunks + continue + } - // Generate a new range original data square then extend it - square := genRandDS(originalDataWidth, shareSize) - eds, err := ComputeExtendedDataSquare(square, codec, NewDefaultTree) - if err != nil { - b.Error(err) - } + // Generate a new range original data square then extend it + square := genRandDS(originalDataWidth, shareSize) + eds, err := ComputeExtendedDataSquare(square, codec, NewDefaultTree) + if err != nil { + b.Error(err) + } - extendedDataWidth := originalDataWidth * 2 - rowRoots, err := eds.RowRoots() - assert.NoError(b, err) - - colRoots, err := eds.ColRoots() - assert.NoError(b, err) - - b.Run( - fmt.Sprintf( - "%s %dx%dx%d ODS", - codecName, - originalDataWidth, - originalDataWidth, - len(square[0]), - ), - func(b *testing.B) { - for n := 0; n < b.N; n++ { - b.StopTimer() - - flattened := eds.Flattened() - // Randomly remove 1/2 of the shares of each row - for r := 0; r < extendedDataWidth; r++ { - for c := 0; c < originalDataWidth; { - ind := rand.Intn(extendedDataWidth) - if flattened[r*extendedDataWidth+ind] == nil { - continue - } - flattened[r*extendedDataWidth+ind] = nil - c++ + extendedDataWidth := originalDataWidth * 2 + rowRoots, err := eds.RowRoots() + assert.NoError(b, err) + + colRoots, err := eds.ColRoots() + assert.NoError(b, err) + + b.Run( + fmt.Sprintf( + "%s %dx%dx%d ODS", + codec.Name(), + originalDataWidth, + originalDataWidth, + len(square[0]), + ), + func(b *testing.B) { + for n := 0; n < b.N; n++ { + b.StopTimer() + + flattened := eds.Flattened() + // Randomly remove 1/2 of the shares of each row + for r := 0; r < extendedDataWidth; r++ { + for c := 0; c < originalDataWidth; { + ind := rand.Intn(extendedDataWidth) + if flattened[r*extendedDataWidth+ind] == nil { + continue } + flattened[r*extendedDataWidth+ind] = nil + c++ } + } - // Re-import the data square. - eds, _ = ImportExtendedDataSquare(flattened, codec, NewDefaultTree) + // Re-import the data square. + eds, _ = ImportExtendedDataSquare(flattened, codec, NewDefaultTree) - b.StartTimer() + b.StartTimer() - err := eds.Repair( - rowRoots, - colRoots, - ) - if err != nil { - b.Error(err) - } + err := eds.Repair( + rowRoots, + colRoots, + ) + if err != nil { + b.Error(err) } - }, - ) - } + } + }, + ) } } @@ -420,38 +388,36 @@ func TestCorruptedEdsReturnsErrByzantineData_UnorderedShares(t *testing.T) { }, } - for codecName, codec := range codecs { - t.Run(codecName, func(t *testing.T) { - // create a DA header - eds := createTestEdsWithNMT(t, codec, shareSize, namespaceSize, 1, 2, 3, 4) - assert.NotNil(t, eds) - dAHeaderRoots, err := eds.getRowRoots() - assert.NoError(t, err) + codec := NewLeoRSCodec() - dAHeaderCols, err := eds.getColRoots() - assert.NoError(t, err) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - // create an eds with the given shares - corruptEds := createTestEdsWithNMT(t, codec, shareSize, namespaceSize, sharesValue...) - assert.NotNil(t, corruptEds) - // corrupt it by setting the values at the given coordinates - for i, coords := range test.coords { - x := coords[0] - y := coords[1] - corruptEds.setCell(x, y, test.values[i]) - } + // create a DA header + eds := createTestEdsWithNMT(t, codec, shareSize, namespaceSize, 1, 2, 3, 4) + assert.NotNil(t, eds) + dAHeaderRoots, err := eds.getRowRoots() + assert.NoError(t, err) - err = corruptEds.Repair(dAHeaderRoots, dAHeaderCols) - assert.Equal(t, err != nil, test.wantErr) - if test.wantErr { - var byzErr *ErrByzantineData - assert.ErrorAs(t, err, &byzErr) - errors.As(err, &byzErr) - assert.Equal(t, byzErr.Axis, test.corruptedAxis) - assert.Equal(t, byzErr.Index, test.corruptedIndex) - } - }) + dAHeaderCols, err := eds.getColRoots() + assert.NoError(t, err) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // create an eds with the given shares + corruptEds := createTestEdsWithNMT(t, codec, shareSize, namespaceSize, sharesValue...) + assert.NotNil(t, corruptEds) + // corrupt it by setting the values at the given coordinates + for i, coords := range test.coords { + x := coords[0] + y := coords[1] + corruptEds.setCell(x, y, test.values[i]) + } + + err = corruptEds.Repair(dAHeaderRoots, dAHeaderCols) + assert.Equal(t, err != nil, test.wantErr) + if test.wantErr { + var byzErr *ErrByzantineData + assert.ErrorAs(t, err, &byzErr) + errors.As(err, &byzErr) + assert.Equal(t, byzErr.Axis, test.corruptedAxis) + assert.Equal(t, byzErr.Index, test.corruptedIndex) } }) }