From e192bd248f78e9b48e8b1846e1fbe6094caf0423 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Wed, 14 Feb 2024 11:09:00 -0800 Subject: [PATCH 01/12] Include fileSize to the call to Consume() Some future consumers will need to know the expected fileSize depending on implementation (e.g. unzip). This wires up basic support for adding the fileSize as an argument to Consume; the value is already available at the time Consume is called. --- pkg/consumer/consumer.go | 2 +- pkg/consumer/null.go | 2 +- pkg/consumer/tar_extractor.go | 2 +- pkg/consumer/write_file.go | 2 +- pkg/pget.go | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/consumer/consumer.go b/pkg/consumer/consumer.go index 39bedef..82675ee 100644 --- a/pkg/consumer/consumer.go +++ b/pkg/consumer/consumer.go @@ -3,5 +3,5 @@ package consumer import "io" type Consumer interface { - Consume(reader io.Reader, destPath string) error + Consume(reader io.Reader, destPath string, fileSize int64) error } diff --git a/pkg/consumer/null.go b/pkg/consumer/null.go index 6bb3836..9407d6c 100644 --- a/pkg/consumer/null.go +++ b/pkg/consumer/null.go @@ -8,7 +8,7 @@ type NullWriter struct{} var _ Consumer = &NullWriter{} -func (f *NullWriter) Consume(reader io.Reader, destPath string) error { +func (f *NullWriter) Consume(reader io.Reader, _ string, _ int64) error { // io.Discard is explicitly designed to always succeed, ignore errors. _, _ = io.Copy(io.Discard, reader) return nil diff --git a/pkg/consumer/tar_extractor.go b/pkg/consumer/tar_extractor.go index 012a1e6..220273d 100644 --- a/pkg/consumer/tar_extractor.go +++ b/pkg/consumer/tar_extractor.go @@ -11,7 +11,7 @@ type TarExtractor struct{} var _ Consumer = &TarExtractor{} -func (f *TarExtractor) Consume(reader io.Reader, destPath string) error { +func (f *TarExtractor) Consume(reader io.Reader, destPath string, _ int64) error { err := extract.TarFile(reader, destPath) if err != nil { return fmt.Errorf("error extracting file: %w", err) diff --git a/pkg/consumer/write_file.go b/pkg/consumer/write_file.go index 008b21f..547375b 100644 --- a/pkg/consumer/write_file.go +++ b/pkg/consumer/write_file.go @@ -10,7 +10,7 @@ type FileWriter struct{} var _ Consumer = &FileWriter{} -func (f *FileWriter) Consume(reader io.Reader, destPath string) error { +func (f *FileWriter) Consume(reader io.Reader, destPath string, _ int64) error { // NOTE(morgan): We check if the file exists early on allowing a fast fail, it is safe // to just apply os.O_TRUNC. Getting to this point without checking existence and // the `--force` flag is a programming error further up the stack. diff --git a/pkg/pget.go b/pkg/pget.go index 14fc1df..e4a8955 100644 --- a/pkg/pget.go +++ b/pkg/pget.go @@ -47,7 +47,7 @@ func (g *Getter) DownloadFile(ctx context.Context, url string, dest string) (int // downloadElapsed := time.Since(downloadStartTime) // writeStartTime := time.Now() - err = g.Consumer.Consume(buffer, dest) + err = g.Consumer.Consume(buffer, dest, fileSize) if err != nil { return fileSize, 0, fmt.Errorf("error writing file: %w", err) } From a6fdb2221040898f741eafbdfded9dc20359e419 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Wed, 14 Feb 2024 15:12:42 -0800 Subject: [PATCH 02/12] Add multiReader multiReader is a reader that implements the ReadAt functionality needed for some future consumers (e.g. unzip). The multiReader at a basic level consumes a mutltiChanReader via the NewmultiReader() function and returns an io.ReaderAt implementation. bufferedReader now has a .len() calculation that will report the content-length once that header is received. Since we do not know the actual content length until the download starts, there is a new signal channel to indicate the download has started and allows us to read the size of the bufferedReader. This means that there is the real likelihood that reading from multiReader may block more often than chanmultiReader. MultiReader may be able to implement Seek() and other related functions for reading the data out of strict order. --- pkg/download/buffered_reader.go | 25 ++++- pkg/download/multi_reader.go | 68 ++++++++++++ pkg/download/multi_reader_test.go | 169 ++++++++++++++++++++++++++++++ 3 files changed, 257 insertions(+), 5 deletions(-) create mode 100644 pkg/download/multi_reader.go create mode 100644 pkg/download/multi_reader_test.go diff --git a/pkg/download/buffered_reader.go b/pkg/download/buffered_reader.go index 973eeea..bb8f532 100644 --- a/pkg/download/buffered_reader.go +++ b/pkg/download/buffered_reader.go @@ -12,17 +12,21 @@ import ( // It implements io.Reader. type bufferedReader struct { // ready channel is closed when we're ready to read - ready chan struct{} - buf *bytes.Buffer - err error + ready chan struct{} + started chan struct{} + buf *bytes.Buffer + err error + size int64 } var _ io.Reader = &bufferedReader{} func newBufferedReader(capacity int64) *bufferedReader { return &bufferedReader{ - ready: make(chan struct{}), - buf: bytes.NewBuffer(make([]byte, 0, capacity)), + ready: make(chan struct{}), + started: make(chan struct{}), + buf: bytes.NewBuffer(make([]byte, 0, capacity)), + size: -1, } } @@ -40,8 +44,14 @@ func (b *bufferedReader) done() { close(b.ready) } +func (b *bufferedReader) contentLengthReceived() { + close(b.started) +} + func (b *bufferedReader) downloadBody(resp *http.Response) error { expectedBytes := resp.ContentLength + b.size = expectedBytes + b.contentLengthReceived() n, err := b.buf.ReadFrom(resp.Body) if err != nil && err != io.EOF { b.err = fmt.Errorf("error reading response for %s: %w", resp.Request.URL.String(), err) @@ -53,3 +63,8 @@ func (b *bufferedReader) downloadBody(resp *http.Response) error { } return nil } + +func (b *bufferedReader) len() int64 { + <-b.started + return b.size +} diff --git a/pkg/download/multi_reader.go b/pkg/download/multi_reader.go new file mode 100644 index 0000000..eb6a092 --- /dev/null +++ b/pkg/download/multi_reader.go @@ -0,0 +1,68 @@ +package download + +import ( + "errors" + "io" +) + +var ( + ErrInvalidOffset = errors.New("download.multiReader: Negative offset") +) + +var _ io.ReaderAt = &multiReader{} + +type multiReader struct { + readers []*bufferedReader +} + +func NewMultiReader(reader io.Reader) (io.ReaderAt, error) { + chanMultiReader, ok := reader.(*chanMultiReader) + if !ok { + // future may support converting a standard reader into a multi reader with a single reader + // for now, we only support chanMultiReader + return nil, errors.New("reader is not a chanMultiReader") + } + multiReader := &multiReader{ + readers: make([]*bufferedReader, 0), + } + for { + reader, ok := <-chanMultiReader.ch + if !ok { + break + } + bufferedReader, ok := reader.(*bufferedReader) + if !ok { + // future may support converting a standard reader into a bufferedReader, + // for now we only support bufferedReader + return nil, errors.New("reader is not a bufferedReader") + } + multiReader.readers = append(multiReader.readers, bufferedReader) + } + return multiReader, nil +} + +func (m *multiReader) ReadAt(p []byte, off int64) (n int, err error) { + var readerBytes int64 + if off < 0 { + return 0, ErrInvalidOffset + } + for i, r := range m.readers { + readerBytes += int64(r.len()) + // if offset is less than the bytes found in the reader slice to this point, + // we can start reading from this reader. + if off < readerBytes { + // Calculate the offset within the reader + innerOffset := off - (readerBytes - int64(r.len())) + n = copy(p, r.buf.Bytes()[innerOffset:]) + if i == len(m.readers)-1 && n < len(p) { + // We are at the last reader and the buffer is not full + // We need to return io.EOF + return n, io.EOF + } + return n, nil + } + } + // If we are here, we have run through all the possible readers and the offset puts us past the end of the last + // reader, meaning we should return 0 and io.EOF to indicate there is nothing to read. + return 0, io.EOF +} diff --git a/pkg/download/multi_reader_test.go b/pkg/download/multi_reader_test.go new file mode 100644 index 0000000..6aa1a4c --- /dev/null +++ b/pkg/download/multi_reader_test.go @@ -0,0 +1,169 @@ +package download + +import ( + "bytes" + "io" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewMultiReader(t *testing.T) { + tests := []struct { + name string + input io.Reader + wantErr bool + errorText string + }{ + { + name: "ErrorWhenReaderIsNotAChanMultiReader", + input: bytes.NewBuffer([]byte("not a chanMultiReader")), + wantErr: true, + errorText: "reader is not a chanMultiReader", + }, + { + name: "ErrorWhenChanMultiReaderContainsNonBufferedReader", + input: func() io.Reader { + ch := make(chan io.Reader, 1) + ch <- bytes.NewBuffer([]byte("not a bufferedReader")) + // explicitly close the channel so that the multiReader can know it's complete + close(ch) + return &chanMultiReader{ch: ch} + }(), + wantErr: true, + errorText: "reader is not a bufferedReader", + }, + { + name: "SuccessfullyCreateMultiReader", + input: func() io.Reader { + ch := make(chan io.Reader, 1) + ch <- &bufferedReader{buf: bytes.NewBuffer([]byte("data"))} + // explicitly close the channel so that the multiReader can know it's complete + close(ch) + return &chanMultiReader{ch: ch} + }(), + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewMultiReader(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("NewMultiReader() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err != nil && err.Error() != tt.errorText { + t.Errorf("NewMultiReader() error = %v, wantErr %v", err, tt.errorText) + } + }) + } +} + +func TestMultiReader_ReadAt(t *testing.T) { + // Create buffered channel for the multiChanReader so the channel can be closed for the testing case + count := 10 + expected := "" + ch := make(chan io.Reader, count) + for i := 0; i < count; i++ { + str := strings.Repeat(strconv.Itoa(i), 100) + expected = expected + str + br := &bufferedReader{ + buf: bytes.NewBuffer([]byte(str)), + size: int64(len(str)), + ready: make(chan struct{}), + started: make(chan struct{}), + } + br.done() + br.contentLengthReceived() + ch <- br + } + + // explicitly close the channel so that the multiReader can know it's complete + close(ch) + multiChanReader := &chanMultiReader{ch: ch} + multiReader, err := NewMultiReader(multiChanReader) + require.NoError(t, err) + + tests := []struct { + name string + offset int64 + buffer []byte + expectedN int + expectedErr error + expectedData []byte + }{ + { + name: "Read Within First Reader", + offset: 0, + buffer: make([]byte, 50), + expectedN: 50, + expectedErr: nil, + expectedData: []byte(expected[:50]), + }, + { + name: "Read Within Last Reader", + offset: int64(len(expected) - 75), + buffer: make([]byte, 50), + expectedN: 50, + expectedErr: nil, + expectedData: []byte(expected[len(expected)-75 : len(expected)-25]), + }, + { + name: "Read Across Multiple Readers", + offset: 50, + buffer: make([]byte, 100), + expectedN: 50, + expectedErr: nil, + expectedData: []byte(expected[50:100]), + }, + { + name: "Read Past End Of Last Reader", + // offset is greater than the total size of the readers + offset: int64(len(expected) - 1), + buffer: make([]byte, 100), + expectedN: 1, + expectedErr: io.EOF, + expectedData: []byte(expected[len(expected)-1:]), + }, + { + name: "Read At Negative Offset", + offset: -1, + expectedErr: ErrInvalidOffset, + }, + { + name: "Read At Offset Greater Than Total Size", + offset: int64(len(expected)) + 1, + buffer: make([]byte, 100), + expectedN: 0, + expectedErr: io.EOF, + expectedData: []byte{}, + }, + { + name: "Read At Offset Equal To Total Size", + offset: int64(len(expected)), + buffer: make([]byte, 100), + expectedN: 0, + expectedErr: io.EOF, + expectedData: []byte{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n, err := multiReader.ReadAt(tt.buffer, tt.offset) + assert.Equal(t, tt.expectedN, n) + if tt.expectedErr != nil { + assert.ErrorIs(t, err, tt.expectedErr) + } + assert.Equal(t, tt.expectedData, tt.buffer[:tt.expectedN]) + if len(tt.buffer) > tt.expectedN { + emptyData := bytes.Repeat([]byte{0}, len(tt.buffer)-tt.expectedN) + assert.Equal(t, tt.buffer[tt.expectedN:], emptyData) + } + }) + } + +} From 8ebeb05ab5d900afe524fa1036efedbaf2a3bdb4 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Wed, 14 Feb 2024 21:29:03 -0800 Subject: [PATCH 03/12] Implement ZipExtractor consumer Implement ZipExtractor consumer --- cmd/multifile/multifile.go | 3 ++ pkg/config/config.go | 3 ++ pkg/consumer/zip_extractor.go | 25 ++++++++++ pkg/download/multi_reader.go | 2 +- pkg/extract/zip.go | 89 +++++++++++++++++++++++++++++++++++ 5 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 pkg/consumer/zip_extractor.go create mode 100644 pkg/extract/zip.go diff --git a/cmd/multifile/multifile.go b/cmd/multifile/multifile.go index 84bf045..b43d148 100644 --- a/cmd/multifile/multifile.go +++ b/cmd/multifile/multifile.go @@ -69,6 +69,9 @@ func multifilePreRunE(cmd *cobra.Command, args []string) error { if viper.GetString(config.OptOutputConsumer) == config.ConsumerTarExtractor { return fmt.Errorf("cannot use --output-consumer tar-extractor with multifile mode") } + if viper.GetString(config.OptOutputConsumer) == config.ConsumerZipExtractor { + return fmt.Errorf("cannot use --output-consumer zip-extractor with multifile mode") + } return nil } diff --git a/pkg/config/config.go b/pkg/config/config.go index 8cef10a..1cf59ea 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -20,6 +20,7 @@ const ( ConsumerFile = "file" ConsumerTarExtractor = "tar-extractor" ConsumerNull = "null" + ConsumerZipExtractor = "zip-extractor" ) var ( @@ -160,6 +161,8 @@ func GetConsumer() (consumer.Consumer, error) { return &consumer.FileWriter{}, nil case ConsumerTarExtractor: return &consumer.TarExtractor{}, nil + case ConsumerZipExtractor: + return &consumer.ZipExtractor{}, nil case ConsumerNull: return &consumer.NullWriter{}, nil default: diff --git a/pkg/consumer/zip_extractor.go b/pkg/consumer/zip_extractor.go new file mode 100644 index 0000000..bb8b6bd --- /dev/null +++ b/pkg/consumer/zip_extractor.go @@ -0,0 +1,25 @@ +package consumer + +import ( + "fmt" + "io" + + "github.com/replicate/pget/pkg/download" + "github.com/replicate/pget/pkg/extract" +) + +type ZipExtractor struct{} + +var _ Consumer = &ZipExtractor{} + +func (f *ZipExtractor) Consume(reader io.Reader, destPath string, size int64) error { + readerAt, err := download.NewMultiReader(reader) + if err != nil { + return fmt.Errorf("error converting to multi reader: %w", err) + } + err = extract.ZipFile(readerAt, destPath, size) + if err != nil { + return fmt.Errorf("error extracting file: %w", err) + } + return nil +} diff --git a/pkg/download/multi_reader.go b/pkg/download/multi_reader.go index eb6a092..7720386 100644 --- a/pkg/download/multi_reader.go +++ b/pkg/download/multi_reader.go @@ -52,7 +52,7 @@ func (m *multiReader) ReadAt(p []byte, off int64) (n int, err error) { // we can start reading from this reader. if off < readerBytes { // Calculate the offset within the reader - innerOffset := off - (readerBytes - int64(r.len())) + innerOffset := off - (readerBytes - r.len()) n = copy(p, r.buf.Bytes()[innerOffset:]) if i == len(m.readers)-1 && n < len(p) { // We are at the last reader and the buffer is not full diff --git a/pkg/extract/zip.go b/pkg/extract/zip.go new file mode 100644 index 0000000..4edd62b --- /dev/null +++ b/pkg/extract/zip.go @@ -0,0 +1,89 @@ +package extract + +import ( + "archive/zip" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" +) + +// ZipFile extracts a zip file to the given destination path. +func ZipFile(reader io.ReaderAt, destPath string, size int64) error { + err := os.MkdirAll(destPath, 0755) + if err != nil { + return fmt.Errorf("error creating destination directory: %w", err) + } + + zipReader, err := zip.NewReader(reader, size) + if err != nil { + return fmt.Errorf("error creating zip reader: %w", err) + } + + for _, file := range zipReader.File { + err := handleFileFromZip(file, destPath) + if err != nil { + return fmt.Errorf("error extracting file: %w", err) + } + } + return nil +} + +func handleFileFromZip(file *zip.File, outputDir string) error { + target := outputDir + file.Name + targetDir := filepath.Dir(target) + if file.FileInfo().IsDir() { + return extractDir(file, targetDir) + } else if file.FileInfo().Mode().IsRegular() { + return extractFile(file, targetDir) + } else { + return fmt.Errorf("unsupported file type (not dir or regular): %s (%d)", file.Name, file.FileInfo().Mode().Type()) + } + +} + +func extractDir(file *zip.File, outputDir string) error { + target := outputDir + file.Name + err := os.MkdirAll(target, file.Mode().Perm()) + if err != nil { + return fmt.Errorf("error creating directory: %w", err) + } + return applyPermissions(target, file.Mode().Perm()) +} + +func extractFile(file *zip.File, outputDir string) error { + target := outputDir + file.Name + targetDir := filepath.Dir(target) + err := os.MkdirAll(targetDir, 0755) + if err != nil { + return fmt.Errorf("error creating directory: %w", err) + } + + // Open the file inside the zip archive + zipFile, err := file.Open() + if err != nil { + return fmt.Errorf("error opening file: %w", err) + } + defer zipFile.Close() + + // Create the file on the filesystem + out, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, file.Mode()) + if err != nil { + return fmt.Errorf("error creating file: %w", err) + } + defer out.Close() + + // Copy the file contents + _, err = io.Copy(out, zipFile) + if err != nil { + return fmt.Errorf("error copying file: %w", err) + } + return applyPermissions(target, file.Mode().Perm()) +} + +func applyPermissions(filepath string, fileMode fs.FileMode) error { + // Do not apply setuid/gid/sticky bits. + perms := fileMode &^ os.ModeSetuid &^ os.ModeSetgid &^ os.ModeSticky + return os.Chmod(filepath, perms) +} From 4cd9172a1cafa5feec103e014ee57c77c7e5b3c3 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Wed, 14 Feb 2024 22:24:39 -0800 Subject: [PATCH 04/12] Log warning on -x if consumer doesn't match If the consumer is not File or tar-extractor when -x is used, log a warning that the tar-extractor supersedes the specified consumer. --- cmd/root/root.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cmd/root/root.go b/cmd/root/root.go index 3815418..ccd08f9 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -165,6 +165,10 @@ func hideAndDeprecateFlags(cmd *cobra.Command) error { func rootCmdPreRun(cmd *cobra.Command, args []string) { if viper.GetBool(config.OptExtract) { + currentConsumer := viper.GetString(config.OptOutputConsumer) + if currentConsumer != config.ConsumerFile && currentConsumer != config.ConsumerTarExtractor { + log.Warn().Msg("Tar Extract Enabled, overriding output consumer to `tar-extractor`") + } viper.Set(config.OptOutputConsumer, config.ConsumerTarExtractor) } } From 46ed2db1ce14d781499254ef22d1a9349cc4083c Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Wed, 14 Feb 2024 22:10:48 -0800 Subject: [PATCH 05/12] Handle Overwriting in the Consumer Make the consumer handle overwriting explicitly. This addresses edge cases with tar and zip consumer when extracting files. --- cmd/multifile/multifile.go | 4 +++ cmd/root/root.go | 4 +++ pkg/consumer/consumer.go | 2 ++ pkg/consumer/null.go | 7 ++++- pkg/consumer/tar_extractor.go | 10 ++++-- pkg/consumer/write_file.go | 17 ++++++++--- pkg/consumer/zip_extractor.go | 10 ++++-- pkg/extract/tar.go | 36 ++++++++++++++++++---- pkg/extract/tar_test.go | 57 +++++++++++++++++++++++++++++------ pkg/extract/zip.go | 48 ++++++++++++++++++----------- 10 files changed, 153 insertions(+), 42 deletions(-) diff --git a/cmd/multifile/multifile.go b/cmd/multifile/multifile.go index b43d148..c541ed8 100644 --- a/cmd/multifile/multifile.go +++ b/cmd/multifile/multifile.go @@ -134,6 +134,10 @@ func multifileExecute(ctx context.Context, manifest pget.Manifest) error { return fmt.Errorf("error getting consumer: %w", err) } + if viper.GetBool(config.OptForce) { + consumer.EnableOverwrite() + } + getter := &pget.Getter{ Downloader: download.GetBufferMode(downloadOpts), Consumer: consumer, diff --git a/cmd/root/root.go b/cmd/root/root.go index ccd08f9..2d7637d 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -233,6 +233,10 @@ func rootExecute(ctx context.Context, urlString, dest string) error { return err } + if viper.GetBool(config.OptForce) { + consumer.EnableOverwrite() + } + getter := pget.Getter{ Downloader: download.GetBufferMode(downloadOpts), Consumer: consumer, diff --git a/pkg/consumer/consumer.go b/pkg/consumer/consumer.go index 82675ee..c99757c 100644 --- a/pkg/consumer/consumer.go +++ b/pkg/consumer/consumer.go @@ -4,4 +4,6 @@ import "io" type Consumer interface { Consume(reader io.Reader, destPath string, fileSize int64) error + // EnableOverwrite sets the overwrite flag for the consumer, allowing it to overwrite files if necessary/supported + EnableOverwrite() } diff --git a/pkg/consumer/null.go b/pkg/consumer/null.go index 9407d6c..f2b7d97 100644 --- a/pkg/consumer/null.go +++ b/pkg/consumer/null.go @@ -4,7 +4,8 @@ import ( "io" ) -type NullWriter struct{} +type NullWriter struct { +} var _ Consumer = &NullWriter{} @@ -13,3 +14,7 @@ func (f *NullWriter) Consume(reader io.Reader, _ string, _ int64) error { _, _ = io.Copy(io.Discard, reader) return nil } + +func (f *NullWriter) EnableOverwrite() { + // no-op +} diff --git a/pkg/consumer/tar_extractor.go b/pkg/consumer/tar_extractor.go index 220273d..8f00ae0 100644 --- a/pkg/consumer/tar_extractor.go +++ b/pkg/consumer/tar_extractor.go @@ -7,14 +7,20 @@ import ( "github.com/replicate/pget/pkg/extract" ) -type TarExtractor struct{} +type TarExtractor struct { + overwrite bool +} var _ Consumer = &TarExtractor{} func (f *TarExtractor) Consume(reader io.Reader, destPath string, _ int64) error { - err := extract.TarFile(reader, destPath) + err := extract.TarFile(reader, destPath, f.overwrite) if err != nil { return fmt.Errorf("error extracting file: %w", err) } return nil } + +func (f *TarExtractor) EnableOverwrite() { + f.overwrite = true +} diff --git a/pkg/consumer/write_file.go b/pkg/consumer/write_file.go index 547375b..afc35c0 100644 --- a/pkg/consumer/write_file.go +++ b/pkg/consumer/write_file.go @@ -6,15 +6,18 @@ import ( "os" ) -type FileWriter struct{} +type FileWriter struct { + overwrite bool +} var _ Consumer = &FileWriter{} func (f *FileWriter) Consume(reader io.Reader, destPath string, _ int64) error { - // NOTE(morgan): We check if the file exists early on allowing a fast fail, it is safe - // to just apply os.O_TRUNC. Getting to this point without checking existence and - // the `--force` flag is a programming error further up the stack. - out, err := os.OpenFile(destPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + openFlags := os.O_WRONLY | os.O_CREATE + if f.overwrite { + openFlags |= os.O_TRUNC + } + out, err := os.OpenFile(destPath, openFlags, 0644) if err != nil { return fmt.Errorf("error writing file: %w", err) } @@ -26,3 +29,7 @@ func (f *FileWriter) Consume(reader io.Reader, destPath string, _ int64) error { } return nil } + +func (f *FileWriter) EnableOverwrite() { + f.overwrite = true +} diff --git a/pkg/consumer/zip_extractor.go b/pkg/consumer/zip_extractor.go index bb8b6bd..adbfc85 100644 --- a/pkg/consumer/zip_extractor.go +++ b/pkg/consumer/zip_extractor.go @@ -8,7 +8,9 @@ import ( "github.com/replicate/pget/pkg/extract" ) -type ZipExtractor struct{} +type ZipExtractor struct { + overwrite bool +} var _ Consumer = &ZipExtractor{} @@ -17,9 +19,13 @@ func (f *ZipExtractor) Consume(reader io.Reader, destPath string, size int64) er if err != nil { return fmt.Errorf("error converting to multi reader: %w", err) } - err = extract.ZipFile(readerAt, destPath, size) + err = extract.ZipFile(readerAt, destPath, size, f.overwrite) if err != nil { return fmt.Errorf("error extracting file: %w", err) } return nil } + +func (f *ZipExtractor) EnableOverwrite() { + f.overwrite = true +} diff --git a/pkg/extract/tar.go b/pkg/extract/tar.go index 23b836d..aa74b25 100644 --- a/pkg/extract/tar.go +++ b/pkg/extract/tar.go @@ -17,7 +17,7 @@ type link struct { newName string } -func TarFile(reader io.Reader, destDir string) error { +func TarFile(reader io.Reader, destDir string, overwrite bool) error { var links []*link startTime := time.Now() @@ -49,7 +49,11 @@ func TarFile(reader io.Reader, destDir string) error { return err } case tar.TypeReg: - targetFile, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY, os.FileMode(header.Mode)) + openFlags := os.O_CREATE | os.O_WRONLY + if overwrite { + openFlags |= os.O_TRUNC + } + targetFile, err := os.OpenFile(target, openFlags, os.FileMode(header.Mode)) if err != nil { return err } @@ -68,7 +72,7 @@ func TarFile(reader io.Reader, destDir string) error { } } - if err := createLinks(links, destDir); err != nil { + if err := createLinks(links, destDir, overwrite); err != nil { return fmt.Errorf("error creating links: %w", err) } @@ -81,7 +85,7 @@ func TarFile(reader io.Reader, destDir string) error { return nil } -func createLinks(links []*link, destDir string) error { +func createLinks(links []*link, destDir string, overwrite bool) error { for _, link := range links { targetDir := filepath.Dir(link.newName) if err := os.MkdirAll(targetDir, 0755); err != nil { @@ -90,11 +94,11 @@ func createLinks(links []*link, destDir string) error { switch link.linkType { case tar.TypeLink: oldPath := filepath.Join(destDir, link.oldName) - if err := os.Link(oldPath, link.newName); err != nil { + if err := createHardLink(oldPath, link.newName, overwrite); err != nil { return fmt.Errorf("error creating hard link from %s to %s: %w", oldPath, link.newName, err) } case tar.TypeSymlink: - if err := os.Symlink(link.oldName, link.newName); err != nil { + if err := createSymlink(link.oldName, link.newName, overwrite); err != nil { return fmt.Errorf("error creating symlink from %s to %s: %w", link.oldName, link.newName, err) } default: @@ -103,3 +107,23 @@ func createLinks(links []*link, destDir string) error { } return nil } + +func createHardLink(oldName, newName string, overwrite bool) error { + if overwrite { + err := os.Remove(newName) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("error removing existing file: %w", err) + } + } + return os.Link(oldName, newName) +} + +func createSymlink(oldName, newName string, overwrite bool) error { + if overwrite { + err := os.Remove(newName) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("error removing existing symlink/file: %w", err) + } + } + return os.Symlink(oldName, newName) +} diff --git a/pkg/extract/tar_test.go b/pkg/extract/tar_test.go index e21f5af..409c774 100644 --- a/pkg/extract/tar_test.go +++ b/pkg/extract/tar_test.go @@ -12,34 +12,70 @@ import ( func TestCreateLinks(t *testing.T) { tests := []struct { - name string - links []*link - expectedError bool + name string + links []*link + expectedError bool + overwrite bool + createOverwritenFile bool }{ { - name: "EmptyLink", + name: "Empty Link", links: []*link{}, }, { - name: "ValidHardLink", + name: "Valid Hard Link", links: []*link{{tar.TypeLink, "", "testLinkHard"}}, }, { - name: "ValidSymlink", + name: "Valid Symlink", links: []*link{{tar.TypeSymlink, "", "testLinkSym"}}, }, { - name: "InvalidLinkType", + name: "Invalid LinkType", links: []*link{{'!', "", "x"}}, expectedError: true, }, { - name: "ValidMultipleLinks", + name: "Valid Multiple Links", links: []*link{ {tar.TypeLink, "", "testLinkHard"}, {tar.TypeSymlink, "", "testLinkSym"}, }, }, + { + name: "HardLink_OverwriteEnabled_File Exists", + links: []*link{{tar.TypeLink, "", "testLinkHard"}}, + overwrite: true, + createOverwritenFile: true, + }, + { + name: "HardLink_OverwriteDisabled_FileExists", + links: []*link{{tar.TypeLink, "", "testLinkHard"}}, + createOverwritenFile: true, + expectedError: true, + }, + { + name: "HardLink_OverwriteEnabled_FileDoesNotExist", + links: []*link{{tar.TypeLink, "", "testLinkHard"}}, + overwrite: true, + }, + { + name: "SymLink_OverwriteEnabled_FileExists", + links: []*link{{tar.TypeSymlink, "", "testLinkSym"}}, + overwrite: true, + createOverwritenFile: true, + }, + { + name: "SymLink_OverwriteDisabled_FileExists", + links: []*link{{tar.TypeSymlink, "", "testLinkSym"}}, + createOverwritenFile: true, + expectedError: true, + }, + { + name: "SymLink_OverwriteEnabled_FileDoesNotExist", + links: []*link{{tar.TypeSymlink, "", "testLinkSym"}}, + overwrite: true, + }, } for _, tt := range tests { @@ -56,6 +92,9 @@ func TestCreateLinks(t *testing.T) { for _, link := range tt.links { if link.linkType == tar.TypeLink || link.linkType == tar.TypeSymlink { testFile, err := os.CreateTemp(destDir, "test-") + if tt.createOverwritenFile { + _, err = os.Create(filepath.Join(destDir, link.newName)) + } if err != nil { t.Fatalf("Test failed, could not create test file: %v", err) } @@ -65,7 +104,7 @@ func TestCreateLinks(t *testing.T) { } } - err = createLinks(tt.links, destDir) + err = createLinks(tt.links, destDir, tt.overwrite) // Validation if tt.expectedError { diff --git a/pkg/extract/zip.go b/pkg/extract/zip.go index 4edd62b..a991c5a 100644 --- a/pkg/extract/zip.go +++ b/pkg/extract/zip.go @@ -4,13 +4,12 @@ import ( "archive/zip" "fmt" "io" - "io/fs" "os" "path/filepath" ) // ZipFile extracts a zip file to the given destination path. -func ZipFile(reader io.ReaderAt, destPath string, size int64) error { +func ZipFile(reader io.ReaderAt, destPath string, size int64, overwrite bool) error { err := os.MkdirAll(destPath, 0755) if err != nil { return fmt.Errorf("error creating destination directory: %w", err) @@ -22,7 +21,7 @@ func ZipFile(reader io.ReaderAt, destPath string, size int64) error { } for _, file := range zipReader.File { - err := handleFileFromZip(file, destPath) + err := handleFileFromZip(file, destPath, overwrite) if err != nil { return fmt.Errorf("error extracting file: %w", err) } @@ -30,13 +29,13 @@ func ZipFile(reader io.ReaderAt, destPath string, size int64) error { return nil } -func handleFileFromZip(file *zip.File, outputDir string) error { +func handleFileFromZip(file *zip.File, outputDir string, overwrite bool) error { target := outputDir + file.Name targetDir := filepath.Dir(target) if file.FileInfo().IsDir() { return extractDir(file, targetDir) } else if file.FileInfo().Mode().IsRegular() { - return extractFile(file, targetDir) + return extractFile(file, targetDir, overwrite) } else { return fmt.Errorf("unsupported file type (not dir or regular): %s (%d)", file.Name, file.FileInfo().Mode().Type()) } @@ -45,14 +44,29 @@ func handleFileFromZip(file *zip.File, outputDir string) error { func extractDir(file *zip.File, outputDir string) error { target := outputDir + file.Name - err := os.MkdirAll(target, file.Mode().Perm()) - if err != nil { + perms := file.Mode().Perm() &^ os.ModeSetuid &^ os.ModeSetgid &^ os.ModeSticky + info, err := os.Stat(target) + if err == nil && !info.IsDir() { + return fmt.Errorf("error creating directory: %s already exists and is not a directory", target) + } + if err != nil && !os.IsNotExist(err) { return fmt.Errorf("error creating directory: %w", err) } - return applyPermissions(target, file.Mode().Perm()) + if os.IsNotExist(err) { + err := os.MkdirAll(target, perms) + if err != nil { + return fmt.Errorf("error creating directory: %w", err) + } + } else { + err := os.Chmod(target, perms) + if err != nil { + return fmt.Errorf("error changing directory permissions: %w", err) + } + } + return nil } -func extractFile(file *zip.File, outputDir string) error { +func extractFile(file *zip.File, outputDir string, overwrite bool) error { target := outputDir + file.Name targetDir := filepath.Dir(target) err := os.MkdirAll(targetDir, 0755) @@ -68,7 +82,13 @@ func extractFile(file *zip.File, outputDir string) error { defer zipFile.Close() // Create the file on the filesystem - out, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, file.Mode()) + openFlags := os.O_WRONLY | os.O_CREATE + if overwrite { + openFlags |= os.O_TRUNC + } + // Do not apply setuid/gid/sticky bits. + perms := file.Mode().Perm() &^ os.ModeSetuid &^ os.ModeSetgid &^ os.ModeSticky + out, err := os.OpenFile(target, openFlags, perms) if err != nil { return fmt.Errorf("error creating file: %w", err) } @@ -79,11 +99,5 @@ func extractFile(file *zip.File, outputDir string) error { if err != nil { return fmt.Errorf("error copying file: %w", err) } - return applyPermissions(target, file.Mode().Perm()) -} - -func applyPermissions(filepath string, fileMode fs.FileMode) error { - // Do not apply setuid/gid/sticky bits. - perms := fileMode &^ os.ModeSetuid &^ os.ModeSetgid &^ os.ModeSticky - return os.Chmod(filepath, perms) + return nil } From 9ca65bd7a18893498432b0b54c1740c3e57bea7a Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Wed, 14 Feb 2024 22:26:22 -0800 Subject: [PATCH 06/12] Address import loop Move the ConsistentHashingStrategyKey to client not config. --- pkg/client/client.go | 7 +++++-- pkg/client/client_test.go | 3 +-- pkg/config/config.go | 4 ---- pkg/download/consistent_hashing.go | 3 +-- 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/pkg/client/client.go b/pkg/client/client.go index e87bc67..ef2bbea 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -12,7 +12,6 @@ import ( "github.com/hashicorp/go-retryablehttp" - "github.com/replicate/pget/pkg/config" "github.com/replicate/pget/pkg/logging" "github.com/replicate/pget/pkg/version" ) @@ -24,6 +23,10 @@ const ( retryMaxWait = 1250 * time.Millisecond ) +type ConsistentHashingStrategy struct{} + +var ConsistentHashingStrategyKey ConsistentHashingStrategy + var ErrStrategyFallback = errors.New("fallback to next strategy") // HTTPClient is a wrapper around http.Client that allows for limiting the number of concurrent connections per host @@ -111,7 +114,7 @@ func RetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, err // While type assertions are not ideal, alternatives are limited to adding custom data in the request // or in the context. The context clearly isolates this data. - consistentHashing, ok := ctx.Value(config.ConsistentHashingStrategyKey).(bool) + consistentHashing, ok := ctx.Value(ConsistentHashingStrategyKey).(bool) if ok && consistentHashing { if fallbackError(err) { return false, ErrStrategyFallback diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index e4f85fd..3d21087 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -11,7 +11,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/replicate/pget/pkg/client" - "github.com/replicate/pget/pkg/config" ) func TestGetSchemeHostKey(t *testing.T) { @@ -24,7 +23,7 @@ func TestGetSchemeHostKey(t *testing.T) { func TestRetryPolicy(t *testing.T) { bgCtx := context.Background() - chCtx := context.WithValue(bgCtx, config.ConsistentHashingStrategyKey, true) + chCtx := context.WithValue(bgCtx, client.ConsistentHashingStrategyKey, true) errContext, cancel := context.WithCancel(bgCtx) cancel() diff --git a/pkg/config/config.go b/pkg/config/config.go index 1cf59ea..d4b3557 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -27,10 +27,6 @@ var ( DefaultCacheURIPrefixes = []string{"https://weights.replicate.delivery"} ) -type ConsistentHashingStrategy struct{} - -var ConsistentHashingStrategyKey ConsistentHashingStrategy - type DeprecatedFlag struct { Flag string Msg string diff --git a/pkg/download/consistent_hashing.go b/pkg/download/consistent_hashing.go index 07f4060..e70532b 100644 --- a/pkg/download/consistent_hashing.go +++ b/pkg/download/consistent_hashing.go @@ -13,7 +13,6 @@ import ( "golang.org/x/sync/errgroup" "github.com/replicate/pget/pkg/client" - "github.com/replicate/pget/pkg/config" "github.com/replicate/pget/pkg/consistent" "github.com/replicate/pget/pkg/logging" ) @@ -250,7 +249,7 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io } func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64, urlString string) (*http.Response, error) { - chContext := context.WithValue(ctx, config.ConsistentHashingStrategyKey, true) + chContext := context.WithValue(ctx, client.ConsistentHashingStrategyKey, true) req, err := http.NewRequestWithContext(chContext, "GET", urlString, nil) if err != nil { return nil, fmt.Errorf("failed to download %s: %w", req.URL.String(), err) From e0b5c392d8080c6712aa5d3d173492b9cfcd24e8 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Wed, 14 Feb 2024 22:58:37 -0800 Subject: [PATCH 07/12] Rename `zip-extractor` to `unzip` 'unzip' is the binary used in linux to extract from a zip file, lets stick with names that are more aligned with the CLI tools we otherwise use. --- cmd/multifile/multifile.go | 2 +- pkg/config/config.go | 2 +- pkg/consumer/{zip_extractor.go => unzip.go} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename pkg/consumer/{zip_extractor.go => unzip.go} (100%) diff --git a/cmd/multifile/multifile.go b/cmd/multifile/multifile.go index c541ed8..1911fd0 100644 --- a/cmd/multifile/multifile.go +++ b/cmd/multifile/multifile.go @@ -70,7 +70,7 @@ func multifilePreRunE(cmd *cobra.Command, args []string) error { return fmt.Errorf("cannot use --output-consumer tar-extractor with multifile mode") } if viper.GetString(config.OptOutputConsumer) == config.ConsumerZipExtractor { - return fmt.Errorf("cannot use --output-consumer zip-extractor with multifile mode") + return fmt.Errorf("cannot use --output-consumer unzip with multifile mode") } return nil } diff --git a/pkg/config/config.go b/pkg/config/config.go index d4b3557..dbe485e 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -20,7 +20,7 @@ const ( ConsumerFile = "file" ConsumerTarExtractor = "tar-extractor" ConsumerNull = "null" - ConsumerZipExtractor = "zip-extractor" + ConsumerZipExtractor = "unzip" ) var ( diff --git a/pkg/consumer/zip_extractor.go b/pkg/consumer/unzip.go similarity index 100% rename from pkg/consumer/zip_extractor.go rename to pkg/consumer/unzip.go From e822d33b470aa28887f605861ce8316609b9d1b2 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Wed, 14 Feb 2024 23:37:40 -0800 Subject: [PATCH 08/12] Add ContentType support for Consumer use Fetch now returns contentType and consumers take ContentType as an argument. This is in preperation of multifile being able to direct differnt contentTypes to different consumers in the case of tar/zip extraction. --- pkg/consumer/consumer.go | 2 +- pkg/consumer/null.go | 2 +- pkg/consumer/tar_extractor.go | 2 +- pkg/consumer/unzip.go | 4 +-- pkg/consumer/write_file.go | 2 +- pkg/download/buffer.go | 26 +++++++++------ pkg/download/buffer_test.go | 2 +- pkg/download/buffer_unit_test.go | 2 +- pkg/download/consistent_hashing.go | 43 ++++++++++++++----------- pkg/download/consistent_hashing_test.go | 20 ++++++------ pkg/download/strategy.go | 2 +- pkg/pget.go | 4 +-- 12 files changed, 62 insertions(+), 49 deletions(-) diff --git a/pkg/consumer/consumer.go b/pkg/consumer/consumer.go index c99757c..11089d7 100644 --- a/pkg/consumer/consumer.go +++ b/pkg/consumer/consumer.go @@ -3,7 +3,7 @@ package consumer import "io" type Consumer interface { - Consume(reader io.Reader, destPath string, fileSize int64) error + Consume(reader io.Reader, destPath string, fileSize int64, contentType string) error // EnableOverwrite sets the overwrite flag for the consumer, allowing it to overwrite files if necessary/supported EnableOverwrite() } diff --git a/pkg/consumer/null.go b/pkg/consumer/null.go index f2b7d97..0c31d6d 100644 --- a/pkg/consumer/null.go +++ b/pkg/consumer/null.go @@ -9,7 +9,7 @@ type NullWriter struct { var _ Consumer = &NullWriter{} -func (f *NullWriter) Consume(reader io.Reader, _ string, _ int64) error { +func (f *NullWriter) Consume(reader io.Reader, destPath string, fileSize int64, contentType string) error { // io.Discard is explicitly designed to always succeed, ignore errors. _, _ = io.Copy(io.Discard, reader) return nil diff --git a/pkg/consumer/tar_extractor.go b/pkg/consumer/tar_extractor.go index 8f00ae0..b4f530b 100644 --- a/pkg/consumer/tar_extractor.go +++ b/pkg/consumer/tar_extractor.go @@ -13,7 +13,7 @@ type TarExtractor struct { var _ Consumer = &TarExtractor{} -func (f *TarExtractor) Consume(reader io.Reader, destPath string, _ int64) error { +func (f *TarExtractor) Consume(reader io.Reader, destPath string, fileSize int64, contentType string) error { err := extract.TarFile(reader, destPath, f.overwrite) if err != nil { return fmt.Errorf("error extracting file: %w", err) diff --git a/pkg/consumer/unzip.go b/pkg/consumer/unzip.go index adbfc85..17e9ea4 100644 --- a/pkg/consumer/unzip.go +++ b/pkg/consumer/unzip.go @@ -14,12 +14,12 @@ type ZipExtractor struct { var _ Consumer = &ZipExtractor{} -func (f *ZipExtractor) Consume(reader io.Reader, destPath string, size int64) error { +func (f *ZipExtractor) Consume(reader io.Reader, destPath string, fileSize int64, contentType string) error { readerAt, err := download.NewMultiReader(reader) if err != nil { return fmt.Errorf("error converting to multi reader: %w", err) } - err = extract.ZipFile(readerAt, destPath, size, f.overwrite) + err = extract.ZipFile(readerAt, destPath, fileSize, f.overwrite) if err != nil { return fmt.Errorf("error extracting file: %w", err) } diff --git a/pkg/consumer/write_file.go b/pkg/consumer/write_file.go index afc35c0..a51ff68 100644 --- a/pkg/consumer/write_file.go +++ b/pkg/consumer/write_file.go @@ -12,7 +12,7 @@ type FileWriter struct { var _ Consumer = &FileWriter{} -func (f *FileWriter) Consume(reader io.Reader, destPath string, _ int64) error { +func (f *FileWriter) Consume(reader io.Reader, destPath string, fileSize int64, contentType string) error { openFlags := os.O_WRONLY | os.O_CREATE if f.overwrite { openFlags |= os.O_TRUNC diff --git a/pkg/download/buffer.go b/pkg/download/buffer.go index 6020710..81a4a83 100644 --- a/pkg/download/buffer.go +++ b/pkg/download/buffer.go @@ -60,16 +60,19 @@ func (m *BufferMode) getFileSizeFromContentRange(contentRange string) (int64, er } type firstReqResult struct { - fileSize int64 - trueURL string - err error + fileSize int64 + trueURL string + err error + contentType string } -func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, error) { +func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, string, error) { logger := logging.GetLogger() br := newBufferedReader(m.minChunkSize()) + var contentType string + firstReqResultCh := make(chan firstReqResult) m.queue.submit(func() { m.sem.Go(func() error { @@ -94,7 +97,10 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e firstReqResultCh <- firstReqResult{err: err} return err } - firstReqResultCh <- firstReqResult{fileSize: fileSize, trueURL: trueURL} + firstReqResultCh <- firstReqResult{ + fileSize: fileSize, + trueURL: trueURL, + contentType: firstChunkResp.Header.Get("Content-Type")} return br.downloadBody(firstChunkResp) }) @@ -105,8 +111,10 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e panic("logic error in BufferMode: first request didn't return any output") } + contentType = firstReqResult.contentType + if firstReqResult.err != nil { - return nil, -1, firstReqResult.err + return nil, -1, contentType, firstReqResult.err } fileSize := firstReqResult.fileSize @@ -114,7 +122,7 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e if fileSize <= m.minChunkSize() { // we only need a single chunk: just download it and finish - return br, fileSize, nil + return br, fileSize, contentType, nil } remainingBytes := fileSize - m.minChunkSize() @@ -134,7 +142,7 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e chunkSize := remainingBytes / int64(numChunks) if chunkSize < 0 { - return nil, -1, fmt.Errorf("error: chunksize incorrect - result is negative, %d", chunkSize) + return nil, -1, contentType, fmt.Errorf("error: chunksize incorrect - result is negative, %d", chunkSize) } m.queue.submit(func() { @@ -169,7 +177,7 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e } }) - return newChanMultiReader(readersCh), fileSize, nil + return newChanMultiReader(readersCh), fileSize, contentType, nil } func (m *BufferMode) DoRequest(ctx context.Context, start, end int64, trueURL string) (*http.Response, error) { diff --git a/pkg/download/buffer_test.go b/pkg/download/buffer_test.go index 250d80a..a7ab79c 100644 --- a/pkg/download/buffer_test.go +++ b/pkg/download/buffer_test.go @@ -25,7 +25,7 @@ func benchmarkDownloadURL(opts download.Options, url string, b *testing.B) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - _, _, err := bufferMode.Fetch(ctx, url) + _, _, _, err := bufferMode.Fetch(ctx, url) assert.NoError(b, err) } } diff --git a/pkg/download/buffer_unit_test.go b/pkg/download/buffer_unit_test.go index e9b7d00..ec165d9 100644 --- a/pkg/download/buffer_unit_test.go +++ b/pkg/download/buffer_unit_test.go @@ -112,7 +112,7 @@ func TestFileToBufferChunkCountExceedsMaxChunks(t *testing.T) { opts.MinChunkSize = tc.minChunkSize bufferMode := GetBufferMode(opts) path, _ := url.JoinPath(server.URL, testFilePath) - download, size, err := bufferMode.Fetch(context.Background(), path) + download, size, _, err := bufferMode.Fetch(context.Background(), path) require.NoError(t, err) data, err := io.ReadAll(download) assert.NoError(t, err) diff --git a/pkg/download/consistent_hashing.go b/pkg/download/consistent_hashing.go index e70532b..69b5893 100644 --- a/pkg/download/consistent_hashing.go +++ b/pkg/download/consistent_hashing.go @@ -6,7 +6,7 @@ import ( "fmt" "io" "net/http" - "net/url" + neturl "net/url" "strconv" "strings" @@ -29,7 +29,7 @@ type ConsistentHashingMode struct { } type CacheKey struct { - URL *url.URL `hash:"string"` + URL *neturl.URL `hash:"string"` Slice int64 } @@ -78,12 +78,14 @@ func (m *ConsistentHashingMode) getFileSizeFromContentRange(contentRange string) return strconv.ParseInt(groups[1], 10, 64) } -func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io.Reader, int64, error) { +func (m *ConsistentHashingMode) Fetch(ctx context.Context, url string) (io.Reader, int64, string, error) { logger := logging.GetLogger() - parsed, err := url.Parse(urlString) + var contentType string + + parsed, err := neturl.Parse(url) if err != nil { - return nil, -1, err + return nil, -1, contentType, err } shouldContinue := false if prefixes, ok := m.CacheableURIPrefixes[parsed.Host]; ok { @@ -97,10 +99,10 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io // Use our fallback mode if we're not downloading from a consistent-hashing enabled domain if !shouldContinue { logger.Debug(). - Str("url", urlString). + Str("url", url). Str("reason", fmt.Sprintf("consistent hashing not enabled for %s", parsed.Host)). Msg("fallback strategy") - return m.FallbackStrategy.Fetch(ctx, urlString) + return m.FallbackStrategy.Fetch(ctx, url) } br := newBufferedReader(m.minChunkSize()) @@ -109,7 +111,7 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io m.sem.Go(func() error { defer close(firstReqResultCh) defer br.done() - firstChunkResp, err := m.DoRequest(ctx, 0, m.minChunkSize()-1, urlString) + firstChunkResp, err := m.DoRequest(ctx, 0, m.minChunkSize()-1, url) if err != nil { br.err = err firstReqResultCh <- firstReqResult{err: err} @@ -122,7 +124,10 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io firstReqResultCh <- firstReqResult{err: err} return err } - firstReqResultCh <- firstReqResult{fileSize: fileSize} + firstReqResultCh <- firstReqResult{ + fileSize: fileSize, + contentType: firstChunkResp.Header.Get("Content-Type"), + } return br.downloadBody(firstChunkResp) }) @@ -138,19 +143,19 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io if errors.Is(firstReqResult.err, client.ErrStrategyFallback) { // TODO(morgan): we should indicate the fallback strategy we're using in the logs logger.Info(). - Str("url", urlString). + Str("url", url). Str("type", "file"). Err(err). Msg("consistent hash fallback") - return m.FallbackStrategy.Fetch(ctx, urlString) + return m.FallbackStrategy.Fetch(ctx, url) } - return nil, -1, firstReqResult.err + return nil, -1, contentType, firstReqResult.err } fileSize := firstReqResult.fileSize if fileSize <= m.minChunkSize() { // we only need a single chunk: just download it and finish - return br, fileSize, nil + return br, fileSize, contentType, nil } totalSlices := fileSize / m.SliceSize @@ -175,7 +180,7 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io readersCh := make(chan io.Reader, m.maxConcurrency()+1) readersCh <- br - logger.Debug().Str("url", urlString). + logger.Debug().Str("url", url). Int64("size", fileSize). Int("concurrency", m.maxConcurrency()). Ints64("chunks_per_slice", chunksPerSlice). @@ -217,7 +222,7 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io m.sem.Go(func() error { defer br.done() logger.Debug().Int64("start", chunkStart).Int64("end", chunkEnd).Msg("starting request") - resp, err := m.DoRequest(ctx, chunkStart, chunkEnd, urlString) + resp, err := m.DoRequest(ctx, chunkStart, chunkEnd, url) if err != nil { // in the case that an error indicating an issue with the cache server, networking, etc is returned, // this will use the fallback strategy. This is a case where the whole file will perform the fall-back @@ -225,11 +230,11 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io if errors.Is(err, client.ErrStrategyFallback) { // TODO(morgan): we should indicate the fallback strategy we're using in the logs logger.Info(). - Str("url", urlString). + Str("url", url). Str("type", "chunk"). Err(err). Msg("consistent hash fallback") - resp, err = m.FallbackStrategy.DoRequest(ctx, chunkStart, chunkEnd, urlString) + resp, err = m.FallbackStrategy.DoRequest(ctx, chunkStart, chunkEnd, url) } if err != nil { br.err = err @@ -245,7 +250,7 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io } }) - return newChanMultiReader(readersCh), fileSize, nil + return newChanMultiReader(readersCh), fileSize, contentType, nil } func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64, urlString string) (*http.Response, error) { @@ -307,7 +312,7 @@ func (m *ConsistentHashingMode) rewriteRequestToCacheHost(req *http.Request, sta } if m.CacheUsePathProxy { // prepend the hostname to the start of the path. The consistent-hash nodes will use this to determine the proxy - newPath, err := url.JoinPath(strings.ToLower(req.URL.Host), req.URL.Path) + newPath, err := neturl.JoinPath(strings.ToLower(req.URL.Host), req.URL.Path) if err != nil { return -1, err } diff --git a/pkg/download/consistent_hashing_test.go b/pkg/download/consistent_hashing_test.go index 6d26250..a126163 100644 --- a/pkg/download/consistent_hashing_test.go +++ b/pkg/download/consistent_hashing_test.go @@ -266,7 +266,7 @@ func TestConsistentHashing(t *testing.T) { require.NoError(t, err) assert.Equal(t, tc.numCacheHosts, len(strategy.Options.CacheHosts)) - reader, _, err := strategy.Fetch(ctx, "http://test.replicate.com/hello.txt") + reader, _, _, err := strategy.Fetch(ctx, "http://test.replicate.com/hello.txt") require.NoError(t, err) bytes, err := io.ReadAll(reader) require.NoError(t, err) @@ -317,7 +317,7 @@ func TestConsistentHashingPathBased(t *testing.T) { require.NoError(t, err) assert.Equal(t, tc.numCacheHosts, len(strategy.Options.CacheHosts)) - reader, _, err := strategy.Fetch(ctx, fmt.Sprintf("http://%s/hello.txt", hostname)) + reader, _, _, err := strategy.Fetch(ctx, fmt.Sprintf("http://%s/hello.txt", hostname)) require.NoError(t, err) bytes, err := io.ReadAll(reader) require.NoError(t, err) @@ -348,7 +348,7 @@ func TestConsistentHashRetries(t *testing.T) { strategy, err := download.GetConsistentHashingMode(opts) require.NoError(t, err) - reader, _, err := strategy.Fetch(ctx, "http://fake.replicate.delivery/hello.txt") + reader, _, _, err := strategy.Fetch(ctx, "http://fake.replicate.delivery/hello.txt") require.NoError(t, err) bytes, err := io.ReadAll(reader) require.NoError(t, err) @@ -383,7 +383,7 @@ func TestConsistentHashRetriesMissingHostname(t *testing.T) { strategy, err := download.GetConsistentHashingMode(opts) require.NoError(t, err) - reader, _, err := strategy.Fetch(ctx, "http://fake.replicate.delivery/hello.txt") + reader, _, _, err := strategy.Fetch(ctx, "http://fake.replicate.delivery/hello.txt") require.NoError(t, err) bytes, err := io.ReadAll(reader) require.NoError(t, err) @@ -417,7 +417,7 @@ func TestConsistentHashRetriesTwoHosts(t *testing.T) { strategy, err := download.GetConsistentHashingMode(opts) require.NoError(t, err) - reader, _, err := strategy.Fetch(ctx, "http://testing.replicate.delivery/hello.txt") + reader, _, _, err := strategy.Fetch(ctx, "http://testing.replicate.delivery/hello.txt") require.NoError(t, err) bytes, err := io.ReadAll(reader) require.NoError(t, err) @@ -444,7 +444,7 @@ func TestConsistentHashingHasFallback(t *testing.T) { strategy, err := download.GetConsistentHashingMode(opts) require.NoError(t, err) - reader, _, err := strategy.Fetch(ctx, "http://fake.replicate.delivery/hello.txt") + reader, _, _, err := strategy.Fetch(ctx, "http://fake.replicate.delivery/hello.txt") require.NoError(t, err) bytes, err := io.ReadAll(reader) require.NoError(t, err) @@ -471,9 +471,9 @@ type testStrategy struct { mut sync.Mutex } -func (s *testStrategy) Fetch(ctx context.Context, url string) (io.Reader, int64, error) { +func (s *testStrategy) Fetch(ctx context.Context, url string) (io.Reader, int64, string, error) { s.fetchCalledCount++ - return io.NopCloser(strings.NewReader("00")), -1, nil + return io.NopCloser(strings.NewReader("00")), -1, "", nil } func (s *testStrategy) DoRequest(ctx context.Context, start, end int64, url string) (*http.Response, error) { @@ -541,7 +541,7 @@ func TestConsistentHashingFileFallback(t *testing.T) { strategy.FallbackStrategy = fallbackStrategy urlString := "http://fake.replicate.delivery/hello.txt" - _, _, err = strategy.Fetch(ctx, urlString) + _, _, _, err = strategy.Fetch(ctx, urlString) if tc.expectedError != nil { assert.ErrorIs(t, err, tc.expectedError) } @@ -603,7 +603,7 @@ func TestConsistentHashingChunkFallback(t *testing.T) { strategy.FallbackStrategy = fallbackStrategy urlString := "http://fake.replicate.delivery/hello.txt" - out, _, err := strategy.Fetch(ctx, urlString) + out, _, _, err := strategy.Fetch(ctx, urlString) assert.ErrorIs(t, err, tc.expectedError) if err == nil { // eagerly read the whole output reader to force all the diff --git a/pkg/download/strategy.go b/pkg/download/strategy.go index a430dd2..c228114 100644 --- a/pkg/download/strategy.go +++ b/pkg/download/strategy.go @@ -13,7 +13,7 @@ type Strategy interface { // Fetch retrieves the content from a given URL and returns it as an io.Reader along with the file size. // If an error occurs during the process, it returns nil for the reader, 0 for the fileSize, and the error itself. // This is the primary method that should be called to initiate a download of a file. - Fetch(ctx context.Context, url string) (result io.Reader, fileSize int64, err error) + Fetch(ctx context.Context, url string) (result io.Reader, fileSize int64, contentType string, err error) // DoRequest sends an HTTP GET request with a specified range of bytes to the given URL using the provided context. // It returns the HTTP response and any error encountered during the request. It is intended that Fetch calls DoRequest diff --git a/pkg/pget.go b/pkg/pget.go index e4a8955..723c810 100644 --- a/pkg/pget.go +++ b/pkg/pget.go @@ -40,14 +40,14 @@ func (g *Getter) DownloadFile(ctx context.Context, url string, dest string) (int } logger := logging.GetLogger() downloadStartTime := time.Now() - buffer, fileSize, err := g.Downloader.Fetch(ctx, url) + buffer, fileSize, contentType, err := g.Downloader.Fetch(ctx, url) if err != nil { return fileSize, 0, err } // downloadElapsed := time.Since(downloadStartTime) // writeStartTime := time.Now() - err = g.Consumer.Consume(buffer, dest, fileSize) + err = g.Consumer.Consume(buffer, dest, fileSize, contentType) if err != nil { return fileSize, 0, fmt.Errorf("error writing file: %w", err) } From fbab194bb7b4ee77a0f174536578d3230cf23048 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Wed, 14 Feb 2024 23:48:46 -0800 Subject: [PATCH 09/12] Enable Multifile to extract tar and zip Multifile can now extract tar and zip files based upon the content-type. The -u and -t flags for multifile command control unzip and untar capabilities respectively. --- cmd/multifile/consumer.go | 39 ++++++++++++++++++++++++++++++++++++++ cmd/multifile/multifile.go | 39 ++++++++++++++++++++++++++++---------- pkg/config/config.go | 4 ++++ 3 files changed, 72 insertions(+), 10 deletions(-) create mode 100644 cmd/multifile/consumer.go diff --git a/cmd/multifile/consumer.go b/cmd/multifile/consumer.go new file mode 100644 index 0000000..0710a8f --- /dev/null +++ b/cmd/multifile/consumer.go @@ -0,0 +1,39 @@ +package multifile + +import ( + "io" + + "github.com/replicate/pget/pkg/config" + "github.com/replicate/pget/pkg/consumer" +) + +type MultiConsumer struct { + consumerMap map[string]consumer.Consumer + defaultConsumer consumer.Consumer +} + +var _ consumer.Consumer = &MultiConsumer{} + +func (f MultiConsumer) Consume(reader io.Reader, destPath string, fileSize int64, contentType string) error { + if c, ok := f.consumerMap[contentType]; ok { + return c.Consume(reader, destPath, fileSize, contentType) + } + return f.defaultConsumer.Consume(reader, destPath, fileSize, contentType) +} + +func (f MultiConsumer) EnableOverwrite() { + f.defaultConsumer.EnableOverwrite() + for _, c := range f.consumerMap { + c.EnableOverwrite() + } +} + +func (f MultiConsumer) addConsumer(contentType, consumerName string) error { + // TODO: Consider making this check content-type instead of just file extension + c, err := config.GetConsumerByName(consumerName) + if err != nil { + return err + } + f.consumerMap[contentType] = c + return nil +} diff --git a/cmd/multifile/multifile.go b/cmd/multifile/multifile.go index 1911fd0..7608921 100644 --- a/cmd/multifile/multifile.go +++ b/cmd/multifile/multifile.go @@ -37,6 +37,11 @@ const multifileExamples = ` cat multifile.txt | pget multifile - ` +const ( + OptUnzip = "unzip" + OptTarExtract = "tar" +) + // test seam type Getter interface { DownloadFile(ctx context.Context, url string, dest string) (int64, time.Duration, error) @@ -53,6 +58,9 @@ func GetCommand() *cobra.Command { Example: multifileExamples, } + cmd.Flags().BoolP(OptUnzip, "u", false, "Extract .zip files in multifile mode") + cmd.Flags().BoolP(OptTarExtract, "t", false, "Extract .tar files in multifile mode") + err := viper.BindPFlags(cmd.PersistentFlags()) if err != nil { fmt.Println(err) @@ -63,15 +71,7 @@ func GetCommand() *cobra.Command { } func multifilePreRunE(cmd *cobra.Command, args []string) error { - if viper.GetBool(config.OptExtract) { - return fmt.Errorf("cannot use --extract with multifile mode") - } - if viper.GetString(config.OptOutputConsumer) == config.ConsumerTarExtractor { - return fmt.Errorf("cannot use --output-consumer tar-extractor with multifile mode") - } - if viper.GetString(config.OptOutputConsumer) == config.ConsumerZipExtractor { - return fmt.Errorf("cannot use --output-consumer unzip with multifile mode") - } + // Add any pre-run checks that may return an error here. return nil } @@ -129,11 +129,30 @@ func multifileExecute(ctx context.Context, manifest pget.Manifest) error { MaxConcurrentFiles: maxConcurrentFiles(), } - consumer, err := config.GetConsumer() + configConsumer, err := config.GetConsumer() if err != nil { return fmt.Errorf("error getting consumer: %w", err) } + consumer := MultiConsumer{ + defaultConsumer: configConsumer, + } + + // Handle zip extraction if unzip flag is set + if viper.GetBool(OptUnzip) { + if err := consumer.addConsumer("application/zip", config.ConsumerZipExtractor); err != nil { + return err + } + } + + // Handle tar extraction if tar flag is set + if viper.GetBool(OptUnzip) { + if err := consumer.addConsumer("application/x-tar", config.ConsumerTarExtractor); err != nil { + return err + } + } + + // Enable overwrite if the force flag is set if viper.GetBool(config.OptForce) { consumer.EnableOverwrite() } diff --git a/pkg/config/config.go b/pkg/config/config.go index dbe485e..3ebccae 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -152,6 +152,10 @@ func ResolveOverridesToMap(resolveOverrides []string) (map[string]string, error) // calls viper.GetString(OptExtract) internally. func GetConsumer() (consumer.Consumer, error) { consumerName := viper.GetString(OptOutputConsumer) + return GetConsumerByName(consumerName) +} + +func GetConsumerByName(consumerName string) (consumer.Consumer, error) { switch consumerName { case ConsumerFile: return &consumer.FileWriter{}, nil From 27a505471b72c6f030bef7974cd9b73f63f22bd8 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Mon, 19 Feb 2024 12:10:31 -0800 Subject: [PATCH 10/12] Fix unzip * MultiReader was not blocking on the buffered reader ready signal * Unzip now joins the path name to the target instead of using '+' incorrectly --- pkg/download/multi_reader.go | 13 ++++++++++++- pkg/extract/zip.go | 11 +++++------ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/pkg/download/multi_reader.go b/pkg/download/multi_reader.go index 7720386..c3047ea 100644 --- a/pkg/download/multi_reader.go +++ b/pkg/download/multi_reader.go @@ -2,6 +2,7 @@ package download import ( "errors" + "fmt" "io" ) @@ -47,12 +48,22 @@ func (m *multiReader) ReadAt(p []byte, off int64) (n int, err error) { return 0, ErrInvalidOffset } for i, r := range m.readers { - readerBytes += int64(r.len()) + readerBytes += r.len() // if offset is less than the bytes found in the reader slice to this point, // we can start reading from this reader. if off < readerBytes { + //innerOffset 1024 off 2301808284 readerBytes 2301809308 r.len() 47621039 + //innerOffset 66560 off 2301742748 readerBytes 2301809308 r.len() 47621039 + //panic: runtime error: slice bounds out of range [66560:15095] + // // Calculate the offset within the reader innerOffset := off - (readerBytes - r.len()) + if innerOffset > r.len() { + return 0, fmt.Errorf("innerOffset %d off %d readerBytes %d r.len() %d", innerOffset, off, readerBytes, r.len()) + } + //innerOffset := off - (readerBytes - r.len()) + //fmt.Println("innerOffset", innerOffset, "off", off, "readerBytes", readerBytes, "r.len()", r.len()) + <-r.ready n = copy(p, r.buf.Bytes()[innerOffset:]) if i == len(m.readers)-1 && n < len(p) { // We are at the last reader and the buffer is not full diff --git a/pkg/extract/zip.go b/pkg/extract/zip.go index a991c5a..b4015c4 100644 --- a/pkg/extract/zip.go +++ b/pkg/extract/zip.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "os" + "path" "path/filepath" ) @@ -30,12 +31,10 @@ func ZipFile(reader io.ReaderAt, destPath string, size int64, overwrite bool) er } func handleFileFromZip(file *zip.File, outputDir string, overwrite bool) error { - target := outputDir + file.Name - targetDir := filepath.Dir(target) if file.FileInfo().IsDir() { - return extractDir(file, targetDir) + return extractDir(file, outputDir) } else if file.FileInfo().Mode().IsRegular() { - return extractFile(file, targetDir, overwrite) + return extractFile(file, outputDir, overwrite) } else { return fmt.Errorf("unsupported file type (not dir or regular): %s (%d)", file.Name, file.FileInfo().Mode().Type()) } @@ -43,7 +42,7 @@ func handleFileFromZip(file *zip.File, outputDir string, overwrite bool) error { } func extractDir(file *zip.File, outputDir string) error { - target := outputDir + file.Name + target := path.Join(outputDir, file.Name) perms := file.Mode().Perm() &^ os.ModeSetuid &^ os.ModeSetgid &^ os.ModeSticky info, err := os.Stat(target) if err == nil && !info.IsDir() { @@ -67,7 +66,7 @@ func extractDir(file *zip.File, outputDir string) error { } func extractFile(file *zip.File, outputDir string, overwrite bool) error { - target := outputDir + file.Name + target := path.Join(outputDir, file.Name) targetDir := filepath.Dir(target) err := os.MkdirAll(targetDir, 0755) if err != nil { From c21731c6ca0ed956444fd048247c7444da50fc48 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Mon, 19 Feb 2024 12:34:18 -0800 Subject: [PATCH 11/12] Update: Tar/Unzip * Implement `-u` short hand for `--unzip` * `--unzip` option for invoking the unzip consumer added * multifile mode utilizes `--unzip/-u` and `--extract/-x` for tar and unzip modes * Improved Debugging logs for tar and unzip * Update README --- README.md | 22 ++++++++++++++-------- cmd/multifile/multifile.go | 12 ++---------- cmd/root/root.go | 14 +++++++++++++- pkg/config/optnames.go | 1 + pkg/extract/tar.go | 23 +++++++++++++++++++++++ pkg/extract/zip.go | 16 +++++++++++++++- 6 files changed, 68 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index da97d1d..868fd6a 100644 --- a/README.md +++ b/README.md @@ -49,11 +49,6 @@ This builds a static binary that can work inside containers. - -c concurrency: The number of concurrent downloads. Default is 4 times the number of cores. - -x: Extract the tar file after download. If not set, the downloaded file will be saved as is. -#### Default-Mode Command-Line Options -- `-x`, `--extract` - - Extract archive after download - - Type: `bool` - - Default: `false` #### Example @@ -101,9 +96,9 @@ https://example.com/music.mp3 /local/path/to/music.mp3 ### Global Command-Line Options - `--max-chunks` - - Maximum number of chunks for downloading a given file - - Type: `Integer` - - Default: `4 * runtime.NumCPU()` + - Maximum number of chunks for downloading a given file + - Type: `Integer` + - Default: `4 * runtime.NumCPU()` - `--connect-timeout` - Timeout for establishing a connection, format is , e.g. 10s - Type: `Duration` @@ -131,6 +126,17 @@ https://example.com/music.mp3 /local/path/to/music.mp3 - Verbose mode (equivalent to `--log-level debug`) - Type: `bool` - Default: `false` +- `-x`, `--extract` + - Extract archive after download + - Type: `bool` + - Default: `false` + - In multifile mode this option will only extract tar files where `content-type` header is `application/x-tar`. This option may be combined with `--unzip` only in multifile mode. +- `-u`, `--unzip` + - Unzip archive after download + - Type: `bool` + - Default: `false` + - In multifile mode this option will only extract tar files where `content-type` header is `application/zip`. This option may be combined with `--extract` only in multifile mode. + #### Deprecated - `--concurrency` (deprecated, use `--max-chunks` instead) diff --git a/cmd/multifile/multifile.go b/cmd/multifile/multifile.go index 7608921..c51d5ff 100644 --- a/cmd/multifile/multifile.go +++ b/cmd/multifile/multifile.go @@ -37,11 +37,6 @@ const multifileExamples = ` cat multifile.txt | pget multifile - ` -const ( - OptUnzip = "unzip" - OptTarExtract = "tar" -) - // test seam type Getter interface { DownloadFile(ctx context.Context, url string, dest string) (int64, time.Duration, error) @@ -58,9 +53,6 @@ func GetCommand() *cobra.Command { Example: multifileExamples, } - cmd.Flags().BoolP(OptUnzip, "u", false, "Extract .zip files in multifile mode") - cmd.Flags().BoolP(OptTarExtract, "t", false, "Extract .tar files in multifile mode") - err := viper.BindPFlags(cmd.PersistentFlags()) if err != nil { fmt.Println(err) @@ -139,14 +131,14 @@ func multifileExecute(ctx context.Context, manifest pget.Manifest) error { } // Handle zip extraction if unzip flag is set - if viper.GetBool(OptUnzip) { + if viper.GetBool(config.OptUnzip) { if err := consumer.addConsumer("application/zip", config.ConsumerZipExtractor); err != nil { return err } } // Handle tar extraction if tar flag is set - if viper.GetBool(OptUnzip) { + if viper.GetBool(config.OptUnzip) { if err := consumer.addConsumer("application/x-tar", config.ConsumerTarExtractor); err != nil { return err } diff --git a/cmd/root/root.go b/cmd/root/root.go index 2d7637d..ce5803a 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -50,12 +50,12 @@ func GetCommand() *cobra.Command { Long: rootLongDesc, PersistentPreRunE: rootPersistentPreRunEFunc, PersistentPostRunE: rootPersistentPostRunEFunc, + PreRunE: rootPreRunEFunc, PreRun: rootCmdPreRun, RunE: runRootCMD, Args: cobra.ExactArgs(2), Example: ` pget https://example.com/file.tar ./target-dir`, } - cmd.Flags().BoolP(config.OptExtract, "x", false, "OptExtract archive after download") cmd.SetUsageTemplate(cli.UsageTemplate) config.ViperInit() if err := persistentFlags(cmd); err != nil { @@ -119,6 +119,13 @@ func rootPersistentPostRunEFunc(cmd *cobra.Command, args []string) error { return nil } +func rootPreRunEFunc(cmd *cobra.Command, args []string) error { + if viper.GetBool(config.OptExtract) && viper.GetBool(config.OptUnzip) { + return fmt.Errorf("cannot use --unzip and --extract together") + } + return nil +} + func persistentFlags(cmd *cobra.Command) error { // Persistent Flags (applies to all commands/subcommands) cmd.PersistentFlags().IntVarP(&concurrency, config.OptConcurrency, "c", runtime.GOMAXPROCS(0)*4, "Maximum number of concurrent downloads/maximum number of chunks for a given file") @@ -134,6 +141,8 @@ func persistentFlags(cmd *cobra.Command) error { cmd.PersistentFlags().Int(config.OptMaxConnPerHost, 40, "Maximum number of (global) concurrent connections per host") cmd.PersistentFlags().StringP(config.OptOutputConsumer, "o", "file", "Output Consumer (file, tar, null)") cmd.PersistentFlags().String(config.OptPIDFile, defaultPidFilePath(), "PID file path") + cmd.PersistentFlags().BoolP(config.OptExtract, "x", false, "Extract tar archive after download") + cmd.PersistentFlags().BoolP(config.OptUnzip, "u", false, "Unzip archive after download") if err := config.AddFlagAlias(cmd, config.OptConcurrency, config.OptMaxChunks); err != nil { return err @@ -169,6 +178,9 @@ func rootCmdPreRun(cmd *cobra.Command, args []string) { if currentConsumer != config.ConsumerFile && currentConsumer != config.ConsumerTarExtractor { log.Warn().Msg("Tar Extract Enabled, overriding output consumer to `tar-extractor`") } + if currentConsumer != config.ConsumerFile && currentConsumer != config.ConsumerZipExtractor { + log.Warn().Msg("Unzip Enabled, overriding output consumer to `unzip`") + } viper.Set(config.OptOutputConsumer, config.ConsumerTarExtractor) } } diff --git a/pkg/config/optnames.go b/pkg/config/optnames.go index 4409e18..6f72a3f 100644 --- a/pkg/config/optnames.go +++ b/pkg/config/optnames.go @@ -24,5 +24,6 @@ const ( OptPIDFile = "pid-file" OptResolve = "resolve" OptRetries = "retries" + OptUnzip = "unzip" OptVerbose = "verbose" ) diff --git a/pkg/extract/tar.go b/pkg/extract/tar.go index aa74b25..07b670b 100644 --- a/pkg/extract/tar.go +++ b/pkg/extract/tar.go @@ -27,6 +27,8 @@ func TarFile(reader io.Reader, destDir string, overwrite bool) error { logger.Debug(). Str("extractor", "tar"). Str("status", "starting"). + Bool("overwrite", overwrite). + Str("destDir", destDir). Msg("Extract") for { header, err := tarReader.Next() @@ -45,6 +47,10 @@ func TarFile(reader io.Reader, destDir string, overwrite bool) error { switch header.Typeflag { case tar.TypeDir: + logger.Debug(). + Str("target", target). + Str("perms", fmt.Sprintf("%o", header.Mode)). + Msg("Tar: Directory") if err := os.MkdirAll(target, os.FileMode(header.Mode)); err != nil { return err } @@ -53,6 +59,10 @@ func TarFile(reader io.Reader, destDir string, overwrite bool) error { if overwrite { openFlags |= os.O_TRUNC } + logger.Debug(). + Str("target", target). + Str("perms", fmt.Sprintf("%o", header.Mode)). + Msg("Tar: File") targetFile, err := os.OpenFile(target, openFlags, os.FileMode(header.Mode)) if err != nil { return err @@ -66,6 +76,10 @@ func TarFile(reader io.Reader, destDir string, overwrite bool) error { } case tar.TypeSymlink, tar.TypeLink: // Defer creation of + logger.Debug().Str("link_type", string(header.Typeflag)). + Str("old_name", header.Linkname). + Str("new_name", target). + Msg("Tar: (Defer) Link") links = append(links, &link{linkType: header.Typeflag, oldName: header.Linkname, newName: target}) default: return fmt.Errorf("unsupported file type for %s, typeflag %s", header.Name, string(header.Typeflag)) @@ -86,6 +100,7 @@ func TarFile(reader io.Reader, destDir string, overwrite bool) error { } func createLinks(links []*link, destDir string, overwrite bool) error { + logger := logging.GetLogger() for _, link := range links { targetDir := filepath.Dir(link.newName) if err := os.MkdirAll(targetDir, 0755); err != nil { @@ -94,10 +109,18 @@ func createLinks(links []*link, destDir string, overwrite bool) error { switch link.linkType { case tar.TypeLink: oldPath := filepath.Join(destDir, link.oldName) + logger.Debug(). + Str("old_path", oldPath). + Str("new_path", link.newName). + Msg("Tar: creating hard link") if err := createHardLink(oldPath, link.newName, overwrite); err != nil { return fmt.Errorf("error creating hard link from %s to %s: %w", oldPath, link.newName, err) } case tar.TypeSymlink: + logger.Debug(). + Str("old_path", link.oldName). + Str("new_path", link.newName). + Msg("Tar: creating symlink") if err := createSymlink(link.oldName, link.newName, overwrite); err != nil { return fmt.Errorf("error creating symlink from %s to %s: %w", link.oldName, link.newName, err) } diff --git a/pkg/extract/zip.go b/pkg/extract/zip.go index b4015c4..f6f5373 100644 --- a/pkg/extract/zip.go +++ b/pkg/extract/zip.go @@ -7,15 +7,24 @@ import ( "os" "path" "path/filepath" + + "github.com/replicate/pget/pkg/logging" ) // ZipFile extracts a zip file to the given destination path. func ZipFile(reader io.ReaderAt, destPath string, size int64, overwrite bool) error { + logger := logging.GetLogger() err := os.MkdirAll(destPath, 0755) if err != nil { return fmt.Errorf("error creating destination directory: %w", err) } + logger.Debug(). + Str("extractor", "zip"). + Str("status", "starting"). + Bool("overwrite", overwrite). + Str("destDir", destPath). + Msg("Extract") zipReader, err := zip.NewReader(reader, size) if err != nil { return fmt.Errorf("error creating zip reader: %w", err) @@ -42,8 +51,11 @@ func handleFileFromZip(file *zip.File, outputDir string, overwrite bool) error { } func extractDir(file *zip.File, outputDir string) error { + logger := logging.GetLogger() target := path.Join(outputDir, file.Name) + // Strip setuid/setgid/sticky bits perms := file.Mode().Perm() &^ os.ModeSetuid &^ os.ModeSetgid &^ os.ModeSticky + logger.Debug().Str("target", target).Str("perms", fmt.Sprintf("%o", perms)).Msg("Unzip: directory") info, err := os.Stat(target) if err == nil && !info.IsDir() { return fmt.Errorf("error creating directory: %s already exists and is not a directory", target) @@ -66,6 +78,7 @@ func extractDir(file *zip.File, outputDir string) error { } func extractFile(file *zip.File, outputDir string, overwrite bool) error { + logger := logging.GetLogger() target := path.Join(outputDir, file.Name) targetDir := filepath.Dir(target) err := os.MkdirAll(targetDir, 0755) @@ -85,8 +98,9 @@ func extractFile(file *zip.File, outputDir string, overwrite bool) error { if overwrite { openFlags |= os.O_TRUNC } - // Do not apply setuid/gid/sticky bits. + // Strip setuid/gid/sticky bits. perms := file.Mode().Perm() &^ os.ModeSetuid &^ os.ModeSetgid &^ os.ModeSticky + logger.Debug().Str("target", target).Str("perms", fmt.Sprintf("%o", perms)).Msg("Unzip: file") out, err := os.OpenFile(target, openFlags, perms) if err != nil { return fmt.Errorf("error creating file: %w", err) From 3e026208cc6411dec3fa3e0f5bbb008dd4e6c85f Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Tue, 20 Feb 2024 10:28:44 -0800 Subject: [PATCH 12/12] PreRun and PreRunE are not both run. PreRun and PreRunE are mutually exclusive. This moves the extraction and unzip consumer handling via short-hand options to PreRunE where we validate that -x and -u are not consurrently used. --- cmd/root/root.go | 42 +++++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/cmd/root/root.go b/cmd/root/root.go index ce5803a..5421ad3 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -18,6 +18,7 @@ import ( "github.com/replicate/pget/pkg/client" "github.com/replicate/pget/pkg/config" "github.com/replicate/pget/pkg/download" + "github.com/replicate/pget/pkg/logging" ) const rootLongDesc = ` @@ -51,7 +52,6 @@ func GetCommand() *cobra.Command { PersistentPreRunE: rootPersistentPreRunEFunc, PersistentPostRunE: rootPersistentPostRunEFunc, PreRunE: rootPreRunEFunc, - PreRun: rootCmdPreRun, RunE: runRootCMD, Args: cobra.ExactArgs(2), Example: ` pget https://example.com/file.tar ./target-dir`, @@ -120,9 +120,26 @@ func rootPersistentPostRunEFunc(cmd *cobra.Command, args []string) error { } func rootPreRunEFunc(cmd *cobra.Command, args []string) error { + logger := logging.GetLogger() + if viper.GetBool(config.OptExtract) && viper.GetBool(config.OptUnzip) { return fmt.Errorf("cannot use --unzip and --extract together") } + + currentConsumer := viper.GetString(config.OptOutputConsumer) + + if viper.GetBool(config.OptExtract) { + if currentConsumer != config.ConsumerFile && currentConsumer != config.ConsumerTarExtractor { + logger.Warn().Msg("Tar Extract Enabled, overriding output consumer to `tar-extractor`") + } + viper.Set(config.OptOutputConsumer, config.ConsumerTarExtractor) + } + if viper.GetBool(config.OptUnzip) { + if currentConsumer != config.ConsumerFile && currentConsumer != config.ConsumerZipExtractor { + logger.Warn().Msg("Unzip Enabled, overriding output consumer to `unzip`") + } + viper.Set(config.OptOutputConsumer, config.ConsumerZipExtractor) + } return nil } @@ -172,19 +189,6 @@ func hideAndDeprecateFlags(cmd *cobra.Command) error { } -func rootCmdPreRun(cmd *cobra.Command, args []string) { - if viper.GetBool(config.OptExtract) { - currentConsumer := viper.GetString(config.OptOutputConsumer) - if currentConsumer != config.ConsumerFile && currentConsumer != config.ConsumerTarExtractor { - log.Warn().Msg("Tar Extract Enabled, overriding output consumer to `tar-extractor`") - } - if currentConsumer != config.ConsumerFile && currentConsumer != config.ConsumerZipExtractor { - log.Warn().Msg("Unzip Enabled, overriding output consumer to `unzip`") - } - viper.Set(config.OptOutputConsumer, config.ConsumerTarExtractor) - } -} - func runRootCMD(cmd *cobra.Command, args []string) error { // After we run through the PreRun functions we want to silence usage from being printed // on all errors @@ -254,11 +258,11 @@ func rootExecute(ctx context.Context, urlString, dest string) error { Consumer: consumer, } - if viper.GetBool(config.OptExtract) { - // TODO: decide what to do when --output is set *and* --extract is set - log.Debug().Msg("Tar Extract Enabled") - viper.Set(config.OptOutputConsumer, config.ConsumerTarExtractor) - } + //if viper.GetBool(config.OptExtract) { + // // TODO: decide what to do when --output is set *and* --extract is set + // log.Debug().Msg("Tar Extract Enabled") + // viper.Set(config.OptOutputConsumer, config.ConsumerTarExtractor) + //} // TODO DRY this if srvName := config.GetCacheSRV(); srvName != "" {