Skip to content

Commit

Permalink
Merge pull request restic#5235 from MichaelEischer/refactor-ls-sorting
Browse files Browse the repository at this point in the history
Refactor ls sorting
  • Loading branch information
MichaelEischer authored Feb 5, 2025
2 parents 4104a8e + 6cc06e0 commit 9cdf91b
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 97 deletions.
140 changes: 90 additions & 50 deletions cmd/restic/cmd_ls.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ type LsOptions struct {
Recursive bool
HumanReadable bool
Ncdu bool
Sort string
Sort SortMode
Reverse bool
}

Expand All @@ -81,7 +81,7 @@ func init() {
flags.BoolVar(&lsOptions.Recursive, "recursive", false, "include files in subfolders of the listed directories")
flags.BoolVar(&lsOptions.HumanReadable, "human-readable", false, "print sizes in human readable format")
flags.BoolVar(&lsOptions.Ncdu, "ncdu", false, "output NCDU export format (pipe into 'ncdu -f -')")
flags.StringVarP(&lsOptions.Sort, "sort", "s", "name", "sort output by (name|size|time=mtime|atime|ctime|extension)")
flags.VarP(&lsOptions.Sort, "sort", "s", "sort output by (name|size|time=mtime|atime|ctime|extension)")
flags.BoolVar(&lsOptions.Reverse, "reverse", false, "reverse sorted output")
}

Expand Down Expand Up @@ -301,19 +301,13 @@ func runLs(ctx context.Context, opts LsOptions, gopts GlobalOptions, args []stri
if opts.Ncdu && gopts.JSON {
return errors.Fatal("only either '--json' or '--ncdu' can be specified")
}
if opts.Sort != "name" && opts.Ncdu {
if opts.Sort != SortModeName && opts.Ncdu {
return errors.Fatal("--sort and --ncdu are mutually exclusive")
}
if opts.Reverse && opts.Ncdu {
return errors.Fatal("--reverse and --ncdu are mutually exclusive")
}

sortMode := SortModeName
err := sortMode.Set(opts.Sort)
if err != nil {
return err
}

// extract any specific directories to walk
var dirs []string
if len(args) > 1 {
Expand Down Expand Up @@ -376,8 +370,6 @@ func runLs(ctx context.Context, opts LsOptions, gopts GlobalOptions, args []stri
}

var printer lsPrinter
collector := []toSortOutput{}
outputSort := sortMode != SortModeName || opts.Reverse

if gopts.JSON {
printer = &jsonLsPrinter{
Expand All @@ -387,14 +379,20 @@ func runLs(ctx context.Context, opts LsOptions, gopts GlobalOptions, args []stri
printer = &ncduLsPrinter{
out: globalOptions.stdout,
}
outputSort = false
} else {
printer = &textLsPrinter{
dirs: dirs,
ListLong: opts.ListLong,
HumanReadable: opts.HumanReadable,
}
}
if opts.Sort != SortModeName || opts.Reverse {
printer = &sortedPrinter{
printer: printer,
sortMode: opts.Sort,
reverse: opts.Reverse,
}
}

sn, subfolder, err := (&restic.SnapshotFilter{
Hosts: opts.Hosts,
Expand Down Expand Up @@ -425,12 +423,8 @@ func runLs(ctx context.Context, opts LsOptions, gopts GlobalOptions, args []stri
printedDir := false
if withinDir(nodepath) {
// if we're within a target path, print the node
if outputSort {
collector = append(collector, toSortOutput{nodepath, node})
} else {
if err := printer.Node(nodepath, node, false); err != nil {
return err
}
if err := printer.Node(nodepath, node, false); err != nil {
return err
}
printedDir = true

Expand All @@ -445,7 +439,7 @@ func runLs(ctx context.Context, opts LsOptions, gopts GlobalOptions, args []stri
// there yet), signal the walker to descend into any subdirs
if approachingMatchingTree(nodepath) {
// print node leading up to the target paths
if !printedDir && !outputSort {
if !printedDir {
return printer.Node(nodepath, node, true)
}
return nil
Expand Down Expand Up @@ -480,80 +474,103 @@ func runLs(ctx context.Context, opts LsOptions, gopts GlobalOptions, args []stri
return err
}

if outputSort {
printSortedOutput(printer, opts, sortMode, collector)
}

return printer.Close()
}

func printSortedOutput(printer lsPrinter, opts LsOptions, sortMode SortMode, collector []toSortOutput) {
switch sortMode {
type sortedPrinter struct {
printer lsPrinter
collector []toSortOutput
sortMode SortMode
reverse bool
}

func (p *sortedPrinter) Snapshot(sn *restic.Snapshot) error {
return p.printer.Snapshot(sn)
}
func (p *sortedPrinter) Node(path string, node *restic.Node, isPrefixDirectory bool) error {
if !isPrefixDirectory {
p.collector = append(p.collector, toSortOutput{path, node})
}
return nil
}

func (p *sortedPrinter) LeaveDir(_ string) error {
return nil
}
func (p *sortedPrinter) Close() error {
var comparator func(a, b toSortOutput) int
switch p.sortMode {
case SortModeName:
case SortModeSize:
slices.SortStableFunc(collector, func(a, b toSortOutput) int {
comparator = func(a, b toSortOutput) int {
return cmp.Or(
cmp.Compare(a.node.Size, b.node.Size),
cmp.Compare(a.nodepath, b.nodepath),
)
})
}
case SortModeMtime:
slices.SortStableFunc(collector, func(a, b toSortOutput) int {
comparator = func(a, b toSortOutput) int {
return cmp.Or(
a.node.ModTime.Compare(b.node.ModTime),
cmp.Compare(a.nodepath, b.nodepath),
)
})
}
case SortModeAtime:
slices.SortStableFunc(collector, func(a, b toSortOutput) int {
comparator = func(a, b toSortOutput) int {
return cmp.Or(
a.node.AccessTime.Compare(b.node.AccessTime),
cmp.Compare(a.nodepath, b.nodepath),
)
})
}
case SortModeCtime:
slices.SortStableFunc(collector, func(a, b toSortOutput) int {
comparator = func(a, b toSortOutput) int {
return cmp.Or(
a.node.ChangeTime.Compare(b.node.ChangeTime),
cmp.Compare(a.nodepath, b.nodepath),
)
})
}
case SortModeExt:
// map name to extension
mapExt := make(map[string]string, len(collector))
for _, item := range collector {
mapExt := make(map[string]string, len(p.collector))
for _, item := range p.collector {
ext := filepath.Ext(item.nodepath)
mapExt[item.nodepath] = ext
}

slices.SortStableFunc(collector, func(a, b toSortOutput) int {
comparator = func(a, b toSortOutput) int {
return cmp.Or(
cmp.Compare(mapExt[a.nodepath], mapExt[b.nodepath]),
cmp.Compare(a.nodepath, b.nodepath),
)
})
}
}

if opts.Reverse {
slices.Reverse(collector)
if comparator != nil {
slices.SortStableFunc(p.collector, comparator)
}
if p.reverse {
slices.Reverse(p.collector)
}
for _, elem := range collector {
_ = printer.Node(elem.nodepath, elem.node, false)
for _, elem := range p.collector {
if err := p.printer.Node(elem.nodepath, elem.node, false); err != nil {
return err
}
}
return nil
}

// SortMode defines the allowed sorting modes
type SortMode string
type SortMode uint

// Allowed sort modes
const (
SortModeName SortMode = "name"
SortModeSize SortMode = "size"
SortModeAtime SortMode = "atime"
SortModeCtime SortMode = "ctime"
SortModeMtime SortMode = "mtime"
SortModeExt SortMode = "extension"
SortModeInvalid SortMode = "--invalid--"
SortModeName SortMode = iota
SortModeSize
SortModeAtime
SortModeCtime
SortModeMtime
SortModeExt
SortModeInvalid
)

// Set implements the method needed for pflag command flag parsing.
Expand All @@ -573,8 +590,31 @@ func (c *SortMode) Set(s string) error {
*c = SortModeExt
default:
*c = SortModeInvalid
return fmt.Errorf("invalid sort mode %q, must be one of (name|size|atime|ctime|mtime=time|extension)", s)
return fmt.Errorf("invalid sort mode %q, must be one of (name|size|time=mtime|atime|ctime|extension)", s)
}

return nil
}

func (c *SortMode) String() string {
switch *c {
case SortModeName:
return "name"
case SortModeSize:
return "size"
case SortModeAtime:
return "atime"
case SortModeCtime:
return "ctime"
case SortModeMtime:
return "mtime"
case SortModeExt:
return "extension"
default:
return "invalid"
}
}

func (c *SortMode) Type() string {
return "mode"
}
89 changes: 42 additions & 47 deletions cmd/restic/cmd_ls_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"context"
"encoding/json"
"fmt"
"strings"
"testing"

Expand All @@ -19,7 +20,7 @@ func testRunLsWithOpts(t testing.TB, gopts GlobalOptions, opts LsOptions, args [
}

func testRunLs(t testing.TB, gopts GlobalOptions, snapshotID string) []string {
out := testRunLsWithOpts(t, gopts, LsOptions{Sort: "name"}, []string{snapshotID})
out := testRunLsWithOpts(t, gopts, LsOptions{}, []string{snapshotID})
return strings.Split(string(out), "\n")
}

Expand All @@ -45,35 +46,13 @@ func TestRunLsNcdu(t *testing.T) {
{"latest", "/0"},
{"latest", "/0", "/0/9"},
} {
ncdu := testRunLsWithOpts(t, env.gopts, LsOptions{Ncdu: true, Sort: "name"}, paths)
ncdu := testRunLsWithOpts(t, env.gopts, LsOptions{Ncdu: true}, paths)
assertIsValidJSON(t, ncdu)
}
}

func TestRunLsSort(t *testing.T) {
compareName := []string{
"/for_cmd_ls",
"/for_cmd_ls/file1.txt",
"/for_cmd_ls/file2.txt",
"/for_cmd_ls/python.py",
"", // last empty line
}

compareSize := []string{
"/for_cmd_ls",
"/for_cmd_ls/file2.txt",
"/for_cmd_ls/file1.txt",
"/for_cmd_ls/python.py",
"",
}

compareExt := []string{
"/for_cmd_ls",
"/for_cmd_ls/python.py",
"/for_cmd_ls/file1.txt",
"/for_cmd_ls/file2.txt",
"",
}
rtest.Equals(t, SortMode(0), SortModeName, "unexpected default sort mode")

env, cleanup := withTestEnvironment(t)
defer cleanup()
Expand All @@ -82,27 +61,43 @@ func TestRunLsSort(t *testing.T) {
opts := BackupOptions{}
testRunBackup(t, env.testdata+"/0", []string{"for_cmd_ls"}, opts, env.gopts)

// sort by size
out := testRunLsWithOpts(t, env.gopts, LsOptions{Sort: "size"}, []string{"latest"})
fileList := strings.Split(string(out), "\n")
rtest.Assert(t, len(fileList) == 5, "invalid ls --sort size, expected 5 array elements, got %v", len(fileList))
for i, item := range compareSize {
rtest.Assert(t, item == fileList[i], "invalid ls --sort size, expected element '%s', got '%s'", item, fileList[i])
}

// sort by file extension
out = testRunLsWithOpts(t, env.gopts, LsOptions{Sort: "extension"}, []string{"latest"})
fileList = strings.Split(string(out), "\n")
rtest.Assert(t, len(fileList) == 5, "invalid ls --sort extension, expected 5 array elements, got %v", len(fileList))
for i, item := range compareExt {
rtest.Assert(t, item == fileList[i], "invalid ls --sort extension, expected element '%s', got '%s'", item, fileList[i])
}

// explicit name sort
out = testRunLsWithOpts(t, env.gopts, LsOptions{Sort: "name"}, []string{"latest"})
fileList = strings.Split(string(out), "\n")
rtest.Assert(t, len(fileList) == 5, "invalid ls --sort name, expected 5 array elements, got %v", len(fileList))
for i, item := range compareName {
rtest.Assert(t, item == fileList[i], "invalid ls --sort name, expected element '%s', got '%s'", item, fileList[i])
for _, test := range []struct {
mode SortMode
expected []string
}{
{
SortModeSize,
[]string{
"/for_cmd_ls",
"/for_cmd_ls/file2.txt",
"/for_cmd_ls/file1.txt",
"/for_cmd_ls/python.py",
"",
},
},
{
SortModeExt,
[]string{
"/for_cmd_ls",
"/for_cmd_ls/python.py",
"/for_cmd_ls/file1.txt",
"/for_cmd_ls/file2.txt",
"",
},
},
{
SortModeName,
[]string{
"/for_cmd_ls",
"/for_cmd_ls/file1.txt",
"/for_cmd_ls/file2.txt",
"/for_cmd_ls/python.py",
"", // last empty line
},
},
} {
out := testRunLsWithOpts(t, env.gopts, LsOptions{Sort: test.mode}, []string{"latest"})
fileList := strings.Split(string(out), "\n")
rtest.Equals(t, test.expected, fileList, fmt.Sprintf("mismatch for mode %v", test.mode))
}
}

0 comments on commit 9cdf91b

Please sign in to comment.