diff --git a/internal/utils/copy/copy.go b/internal/utils/copy/copy.go index 4e2aacd24..314e93764 100644 --- a/internal/utils/copy/copy.go +++ b/internal/utils/copy/copy.go @@ -15,6 +15,7 @@ package copy import ( + "fmt" "io" "os" "path/filepath" @@ -23,43 +24,33 @@ import ( "github.com/ZupIT/horusec-devkit/pkg/utils/logger" ) +// Copy copy src directory to dst ignoring files make skip function return true. +// +// Note that symlink files will be ignored by default. func Copy(src, dst string, skip func(src string) bool) error { if err := os.MkdirAll(dst, os.ModePerm); err != nil { return err } return filepath.Walk(src, func(path string, info os.FileInfo, err error) error { - if err != nil { + if err != nil || skip(path) || info.Mode()&os.ModeSymlink != 0 { return err } - if isToSkip := skip(path); !isToSkip { - return copyByType(src, dst, path, info) + logger.LogTraceWithLevel(fmt.Sprintf("Copying src: %s dst: %s path: %s", src, dst, path)) + if info.IsDir() { + return copyDir(src, dst, path) } - return nil - }) -} - -func copyByType(src, dst, path string, info os.FileInfo) error { - logger.LogTraceWithLevel("Copying ", "src: "+src, "dst: "+dst, "path: "+path) - switch { - case info.IsDir(): - return copyDir(src, dst, path) - case info.Mode()&os.ModeSymlink != 0: - return copyLink(src, dst, path) - default: return copyFile(src, dst, path) - } + }) } func copyFile(src, dst, path string) error { file, err := os.Create(replacePathSrcToDst(path, src, dst)) - if file != nil { - defer func() { - logger.LogError("Error defer file close", file.Close()) - }() - } if err != nil { return err } + defer func() { + logger.LogError("Error defer file close", file.Close()) + }() return copyContentSrcFileToDstFile(path, file) } @@ -69,14 +60,12 @@ func replacePathSrcToDst(path, src, dst string) string { func copyContentSrcFileToDstFile(srcPath string, dstFile *os.File) error { srcFile, err := os.Open(srcPath) - if srcFile != nil { - defer func() { - logger.LogError("Error defer file close", srcFile.Close()) - }() - } if err != nil { return err } + defer func() { + logger.LogError("Error defer file close", srcFile.Close()) + }() _, err = io.Copy(dstFile, srcFile) return err @@ -86,17 +75,3 @@ func copyDir(src, dst, path string) error { newPath := replacePathSrcToDst(path, src, dst) return os.MkdirAll(newPath, os.ModePerm) } - -func copyLink(src, dst, path string) error { - orig, err := filepath.EvalSymlinks(src) - if err != nil { - return err - } - - info, err := os.Lstat(orig) - if err != nil { - return err - } - - return copyByType(orig, dst, path, info) -} diff --git a/internal/utils/copy/copy_test.go b/internal/utils/copy/copy_test.go index 201c5bcf8..a996adbf1 100644 --- a/internal/utils/copy/copy_test.go +++ b/internal/utils/copy/copy_test.go @@ -12,31 +12,52 @@ // See the License for the specific language governing permissions and // limitations under the License. -package copy +package copy_test import ( - "fmt" "os" "path/filepath" "testing" + "github.com/ZupIT/horusec/internal/utils/copy" + "github.com/ZupIT/horusec/internal/utils/testutil" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCopy(t *testing.T) { - t.Run("Should success copy dir", func(t *testing.T) { - srcPath, err := filepath.Abs("../../../config") - assert.NoError(t, err) + src := testutil.GoExample1 + dst := filepath.Join(t.TempDir(), t.Name()) - dstPath, err := filepath.Abs(".") - assert.NoError(t, err) + tmpFile, err := os.CreateTemp(os.TempDir(), "test-symlink") + require.Nil(t, err, "Expected nil error to create temp file") - dstPath = fmt.Sprintf(dstPath+"%s", "/tmp-test") + symlinkFile := filepath.Join(src, "symlink") + err = os.Symlink(tmpFile.Name(), symlinkFile) + require.NoError(t, err, "Expected nil error to create symlink file: %v", err) - err = Copy(srcPath, dstPath, func(src string) bool { return false }) - assert.NoError(t, err) + t.Cleanup(func() { + err := os.Remove(symlinkFile) + assert.NoError(t, err, "Expected nil error to clean up symlink file: %v", err) + }) - err = os.RemoveAll(dstPath) - assert.NoError(t, err) + err = copy.Copy(src, dst, func(src string) bool { + ext := filepath.Ext(src) + return ext == ".mod" || ext == ".sum" }) + + assert.NoError(t, err) + + assert.DirExists(t, dst) + assert.DirExists(t, filepath.Join(dst, "api", "routes")) + assert.DirExists(t, filepath.Join(dst, "api", "util")) + + assert.NoFileExists(t, filepath.Join(dst, "symlink")) + assert.FileExists(t, filepath.Join(dst, "api", "server.go")) + assert.FileExists(t, filepath.Join(dst, "api", "routes", "healthcheck.go")) + assert.FileExists(t, filepath.Join(dst, "api", "util", "util.go")) + + assert.NoFileExists(t, filepath.Join(dst, "go.mod")) + assert.NoFileExists(t, filepath.Join(dst, "go.sum")) + }