Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement zip extract #158

Closed
wants to merge 12 commits into from
22 changes: 14 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,6 @@ This builds a static binary that can work inside containers.
- -c concurrency: The number of concurrent downloads. Default is 4 times the number of cores.
- -x: Extract the tar file after download. If not set, the downloaded file will be saved as is.

#### Default-Mode Command-Line Options
- `-x`, `--extract`
- Extract archive after download
- Type: `bool`
- Default: `false`

#### Example

Expand Down Expand Up @@ -101,9 +96,9 @@ https://example.com/music.mp3 /local/path/to/music.mp3

### Global Command-Line Options
- `--max-chunks`
- Maximum number of chunks for downloading a given file
- Type: `Integer`
- Default: `4 * runtime.NumCPU()`
- Maximum number of chunks for downloading a given file
- Type: `Integer`
- Default: `4 * runtime.NumCPU()`
- `--connect-timeout`
- Timeout for establishing a connection, format is <number><unit>, e.g. 10s
- Type: `Duration`
Expand Down Expand Up @@ -131,6 +126,17 @@ https://example.com/music.mp3 /local/path/to/music.mp3
- Verbose mode (equivalent to `--log-level debug`)
- Type: `bool`
- Default: `false`
- `-x`, `--extract`
- Extract archive after download
- Type: `bool`
- Default: `false`
- In multifile mode this option will only extract tar files where `content-type` header is `application/x-tar`. This option may be combined with `--unzip` only in multifile mode.
- `-u`, `--unzip`
- Unzip archive after download
- Type: `bool`
- Default: `false`
- In multifile mode this option will only extract tar files where `content-type` header is `application/zip`. This option may be combined with `--extract` only in multifile mode.


#### Deprecated
- `--concurrency` (deprecated, use `--max-chunks` instead)
Expand Down
39 changes: 39 additions & 0 deletions cmd/multifile/consumer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package multifile

import (
"io"

"github.com/replicate/pget/pkg/config"
"github.com/replicate/pget/pkg/consumer"
)

type MultiConsumer struct {
consumerMap map[string]consumer.Consumer
defaultConsumer consumer.Consumer
}

var _ consumer.Consumer = &MultiConsumer{}

func (f MultiConsumer) Consume(reader io.Reader, destPath string, fileSize int64, contentType string) error {
if c, ok := f.consumerMap[contentType]; ok {
return c.Consume(reader, destPath, fileSize, contentType)
}
return f.defaultConsumer.Consume(reader, destPath, fileSize, contentType)
}

func (f MultiConsumer) EnableOverwrite() {
f.defaultConsumer.EnableOverwrite()
for _, c := range f.consumerMap {
c.EnableOverwrite()
}
}

func (f MultiConsumer) addConsumer(contentType, consumerName string) error {
// TODO: Consider making this check content-type instead of just file extension
c, err := config.GetConsumerByName(consumerName)
if err != nil {
return err
}
f.consumerMap[contentType] = c
return nil
}
32 changes: 25 additions & 7 deletions cmd/multifile/multifile.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,7 @@ func GetCommand() *cobra.Command {
}

func multifilePreRunE(cmd *cobra.Command, args []string) error {
if viper.GetBool(config.OptExtract) {
return fmt.Errorf("cannot use --extract with multifile mode")
}
if viper.GetString(config.OptOutputConsumer) == config.ConsumerTarExtractor {
return fmt.Errorf("cannot use --output-consumer tar-extractor with multifile mode")
}
// Add any pre-run checks that may return an error here.
return nil
}

Expand Down Expand Up @@ -126,11 +121,34 @@ func multifileExecute(ctx context.Context, manifest pget.Manifest) error {
MaxConcurrentFiles: maxConcurrentFiles(),
}

consumer, err := config.GetConsumer()
configConsumer, err := config.GetConsumer()
if err != nil {
return fmt.Errorf("error getting consumer: %w", err)
}

consumer := MultiConsumer{
defaultConsumer: configConsumer,
}

// Handle zip extraction if unzip flag is set
if viper.GetBool(config.OptUnzip) {
if err := consumer.addConsumer("application/zip", config.ConsumerZipExtractor); err != nil {
return err
}
}

// Handle tar extraction if tar flag is set
if viper.GetBool(config.OptUnzip) {
if err := consumer.addConsumer("application/x-tar", config.ConsumerTarExtractor); err != nil {
return err
}
}

// Enable overwrite if the force flag is set
if viper.GetBool(config.OptForce) {
consumer.EnableOverwrite()
}

getter := &pget.Getter{
Downloader: download.GetBufferMode(downloadOpts),
Consumer: consumer,
Expand Down
50 changes: 37 additions & 13 deletions cmd/root/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/replicate/pget/pkg/client"
"github.com/replicate/pget/pkg/config"
"github.com/replicate/pget/pkg/download"
"github.com/replicate/pget/pkg/logging"
)

const rootLongDesc = `
Expand Down Expand Up @@ -50,12 +51,11 @@ func GetCommand() *cobra.Command {
Long: rootLongDesc,
PersistentPreRunE: rootPersistentPreRunEFunc,
PersistentPostRunE: rootPersistentPostRunEFunc,
PreRun: rootCmdPreRun,
PreRunE: rootPreRunEFunc,
RunE: runRootCMD,
Args: cobra.ExactArgs(2),
Example: ` pget https://example.com/file.tar ./target-dir`,
}
cmd.Flags().BoolP(config.OptExtract, "x", false, "OptExtract archive after download")
cmd.SetUsageTemplate(cli.UsageTemplate)
config.ViperInit()
if err := persistentFlags(cmd); err != nil {
Expand Down Expand Up @@ -119,6 +119,30 @@ func rootPersistentPostRunEFunc(cmd *cobra.Command, args []string) error {
return nil
}

func rootPreRunEFunc(cmd *cobra.Command, args []string) error {
logger := logging.GetLogger()

if viper.GetBool(config.OptExtract) && viper.GetBool(config.OptUnzip) {
return fmt.Errorf("cannot use --unzip and --extract together")
}

currentConsumer := viper.GetString(config.OptOutputConsumer)

if viper.GetBool(config.OptExtract) {
if currentConsumer != config.ConsumerFile && currentConsumer != config.ConsumerTarExtractor {
logger.Warn().Msg("Tar Extract Enabled, overriding output consumer to `tar-extractor`")
}
viper.Set(config.OptOutputConsumer, config.ConsumerTarExtractor)
}
if viper.GetBool(config.OptUnzip) {
if currentConsumer != config.ConsumerFile && currentConsumer != config.ConsumerZipExtractor {
logger.Warn().Msg("Unzip Enabled, overriding output consumer to `unzip`")
}
viper.Set(config.OptOutputConsumer, config.ConsumerZipExtractor)
}
return nil
}

func persistentFlags(cmd *cobra.Command) error {
// Persistent Flags (applies to all commands/subcommands)
cmd.PersistentFlags().IntVarP(&concurrency, config.OptConcurrency, "c", runtime.GOMAXPROCS(0)*4, "Maximum number of concurrent downloads/maximum number of chunks for a given file")
Expand All @@ -134,6 +158,8 @@ func persistentFlags(cmd *cobra.Command) error {
cmd.PersistentFlags().Int(config.OptMaxConnPerHost, 40, "Maximum number of (global) concurrent connections per host")
cmd.PersistentFlags().StringP(config.OptOutputConsumer, "o", "file", "Output Consumer (file, tar, null)")
cmd.PersistentFlags().String(config.OptPIDFile, defaultPidFilePath(), "PID file path")
cmd.PersistentFlags().BoolP(config.OptExtract, "x", false, "Extract tar archive after download")
cmd.PersistentFlags().BoolP(config.OptUnzip, "u", false, "Unzip archive after download")

if err := config.AddFlagAlias(cmd, config.OptConcurrency, config.OptMaxChunks); err != nil {
return err
Expand Down Expand Up @@ -163,12 +189,6 @@ func hideAndDeprecateFlags(cmd *cobra.Command) error {

}

func rootCmdPreRun(cmd *cobra.Command, args []string) {
if viper.GetBool(config.OptExtract) {
viper.Set(config.OptOutputConsumer, config.ConsumerTarExtractor)
}
}

func runRootCMD(cmd *cobra.Command, args []string) error {
// After we run through the PreRun functions we want to silence usage from being printed
// on all errors
Expand Down Expand Up @@ -229,16 +249,20 @@ func rootExecute(ctx context.Context, urlString, dest string) error {
return err
}

if viper.GetBool(config.OptForce) {
consumer.EnableOverwrite()
}

getter := pget.Getter{
Downloader: download.GetBufferMode(downloadOpts),
Consumer: consumer,
}

if viper.GetBool(config.OptExtract) {
// TODO: decide what to do when --output is set *and* --extract is set
log.Debug().Msg("Tar Extract Enabled")
viper.Set(config.OptOutputConsumer, config.ConsumerTarExtractor)
}
//if viper.GetBool(config.OptExtract) {
// // TODO: decide what to do when --output is set *and* --extract is set
// log.Debug().Msg("Tar Extract Enabled")
// viper.Set(config.OptOutputConsumer, config.ConsumerTarExtractor)
//}

// TODO DRY this
if srvName := config.GetCacheSRV(); srvName != "" {
Expand Down
7 changes: 5 additions & 2 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (

"github.com/hashicorp/go-retryablehttp"

"github.com/replicate/pget/pkg/config"
"github.com/replicate/pget/pkg/logging"
"github.com/replicate/pget/pkg/version"
)
Expand All @@ -24,6 +23,10 @@ const (
retryMaxWait = 1250 * time.Millisecond
)

type ConsistentHashingStrategy struct{}

var ConsistentHashingStrategyKey ConsistentHashingStrategy

var ErrStrategyFallback = errors.New("fallback to next strategy")

// HTTPClient is a wrapper around http.Client that allows for limiting the number of concurrent connections per host
Expand Down Expand Up @@ -111,7 +114,7 @@ func RetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, err

// While type assertions are not ideal, alternatives are limited to adding custom data in the request
// or in the context. The context clearly isolates this data.
consistentHashing, ok := ctx.Value(config.ConsistentHashingStrategyKey).(bool)
consistentHashing, ok := ctx.Value(ConsistentHashingStrategyKey).(bool)
if ok && consistentHashing {
if fallbackError(err) {
return false, ErrStrategyFallback
Expand Down
3 changes: 1 addition & 2 deletions pkg/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/stretchr/testify/assert"

"github.com/replicate/pget/pkg/client"
"github.com/replicate/pget/pkg/config"
)

func TestGetSchemeHostKey(t *testing.T) {
Expand All @@ -24,7 +23,7 @@ func TestGetSchemeHostKey(t *testing.T) {

func TestRetryPolicy(t *testing.T) {
bgCtx := context.Background()
chCtx := context.WithValue(bgCtx, config.ConsistentHashingStrategyKey, true)
chCtx := context.WithValue(bgCtx, client.ConsistentHashingStrategyKey, true)
errContext, cancel := context.WithCancel(bgCtx)
cancel()

Expand Down
11 changes: 7 additions & 4 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,13 @@ const (
ConsumerFile = "file"
ConsumerTarExtractor = "tar-extractor"
ConsumerNull = "null"
ConsumerZipExtractor = "unzip"
)

var (
DefaultCacheURIPrefixes = []string{"https://weights.replicate.delivery"}
)

type ConsistentHashingStrategy struct{}

var ConsistentHashingStrategyKey ConsistentHashingStrategy

type DeprecatedFlag struct {
Flag string
Msg string
Expand Down Expand Up @@ -155,11 +152,17 @@ func ResolveOverridesToMap(resolveOverrides []string) (map[string]string, error)
// calls viper.GetString(OptExtract) internally.
func GetConsumer() (consumer.Consumer, error) {
consumerName := viper.GetString(OptOutputConsumer)
return GetConsumerByName(consumerName)
}

func GetConsumerByName(consumerName string) (consumer.Consumer, error) {
switch consumerName {
case ConsumerFile:
return &consumer.FileWriter{}, nil
case ConsumerTarExtractor:
return &consumer.TarExtractor{}, nil
case ConsumerZipExtractor:
return &consumer.ZipExtractor{}, nil
case ConsumerNull:
return &consumer.NullWriter{}, nil
default:
Expand Down
1 change: 1 addition & 0 deletions pkg/config/optnames.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@ const (
OptPIDFile = "pid-file"
OptResolve = "resolve"
OptRetries = "retries"
OptUnzip = "unzip"
OptVerbose = "verbose"
)
4 changes: 3 additions & 1 deletion pkg/consumer/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,7 @@ package consumer
import "io"

type Consumer interface {
Consume(reader io.Reader, destPath string) error
Consume(reader io.Reader, destPath string, fileSize int64, contentType string) error
// EnableOverwrite sets the overwrite flag for the consumer, allowing it to overwrite files if necessary/supported
EnableOverwrite()
}
9 changes: 7 additions & 2 deletions pkg/consumer/null.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@ import (
"io"
)

type NullWriter struct{}
type NullWriter struct {
}

var _ Consumer = &NullWriter{}

func (f *NullWriter) Consume(reader io.Reader, destPath string) error {
func (f *NullWriter) Consume(reader io.Reader, destPath string, fileSize int64, contentType string) error {
// io.Discard is explicitly designed to always succeed, ignore errors.
_, _ = io.Copy(io.Discard, reader)
return nil
}

func (f *NullWriter) EnableOverwrite() {
// no-op
}
12 changes: 9 additions & 3 deletions pkg/consumer/tar_extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,20 @@ import (
"github.com/replicate/pget/pkg/extract"
)

type TarExtractor struct{}
type TarExtractor struct {
overwrite bool
}

var _ Consumer = &TarExtractor{}

func (f *TarExtractor) Consume(reader io.Reader, destPath string) error {
err := extract.TarFile(reader, destPath)
func (f *TarExtractor) Consume(reader io.Reader, destPath string, fileSize int64, contentType string) error {
err := extract.TarFile(reader, destPath, f.overwrite)
if err != nil {
return fmt.Errorf("error extracting file: %w", err)
}
return nil
}

func (f *TarExtractor) EnableOverwrite() {
f.overwrite = true
}
Loading