From f0fd031b07f841892c0edd454687a83aea5139b2 Mon Sep 17 00:00:00 2001 From: Paul Holzinger Date: Thu, 23 Jan 2025 18:49:30 +0100 Subject: [PATCH] pkg/detach: fix broken Copy() detach sequence The code only could handle the detach sequence when it read one byte at a time. This obviously is not correct and lead to some issues for my automated test in my podman PR[1] where I added some automated tests for detaching and the read part is really undefined and depends on the input side/kernel scheduling on how much we read at once. This is large rework to make the code check for the key sequence across the entire buffer. That is of course more work but it needs to happen for this to work correctly. I guess the only reason why this was never noticed is because normally user detach manually by typing and not in an automated way which is much slower and thus likely reads the bytes one by one. I added new test to actually confirm the behavior. And to ensure this works with various read sizes I made it a fuzz test. I had this running for a while and did not spot any issues there. The old code fails already on the simple test cases. [1] https://github.com/containers/podman/pull/25083 Signed-off-by: Paul Holzinger --- pkg/detach/copy.go | 100 ++++++++++++++++++++---------- pkg/detach/copy_test.go | 132 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 201 insertions(+), 31 deletions(-) create mode 100644 pkg/detach/copy_test.go diff --git a/pkg/detach/copy.go b/pkg/detach/copy.go index 97a250e44..a0ca101c5 100644 --- a/pkg/detach/copy.go +++ b/pkg/detach/copy.go @@ -1,6 +1,7 @@ package detach import ( + "bytes" "errors" "io" ) @@ -11,47 +12,84 @@ var ErrDetach = errors.New("detached from container") // Copy is similar to io.Copy but support a detach key sequence to break out. func Copy(dst io.Writer, src io.Reader, keys []byte) (written int64, err error) { + // if no key sequence we can use the fast std lib implementation + if len(keys) == 0 { + return io.Copy(dst, src) + } buf := make([]byte, 32*1024) + + // When key 1 is on one read and the key2 on the second read we cannot do a normal full match in the buffer. + // Thus we use this index to store where in keys we matched on the end of the previous buffer. + keySequenceIndex := 0 +outer: for { nr, er := src.Read(buf) - if nr > 0 { - preservBuf := []byte{} - for i, key := range keys { - preservBuf = append(preservBuf, buf[0:nr]...) - if nr != 1 || buf[0] != key { - break - } - if i == len(keys)-1 { - return 0, ErrDetach + // Do not check error right away, i.e. if we have EOF this code still must flush the last partial key sequence first. + // Previous key index, if the last buffer ended with the start of the key sequence + // then we must continue looking here. + if keySequenceIndex > 0 { + bytesToCheck := min(nr, len(keys)-keySequenceIndex) + if bytes.Equal(buf[:bytesToCheck], keys[keySequenceIndex:keySequenceIndex+bytesToCheck]) { + if keySequenceIndex+bytesToCheck == len(keys) { + // we are done + return written, ErrDetach } - nr, er = src.Read(buf) - } - var nw int - var ew error - if len(preservBuf) > 0 { - nw, ew = dst.Write(preservBuf) - nr = len(preservBuf) - } else { - nw, ew = dst.Write(buf[0:nr]) - } - if nw > 0 { - written += int64(nw) + // still not at the end of the sequence, must continue to read + keySequenceIndex += bytesToCheck + continue outer } + // No match, write buffered keys now + nw, ew := dst.Write(keys[:keySequenceIndex]) if ew != nil { - err = ew - break - } - if nr != nw { - err = io.ErrShortWrite - break + return written, ew } + written += int64(nw) + keySequenceIndex = 0 } + + // Now we can handle and return the error. if er != nil { - if er != io.EOF { - err = er + if er == io.EOF { + return written, nil } - break + return written, err + } + + // Check buffer from 0 to end - sequence length (after that there cannot be a full match), + // then walk the entire buffer and try to perform a full sequence match. + readMinusKeys := nr - len(keys) + for i := range readMinusKeys { + if bytes.Equal(buf[i:i+len(keys)], keys) { + if i > 0 { + nw, ew := dst.Write(buf[:i]) + if ew != nil { + return written, ew + } + written += int64(nw) + } + return written, ErrDetach + } + } + + // Now read the rest of the buffer to the end and perform a partial match on the sequence. + // Note that readMinusKeys can be < 0 on reads smaller than sequence length. Thus we must + // ensure it is at least 0 otherwise the index access will cause a panic. + for i := max(readMinusKeys, 0); i < nr; i++ { + if bytes.Equal(buf[i:nr], keys[:nr-i]) { + nw, ew := dst.Write(buf[:i]) + if ew != nil { + return written, ew + } + written += int64(nw) + keySequenceIndex = nr - i + continue outer + } + } + + nw, ew := dst.Write(buf[:nr]) + if ew != nil { + return written, ew } + written += int64(nw) } - return written, err } diff --git a/pkg/detach/copy_test.go b/pkg/detach/copy_test.go new file mode 100644 index 000000000..fbaf7f43d --- /dev/null +++ b/pkg/detach/copy_test.go @@ -0,0 +1,132 @@ +package detach + +import ( + "bytes" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +var ( + smallBytes = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + bigBytes = []byte(strings.Repeat("0F", 32*1024+30)) +) + +func newCustomReader(buf *bytes.Buffer, readsize uint) *customReader { + return &customReader{ + inner: buf, + readsize: readsize, + } +} + +type customReader struct { + inner *bytes.Buffer + readsize uint +} + +func (c *customReader) Read(p []byte) (n int, err error) { + return c.inner.Read(p[:min(int(c.readsize), len(p))]) +} + +func FuzzCopy(f *testing.F) { + for _, i := range []uint{1, 2, 3, 5, 10, 100, 200, 1000, 1024, 32 * 1024} { + f.Add(i) + } + + f.Fuzz(func(t *testing.T, readSize uint) { + // 0 is not a valid read size + if readSize == 0 { + t.Skip() + } + + tests := []struct { + name string + from []byte + expected []byte + expectDetach bool + keys []byte + }{ + { + name: "small copy", + from: smallBytes, + expected: smallBytes, + keys: nil, + }, + { + name: "small copy with detach keys", + from: smallBytes, + expected: smallBytes, + keys: []byte{'A', 'B'}, + }, + { + name: "big copy", + from: bigBytes, + expected: bigBytes, + keys: nil, + }, + { + name: "big copy with detach keys", + from: bigBytes, + expected: bigBytes, + keys: []byte{'A', 'B'}, + }, + { + name: "simple detach 1 key", + from: append(smallBytes, 'A'), + expected: smallBytes, + expectDetach: true, + keys: []byte{'A'}, + }, + { + name: "simple detach 2 keys", + from: append(smallBytes, 'A', 'B'), + expected: smallBytes, + expectDetach: true, + keys: []byte{'A', 'B'}, + }, + { + name: "simple detach 3 keys", + from: append(smallBytes, 'A', 'B', 'C'), + expected: smallBytes, + expectDetach: true, + keys: []byte{'A', 'B', 'C'}, + }, + { + name: "detach early", + from: append(smallBytes, 'A', 'B', 'B', 'A'), + expected: smallBytes, + expectDetach: true, + keys: []byte{'A', 'B'}, + }, + { + name: "detach with partial match", + from: append(smallBytes, 'A', 'A', 'A', 'B'), + expected: append(smallBytes, 'A', 'A'), + expectDetach: true, + keys: []byte{'A', 'B'}, + }, + { + name: "big buffer detach with partial match", + from: append(bigBytes, 'A', 'A', 'A', 'B'), + expected: append(bigBytes, 'A', 'A'), + expectDetach: true, + keys: []byte{'A', 'B'}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dst := &bytes.Buffer{} + src := newCustomReader(bytes.NewBuffer(tt.from), readSize) + written, err := Copy(dst, src, tt.keys) + if tt.expectDetach { + assert.ErrorIs(t, err, ErrDetach) + } else { + assert.NoError(t, err) + } + assert.Equal(t, dst.Len(), int(written), "bytes written matches buffer") + assert.Equal(t, tt.expected, dst.Bytes(), "buffer matches") + }) + } + }) +}