diff --git a/task_test.go b/task_test.go index 73e464fe64..415286dae7 100644 --- a/task_test.go +++ b/task_test.go @@ -1,6 +1,7 @@ package task_test import ( + "archive/zip" "bytes" "context" "fmt" @@ -1053,20 +1054,32 @@ func TestIncludesMultiLevel(t *testing.T) { func TestIncludesRemote(t *testing.T) { dir := "testdata/includes_remote" + os.RemoveAll(filepath.Join(dir, ".task")) + srv := httptest.NewServer(http.FileServer(http.Dir(dir))) defer srv.Close() + createZipFileOfDir(t, filepath.Join(dir, "tasks-root.zip"), dir) + createZipFileOfDir(t, filepath.Join(dir, "tasks-first.zip"), filepath.Join(dir, "first")) + tcs := []struct { rootTaskfile string firstRemote string secondRemote string extraTasks []string }{ + // // NOTE: When adding content for tests that use `getGitRemoteURL`, // you must commit the test data for the tests to be able to find it. // // These tests will not see data in the working tree because they clone // this repo. + // + { + // Ensure non-remote includes still work + firstRemote: "./first/Taskfile.yml", + secondRemote: "./second/Taskfile.yml", + }, { firstRemote: srv.URL + "/first/Taskfile.yml", secondRemote: srv.URL + "/first/second/Taskfile.yml", @@ -1108,14 +1121,36 @@ func TestIncludesRemote(t *testing.T) { }, }, { - firstRemote: srv.URL + "/tasks.zip", + firstRemote: srv.URL + "/tasks-first.zip", secondRemote: "./second/Taskfile.yml", + extraTasks: []string{ + "first:check-if-neighbor-file-exists", + "first:second:check-if-neighbor-file-exists", + }, }, { rootTaskfile: srv.URL + "/Taskfile.yml", firstRemote: "./first/Taskfile.yml", secondRemote: "./second/Taskfile.yml", }, + { + rootTaskfile: getGitRemoteURL(t, dir), + firstRemote: "./first/Taskfile.yml", + secondRemote: "./second/Taskfile.yml", + extraTasks: []string{ + "first:check-if-neighbor-file-exists", + "first:second:check-if-neighbor-file-exists", + }, + }, + { + rootTaskfile: srv.URL + "/tasks-root.zip", + firstRemote: "./first/Taskfile.yml", + secondRemote: "./second/Taskfile.yml", + extraTasks: []string{ + "first:check-if-neighbor-file-exists", + "first:second:check-if-neighbor-file-exists", + }, + }, } tasks := []string{ @@ -1144,25 +1179,24 @@ func TestIncludesRemote(t *testing.T) { // Without caching AssumeYes: true, Download: true, + Offline: false, + }, + }, + { + name: "offline, use-cache", + executor: &task.Executor{ + Dir: dir, + Entrypoint: tc.rootTaskfile, + Timeout: time.Minute, + Insecure: true, + Verbose: true, + + // With caching + AssumeYes: false, + Download: false, + Offline: true, }, }, - // Disabled until we add caching support for directories - // - // { - // name: "offline, use-cache", - // executor: &task.Executor{ - // Dir: dir, - // Entrypoint: tc.rootTaskfile, - // Timeout: time.Minute, - // Insecure: true, - // Verbose: true, - // - // // With caching - // AssumeYes: false, - // Download: false, - // Offline: true, - // }, - // }, } for j, e := range executors { @@ -1206,6 +1240,17 @@ func TestIncludesRemote(t *testing.T) { } } +func createZipFileOfDir(t *testing.T, zipFilePath string, dir string) { + f, err := os.OpenFile(zipFilePath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o644) + require.NoError(t, err) + defer f.Close() + + w := zip.NewWriter(f) + err = w.AddFS(os.DirFS(dir)) + require.NoError(t, err) + w.Close() +} + func getGitRemoteURL(t *testing.T, path string) string { repoRoot, err := exec.Command("git", "rev-parse", "--show-toplevel").Output() require.NoError(t, err) diff --git a/taskfile/cache.go b/taskfile/cache.go index 2b57c17dd8..62e4148659 100644 --- a/taskfile/cache.go +++ b/taskfile/cache.go @@ -6,12 +6,21 @@ import ( "os" "path/filepath" "strings" + + "gopkg.in/yaml.v3" + + "github.com/go-task/task/v3/errors" ) type Cache struct { dir string } +type metadata struct { + Checksum string + TaskfileName string +} + func NewCache(dir string) (*Cache, error) { dir = filepath.Join(dir, "remote") if err := os.MkdirAll(dir, 0o755); err != nil { @@ -25,46 +34,207 @@ func NewCache(dir string) (*Cache, error) { func checksum(b []byte) string { h := sha256.New() h.Write(b) - return fmt.Sprintf("%x", h.Sum(nil)) + return fmt.Sprintf("%x", h.Sum(nil))[:16] } -func (c *Cache) write(node Node, b []byte) error { - return os.WriteFile(c.cacheFilePath(node), b, 0o644) +func checksumSource(s source) (string, error) { + h := sha256.New() + + entries, err := os.ReadDir(s.FileDirectory) + if err != nil { + return "", fmt.Errorf("could not list files at %s: %w", s.FileDirectory, err) + } + + for _, e := range entries { + if e.Type().IsRegular() { + path := filepath.Join(s.FileDirectory, e.Name()) + f, err := os.Open(path) + if err != nil { + return "", fmt.Errorf("error opening file %s for checksumming: %w", path, err) + } + if _, err := f.WriteTo(h); err != nil { + f.Close() + return "", fmt.Errorf("error reading file %s for checksumming: %w", path, err) + } + f.Close() + } + } + return fmt.Sprintf("%x", h.Sum(nil))[:16], nil } -func (c *Cache) read(node Node) ([]byte, error) { - return os.ReadFile(c.cacheFilePath(node)) +func (c *Cache) write(node Node, src source) (*source, error) { + // Clear metadata file so that if the rest of the operations fail part-way we don't + // end up in an inconsistent state where we've written the contents but have old metadata + if err := c.clearMetadata(node); err != nil { + return nil, err + } + + p, err := c.contentsPath(node) + if err != nil { + return nil, err + } + + switch fi, err := os.Stat(p); { + case errors.Is(err, os.ErrNotExist): + // Nothign to clear, do nothing + + case !fi.IsDir(): + return nil, fmt.Errorf("error writing to contents path %s: not a directory", p) + + case err != nil: + return nil, fmt.Errorf("error cheacking for previous contents path %s: %w", p, err) + + default: + err := os.RemoveAll(p) + if err != nil { + return nil, fmt.Errorf("error clearing contents directory: %s", err) + } + } + + if err := os.Rename(src.FileDirectory, p); err != nil { + return nil, err + } + + // TODO Clean up + src.FileDirectory = p + + cs, err := checksumSource(src) + if err != nil { + return nil, err + } + + m := metadata{ + Checksum: cs, + TaskfileName: src.Filename, + } + + if err := c.storeMetadata(node, m); err != nil { + return nil, fmt.Errorf("error storing metadata for node %s: %w", node.Location(), err) + } + + return &src, nil } -func (c *Cache) writeChecksum(node Node, checksum string) error { - return os.WriteFile(c.checksumFilePath(node), []byte(checksum), 0o644) +func (c *Cache) read(node Node) (*source, error) { + path, err := c.contentsPath(node) + if err != nil { + return nil, err + } + + m, err := c.readMetadata(node) + if err != nil { + return nil, err + } + + taskfileName := m.TaskfileName + + content, err := os.ReadFile(filepath.Join(path, m.TaskfileName)) + if err != nil { + return nil, err + } + + return &source{ + FileContent: content, + FileDirectory: path, + Filename: taskfileName, + }, nil } func (c *Cache) readChecksum(node Node) string { - b, _ := os.ReadFile(c.checksumFilePath(node)) - return string(b) + m, err := c.readMetadata(node) + if err != nil { + return "" + } + return m.Checksum +} + +func (c *Cache) clearMetadata(node Node) error { + path, err := c.metadataFilePath(node) + if err != nil { + return fmt.Errorf("error clearing metadata file at %s: %w", path, err) + } + + fi, err := os.Stat(path) + if errors.Is(err, os.ErrNotExist) { + return nil + } + + if !fi.Mode().IsRegular() { + return fmt.Errorf("path is not a real file when trying to delete metadata file: %s", path) + } + + // if err := os.Remove(path) + if err := os.Remove(path); err != nil { + return fmt.Errorf("error removing metadata file %s: %w", path, err) + } + + return nil +} + +func (c *Cache) storeMetadata(node Node, m metadata) error { + path, err := c.metadataFilePath(node) + if err != nil { + return err + } + + f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o644) + if err != nil { + return fmt.Errorf("error creating metadata file %s: %w", path, err) + } + defer f.Close() + + if err := yaml.NewEncoder(f).Encode(m); err != nil { + return fmt.Errorf("error writing metadata into %s: %w", path, err) + } + + return nil +} + +func (c *Cache) readMetadata(node Node) (*metadata, error) { + path, err := c.metadataFilePath(node) + if err != nil { + return nil, err + } + + f, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("error opening metadata file %s: %w", path, err) + } + defer f.Close() + + var m *metadata + if err := yaml.NewDecoder(f).Decode(&m); err != nil { + return nil, fmt.Errorf("error reading metadata file %s: %w", path, err) + } + + return m, nil } func (c *Cache) key(node Node) string { return strings.TrimRight(checksum([]byte(node.Location())), "=") } -func (c *Cache) cacheFilePath(node Node) string { - return c.filePath(node, "yaml") +func (c *Cache) contentsPath(node Node) (string, error) { + return c.cacheFilePath(node, "contents") } -func (c *Cache) checksumFilePath(node Node) string { - return c.filePath(node, "checksum") +func (c *Cache) metadataFilePath(node Node) (string, error) { + return c.cacheFilePath(node, "metadata.yaml") } -func (c *Cache) filePath(node Node, suffix string) string { - lastDir, filename := node.FilenameAndLastDir() - prefix := filename +func (c *Cache) cacheFilePath(node Node, filename string) (string, error) { + lastDir, prefix := node.FilenameAndLastDir() // Means it's not "", nor "." nor "/", so it's a valid directory if len(lastDir) > 1 { - prefix = fmt.Sprintf("%s-%s", lastDir, filename) + prefix = fmt.Sprintf("%s-%s", lastDir, prefix) } - return filepath.Join(c.dir, fmt.Sprintf("%s.%s.%s", prefix, c.key(node), suffix)) + + dir := filepath.Join(c.dir, fmt.Sprintf("%s.%s", prefix, c.key(node))) + if err := os.MkdirAll(dir, 0o755); err != nil { + return "", fmt.Errorf("error creating cache dir %s: %w", dir, err) + } + + return filepath.Join(dir, filename), nil } func (c *Cache) Clear() error { diff --git a/taskfile/node_remote.go b/taskfile/node_remote.go index d832a61e5f..660104d2bc 100644 --- a/taskfile/node_remote.go +++ b/taskfile/node_remote.go @@ -40,7 +40,6 @@ func NewRemoteNode( timeout time.Duration, opts ...NodeOption, ) (*RemoteNode, bool, error) { - client := newGetterClient(dir) proto, u, err := extractProtocolFromURL(client, entrypoint) if err != nil { @@ -90,21 +89,8 @@ func (r *RemoteNode) ResolveEntrypoint(entrypoint string) (string, error) { return "", fmt.Errorf("could not resolve protocol for include %s: %w", entrypoint, err) } - switch { - case childProto != "file": - return entrypoint, nil - - case filepath.IsAbs(entrypoint): - return entrypoint, nil - - case r.proto == "http" || r.proto == "https": - // In HTTP, relative includes aren't available locally and are downloaded from the same base URL. - base := *r.url - base.Path = filepath.Join(filepath.Dir(base.Path), entrypoint) - - return base.String(), nil - - default: + if childProto == "file" && !filepath.IsAbs(entrypoint) { + // Relative file paths are resolved as relative to our own source location ctx, cancel := context.WithTimeout(context.Background(), r.timeout) defer cancel() @@ -113,8 +99,26 @@ func (r *RemoteNode) ResolveEntrypoint(entrypoint string) (string, error) { return "", err } - return filepathext.SmartJoin(src.FileDirectory, entrypoint), nil + relativePath := filepath.Join(src.FileDirectory, entrypoint) + if exists, err := fileExists(relativePath); err != nil { + return "", err + } else if exists { + return relativePath, nil + } + + if r.proto == "http" || r.proto == "https" { + // In HTTP, if relative includes are not available locally (eg. from a ZIP file), + // we try to download them relative to the base URL. + rel, err := url.Parse(entrypoint) + if err != nil { + return "", fmt.Errorf("error parsing entrypoint %s as url: %w", entrypoint, err) + } + + return r.url.ResolveReference(rel).String(), nil + } } + + return entrypoint, nil } func (r *RemoteNode) ResolveDir(dir string) (string, error) { @@ -145,16 +149,16 @@ func (r *RemoteNode) FilenameAndLastDir() (string, string) { func (r *RemoteNode) loadSource(ctx context.Context) (*source, error) { if r.cachedSource == nil { - r.logger.VerboseOutf(logger.Magenta, "task: [%s] Fetching remote taskfile from %s\n", r.Location(), r.client.Src) - dir, err := os.MkdirTemp("", "taskfile-remote-") if err != nil { return nil, err } + r.client.Ctx = ctx r.client.Src = r.Location() r.client.Dst = dir + r.logger.VerboseOutf(logger.Magenta, "task: [%s] Fetching remote taskfile from %s into %s\n", r.Location(), r.client.Src, r.client.Dst) if err := r.client.Get(); err != nil { return nil, err } @@ -243,7 +247,8 @@ var getterURLRegexp = regexp.MustCompile(`^([A-Za-z0-9]+)::(.+)$`) func extractProtocolFromURL(client *getter.Client, src string) (string, *url.URL, error) { if src == "" { - // If empty we assume current directory and let NodeFile logic deal with finding and appropriate file + // If empty we assume current directory and type `file` and let the FileNode logic + // deal with finding and appropriate file. u, err := url.Parse(".") return "file", u, err } @@ -309,6 +314,17 @@ func resolveTaskfileOverride(u *url.URL) (string, *url.URL) { return "", u } +func fileExists(path string) (bool, error) { + switch _, err := os.Stat(path); { + case errors.Is(err, os.ErrNotExist): + return false, nil + case err != nil: + return false, fmt.Errorf("error checking if file exists at %s: %w", path, err) + default: + return true, nil + } +} + // httpGetter wraps getter.HttpGetter to give us the ability // to download single files into a directory, like other getters would type httpGetter struct { diff --git a/taskfile/reader.go b/taskfile/reader.go index f18fbde777..0501a56ad8 100644 --- a/taskfile/reader.go +++ b/taskfile/reader.go @@ -180,10 +180,11 @@ func (r *Reader) include(node Node) error { } func (r *Reader) readNode(node Node) (*ast.Taskfile, error) { - var b []byte - var err error - var cache *Cache - source := &source{} + var ( + err error + cache *Cache + source *source + ) if node.Remote() { cache, err = NewCache(r.tempDir) @@ -194,11 +195,19 @@ func (r *Reader) readNode(node Node) (*ast.Taskfile, error) { // If the file is remote and we're in offline mode, check if we have a cached copy if node.Remote() && r.offline { - if b, err = cache.read(node); errors.Is(err, os.ErrNotExist) { + if source, err = cache.read(node); errors.Is(err, os.ErrNotExist) { return nil, &errors.TaskfileCacheNotFoundError{URI: node.Location()} } else if err != nil { return nil, err } + + // TODO: Find a cleaner way to override source when loading from the cache + // Without this later usages of ResolveEntrypoint will be relative to the old source location + // fr before it got moved into the cache. + if n, ok := node.(*RemoteNode); ok { + n.cachedSource = source + } + r.logger.VerboseOutf(logger.Magenta, "task: [%s] Fetched cached copy\n", node.Location()) } else { @@ -215,17 +224,24 @@ func (r *Reader) readNode(node Node) (*ast.Taskfile, error) { return nil, &errors.TaskfileNetworkTimeoutError{URI: node.Location(), Timeout: r.timeout} } // Search for any cached copies - if b, err = cache.read(node); errors.Is(err, os.ErrNotExist) { + if source, err = cache.read(node); errors.Is(err, os.ErrNotExist) { return nil, &errors.TaskfileNetworkTimeoutError{URI: node.Location(), Timeout: r.timeout, CheckedCache: true} } else if err != nil { return nil, err } r.logger.VerboseOutf(logger.Magenta, "task: [%s] Network timeout. Fetched cached copy\n", node.Location()) + + // TODO: Find a cleaner way to override source when loading from the cache + // Without this later usages of ResolveEntrypoint will be relative to the old source location + // fr before it got moved into the cache. + if n, ok := node.(*RemoteNode); ok { + n.cachedSource = source + } + } else if err != nil { return nil, err } else { downloaded = true - b = source.FileContent } // If the node was remote, we need to check the checksum @@ -233,8 +249,11 @@ func (r *Reader) readNode(node Node) (*ast.Taskfile, error) { r.logger.VerboseOutf(logger.Magenta, "task: [%s] Fetched remote copy\n", node.Location()) // Get the checksums - checksum := checksum(b) cachedChecksum := cache.readChecksum(node) + checksum, err := checksumSource(*source) + if err != nil { + return nil, err + } var prompt string if cachedChecksum == "" { @@ -252,25 +271,28 @@ func (r *Reader) readNode(node Node) (*ast.Taskfile, error) { // If the hash has changed (or is new) if checksum != cachedChecksum { - // Store the checksum - if err := cache.writeChecksum(node, checksum); err != nil { - return nil, err - } // Cache the file r.logger.VerboseOutf(logger.Magenta, "task: [%s] Caching downloaded file\n", node.Location()) - if err = cache.write(node, b); err != nil { + if source, err = cache.write(node, *source); err != nil { return nil, err } + + // TODO: Find a cleaner way to override source when loading from the cache + // Without this later usages of ResolveEntrypoint will be relative to the old source location + // fr before it got moved into the cache. + if n, ok := node.(*RemoteNode); ok { + n.cachedSource = source + } } } } var tf ast.Taskfile - if err := yaml.Unmarshal(b, &tf); err != nil { + if err := yaml.Unmarshal(source.FileContent, &tf); err != nil { // Decode the taskfile and add the file info the any errors taskfileInvalidErr := &errors.TaskfileDecodeError{} if errors.As(err, &taskfileInvalidErr) { - return nil, taskfileInvalidErr.WithFileInfo(node.Location(), b, 2) + return nil, taskfileInvalidErr.WithFileInfo(node.Location(), source.FileContent, 2) } return nil, &errors.TaskfileInvalidError{URI: filepathext.TryAbsToRel(node.Location()), Err: err} } diff --git a/taskfile/taskfile.go b/taskfile/taskfile.go index ca915eae75..e7642067f9 100644 --- a/taskfile/taskfile.go +++ b/taskfile/taskfile.go @@ -32,6 +32,7 @@ var ( "text/x-yaml", "application/yaml", "application/x-yaml", + "application/zip", } ) diff --git a/testdata/includes_remote/.gitignore b/testdata/includes_remote/.gitignore index 2211df63dd..8d8a34eeab 100644 --- a/testdata/includes_remote/.gitignore +++ b/testdata/includes_remote/.gitignore @@ -1 +1,2 @@ *.txt +*.zip \ No newline at end of file diff --git a/testdata/includes_remote/first/neighbor-of-first b/testdata/includes_remote/first/neighbor-of-first index ee3b91457f..4dd1ef7569 100644 --- a/testdata/includes_remote/first/neighbor-of-first +++ b/testdata/includes_remote/first/neighbor-of-first @@ -1 +1 @@ -This is a file. \ No newline at end of file +This is a file.