Skip to content

Commit

Permalink
Fix split OOB and zeroing (#238)
Browse files Browse the repository at this point in the history
* Fix split OOB and zeroing

Fix invalid slicing in #237

Zero values in shards taken from capacity of data shards.

* Add more tests
* Fix swapped params for Go fallback.
* Tweak default tests
* Add conservative retraction
  • Loading branch information
klauspost authored Feb 6, 2023
1 parent 7e59db9 commit 9dac1a1
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 73 deletions.
2 changes: 1 addition & 1 deletion galois_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ func mulAdd8(x, y []byte, log_m ffe8, o *options) {
y = y[done:]
x = x[done:]
}
refMulAdd8(y, x, log_m)
refMulAdd8(x, y, log_m)
}

// 2-way butterfly
Expand Down
7 changes: 5 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,8 @@ require github.com/klauspost/cpuid/v2 v2.1.1

require golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e // indirect

// https://github.com/klauspost/reedsolomon/pull/229
retract v1.11.2

retract (
v1.11.2 // https://github.com/klauspost/reedsolomon/pull/229
[v1.11.3, v1.11.5] // https://github.com/klauspost/reedsolomon/pull/238
)
26 changes: 19 additions & 7 deletions leopard.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,23 +279,35 @@ func (r *leopardFF16) Split(data []byte) ([][]byte, error) {
// Calculate number of bytes per data shard.
perShard := (len(data) + r.dataShards - 1) / r.dataShards
perShard = ((perShard + 63) / 64) * 64
needTotal := r.totalShards * perShard

if cap(data) > len(data) {
data = data[:cap(data)]
if cap(data) > needTotal {
data = data[:needTotal]
} else {
data = data[:cap(data)]
}
clear := data[dataLen:]
for i := range clear {
clear[i] = 0
}
}

// Only allocate memory if necessary
var padding [][]byte
if len(data) < (r.totalShards * perShard) {
if len(data) < needTotal {
// calculate maximum number of full shards in `data` slice
fullShards := len(data) / perShard
padding = AllocAligned(r.totalShards-fullShards, perShard)
copyFrom := data[perShard*fullShards : dataLen]
for i := range padding {
if len(copyFrom) <= 0 {
break
if dataLen > perShard*fullShards {
// Copy partial shards
copyFrom := data[perShard*fullShards : dataLen]
for i := range padding {
if len(copyFrom) <= 0 {
break
}
copyFrom = copyFrom[copy(padding[i], copyFrom):]
}
copyFrom = copyFrom[copy(padding[i], copyFrom):]
}
} else {
zero := data[dataLen : r.totalShards*perShard]
Expand Down
33 changes: 20 additions & 13 deletions leopard8.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,28 +320,35 @@ func (r *leopardFF8) Split(data []byte) ([][]byte, error) {
// Calculate number of bytes per data shard.
perShard := (len(data) + r.dataShards - 1) / r.dataShards
perShard = ((perShard + 63) / 64) * 64
needTotal := r.totalShards * perShard

if cap(data) > len(data) {
data = data[:cap(data)]
if cap(data) > needTotal {
data = data[:needTotal]
} else {
data = data[:cap(data)]
}
clear := data[dataLen:]
for i := range clear {
clear[i] = 0
}
}

// Only allocate memory if necessary
var padding [][]byte
if len(data) < (r.totalShards * perShard) {
if len(data) < needTotal {
// calculate maximum number of full shards in `data` slice
fullShards := len(data) / perShard
padding = AllocAligned(r.totalShards-fullShards, perShard)
copyFrom := data[perShard*fullShards : dataLen]
for i := range padding {
if len(copyFrom) <= 0 {
break
if dataLen > perShard*fullShards {
// Copy partial shards
copyFrom := data[perShard*fullShards : dataLen]
for i := range padding {
if len(copyFrom) <= 0 {
break
}
copyFrom = copyFrom[copy(padding[i], copyFrom):]
}
copyFrom = copyFrom[copy(padding[i], copyFrom):]
}
} else {
zero := data[dataLen : r.totalShards*perShard]
for i := range zero {
zero[i] = 0
}
}

Expand Down Expand Up @@ -877,7 +884,7 @@ func refMulAdd8(x, y []byte, log_m ffe8) {
for len(x) >= 64 {
// Assert sizes for no bounds checks in loop
src := y[:64]
dst := x[:64] // Needed, but not checked...
dst := x[:len(src)] // Needed, but not checked...
for i, y1 := range src {
dst[i] ^= byte(lut.Value[y1])
}
Expand Down
45 changes: 30 additions & 15 deletions reedsolomon.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,16 @@ type Encoder interface {
Update(shards [][]byte, newDatashards [][]byte) error

// Split a data slice into the number of shards given to the encoder,
// and create empty parity shards.
// and create empty parity shards if necessary.
//
// The data will be split into equally sized shards.
// If the data size isn't dividable by the number of shards,
// If the data size isn't divisible by the number of shards,
// the last shard will contain extra zeros.
//
// If there is extra capacity on the provided data slice
// it will be used instead of allocating parity shards.
// It will be zeroed out.
//
// There must be at least 1 byte otherwise ErrShortData will be
// returned.
//
Expand Down Expand Up @@ -1542,6 +1546,10 @@ var ErrShortData = errors.New("not enough data to fill the number of requested s
// If the data size isn't divisible by the number of shards,
// the last shard will contain extra zeros.
//
// If there is extra capacity on the provided data slice
// it will be used instead of allocating parity shards.
// It will be zeroed out.
//
// There must be at least 1 byte otherwise ErrShortData will be
// returned.
//
Expand All @@ -1558,29 +1566,36 @@ func (r *reedSolomon) Split(data []byte) ([][]byte, error) {
dataLen := len(data)
// Calculate number of bytes per data shard.
perShard := (len(data) + r.dataShards - 1) / r.dataShards
needTotal := r.totalShards * perShard

if cap(data) > len(data) {
data = data[:cap(data)]
if cap(data) > needTotal {
data = data[:needTotal]
} else {
data = data[:cap(data)]
}
clear := data[dataLen:]
for i := range clear {
clear[i] = 0
}
}

// Only allocate memory if necessary
var padding [][]byte
if len(data) < (r.totalShards * perShard) {
if len(data) < needTotal {
// calculate maximum number of full shards in `data` slice
fullShards := len(data) / perShard
padding = AllocAligned(r.totalShards-fullShards, perShard)
copyFrom := data[perShard*fullShards : dataLen]
for i := range padding {
if len(copyFrom) <= 0 {
break

if dataLen > perShard*fullShards {
// Copy partial shards
copyFrom := data[perShard*fullShards : dataLen]
for i := range padding {
if len(copyFrom) <= 0 {
break
}
copyFrom = copyFrom[copy(padding[i], copyFrom):]
}
copyFrom = copyFrom[copy(padding[i], copyFrom):]
}
data = data[0 : perShard*fullShards]
} else {
zero := data[dataLen : r.totalShards*perShard]
for i := range zero {
zero[i] = 0
}
}

Expand Down
109 changes: 74 additions & 35 deletions reedsolomon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func TestBuildMatrixPAR1Singular(t *testing.T) {
func testOpts() [][]Option {
if testing.Short() {
return [][]Option{
{WithPAR1Matrix()}, {WithCauchyMatrix()},
{WithCauchyMatrix()}, {WithLeopardGF16(true)}, {WithLeopardGF(true)},
}
}
opts := [][]Option{
Expand Down Expand Up @@ -1603,7 +1603,7 @@ func testEncoderReconstruct(t *testing.T, o ...Option) {
fillRandom(data)

// Create 5 data slices of 50000 elements each
enc, err := New(5, 3, testOptions(o...)...)
enc, err := New(7, 6, testOptions(o...)...)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1675,43 +1675,82 @@ func testEncoderReconstruct(t *testing.T, o ...Option) {
}

func TestSplitJoin(t *testing.T) {
var data = make([]byte, 250000)
fillRandom(data)

enc, _ := New(5, 3, testOptions()...)
shards, err := enc.Split(data)
if err != nil {
t.Fatal(err)
}

_, err = enc.Split([]byte{})
if err != ErrShortData {
t.Errorf("expected %v, got %v", ErrShortData, err)
}
opts := [][]Option{
testOptions(),
append(testOptions(), WithLeopardGF(true)),
append(testOptions(), WithLeopardGF16(true)),
}
for i, opts := range opts {
t.Run("opt-"+strconv.Itoa(i), func(t *testing.T) {
for _, dp := range [][2]int{{1, 0}, {5, 0}, {5, 1}, {12, 4}, {2, 15}, {17, 1}} {
enc, _ := New(dp[0], dp[1], opts...)
ext := enc.(Extensions)

_, err := enc.Split([]byte{})
if err != ErrShortData {
t.Errorf("expected %v, got %v", ErrShortData, err)
}

buf := new(bytes.Buffer)
err = enc.Join(buf, shards, 50)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf.Bytes(), data[:50]) {
t.Fatal("recovered data does match original")
}
buf := new(bytes.Buffer)
err = enc.Join(buf, [][]byte{}, 0)
if err != ErrTooFewShards {
t.Errorf("expected %v, got %v", ErrTooFewShards, err)
}
for _, size := range []int{ext.DataShards(), 1337, 2699} {
for _, extra := range []int{0, 1, ext.ShardSizeMultiple(), ext.ShardSizeMultiple() * ext.DataShards(), ext.ShardSizeMultiple()*ext.ParityShards() + 1, 255} {
buf.Reset()
t.Run(fmt.Sprintf("d-%d-p-%d-sz-%d-cap%d", ext.DataShards(), ext.ParityShards(), size, extra), func(t *testing.T) {
var data = make([]byte, size, size+extra)
var ref = make([]byte, size, size)
fillRandom(data)
copy(ref, data)

shards, err := enc.Split(data)
if err != nil {
t.Fatal(err)
}
err = enc.Encode(shards)
if err != nil {
t.Fatal(err)
}
_, err = enc.Verify(shards)
if err != nil {
t.Fatal(err)
}
for i := range shards[:ext.ParityShards()] {
// delete data shards up to parity
shards[i] = nil
}
err = enc.Reconstruct(shards)
if err != nil {
t.Fatal(err)
}

err = enc.Join(buf, [][]byte{}, 0)
if err != ErrTooFewShards {
t.Errorf("expected %v, got %v", ErrTooFewShards, err)
}
// Rejoin....
err = enc.Join(buf, shards, size)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf.Bytes(), ref) {
t.Log("")
t.Fatal("recovered data does match original")
}

err = enc.Join(buf, shards, len(data)+1)
if err != ErrShortData {
t.Errorf("expected %v, got %v", ErrShortData, err)
}
err = enc.Join(buf, shards, len(data)+ext.DataShards()*ext.ShardSizeMultiple())
if err != ErrShortData {
t.Errorf("expected %v, got %v", ErrShortData, err)
}

shards[0] = nil
err = enc.Join(buf, shards, len(data))
if err != ErrReconstructRequired {
t.Errorf("expected %v, got %v", ErrReconstructRequired, err)
shards[0] = nil
err = enc.Join(buf, shards, len(data))
if err != ErrReconstructRequired {
t.Errorf("expected %v, got %v", ErrReconstructRequired, err)
}
})
}
}
}
})
}
}

Expand Down

0 comments on commit 9dac1a1

Please sign in to comment.