diff --git a/pkg/crdutil/crdutil.go b/pkg/crdutil/crdutil.go index 7455875..6f06eee 100644 --- a/pkg/crdutil/crdutil.go +++ b/pkg/crdutil/crdutil.go @@ -39,23 +39,21 @@ import ( ) var ( - filenames []string - recursive bool + crdsDir []string ) func initFlags() { pflag.CommandLine.AddGoFlagSet(flag.CommandLine) - pflag.StringSliceVarP(&filenames, "filename", "f", filenames, "The files that contain the configurations to apply.") - pflag.BoolVarP(&recursive, "recursive", "R", false, "Process the directory used in -f, --filename recursively.") + pflag.StringSliceVarP(&crdsDir, "crds-dir", "f", crdsDir, "The files or directories that contain the CRDs to apply.") pflag.Parse() - if len(filenames) == 0 { + if len(crdsDir) == 0 { log.Fatalf("CRDs directory or single CRDs are required") } - for _, crdDir := range filenames { - if _, err := os.Stat(crdDir); os.IsNotExist(err) { - log.Fatalf("CRDs directory %s does not exist", filenames) + for _, path := range crdsDir { + if _, err := os.Stat(path); os.IsNotExist(err) { + log.Fatalf("path does not exist: %s", path) } } } @@ -78,30 +76,32 @@ func EnsureCRDsCmd() { log.Fatalf("Failed to create API extensions client: %v", err) } - dirsToApply, err := walkCRDs(recursive, filenames) + pathsToApply, err := collectYamlPaths(crdsDir) if err != nil { log.Fatalf("Failed to walk through CRDs: %v", err) } - for _, dir := range dirsToApply { - log.Printf("Apply CRDs from file: %s", dir) - if err := applyCRDs(ctx, client.ApiextensionsV1().CustomResourceDefinitions(), dir); err != nil { + for _, path := range pathsToApply { + log.Printf("Apply CRDs from file: %s", path) + if err := applyCRDs(ctx, client.ApiextensionsV1().CustomResourceDefinitions(), path); err != nil { log.Fatalf("Failed to apply CRDs: %v", err) } } } -// walkCRDs walks the CRDs directory and applies each YAML file. -// TODO: add unit test for this function. -func walkCRDs(recursive bool, crdDirs []string) ([]string, error) { - var dirs []string +// collectYamlPaths processes a list of paths and returns all YAML files. +func collectYamlPaths(crdDirs []string) ([]string, error) { + paths := map[string]struct{}{} for _, crdDir := range crdDirs { // We need the parent directory to check if we are in the top-level directory. // This is necessary for the recursive logic. // We can skip the errors as it has been checked in initFlags. - parentDir, _ := os.Stat(crdDir) + parentDir, err := os.Stat(crdDir) + if err != nil { + return []string{}, fmt.Errorf("stat the path %s: %w", crdDir, err) + } // Walk the directory recursively and apply each YAML file. - err := filepath.Walk(crdDir, func(path string, info os.FileInfo, err error) error { + err = filepath.Walk(crdDir, func(path string, info os.FileInfo, err error) error { if err != nil { return err } @@ -113,20 +113,30 @@ func walkCRDs(recursive bool, crdDirs []string) ([]string, error) { if filepath.Ext(path) != ".yaml" && filepath.Ext(path) != ".yml" { return nil } - // If not recursive we want to only apply the CRDs in the top-level directory. + // If we apply a dir we only want to apply the CRDs in the directory (i.e., not recursively). // filepath.Dir() does not add a trailing slash, thus we need to trim it in the crdDir. - if !recursive && parentDir.IsDir() && filepath.Dir(path) != strings.TrimRight(crdDir, "/") { + if parentDir.IsDir() && filepath.Dir(path) != strings.TrimRight(crdDir, "/") { return nil } - dirs = append(dirs, path) + paths[path] = struct{}{} return nil }) if err != nil { return []string{}, fmt.Errorf("walk the path %s: %w", crdDirs, err) } } - return dirs, nil + return mapToSlice(paths), nil +} + +// mapToSlice converts a map to a slice. +// The map is used to deduplicate the paths. +func mapToSlice(m map[string]struct{}) []string { + s := []string{} + for k := range m { + s = append(s, k) + } + return s } // applyCRDs reads a YAML file, splits it into documents, and applies each CRD to the cluster. @@ -183,7 +193,7 @@ func applyCRD( log.Printf("Create CRD %s", crd.Name) _, err = crdClient.Create(ctx, crd, metav1.CreateOptions{}) if err != nil { - return fmt.Errorf("create CRD %s: %w", crd.Name, err) + return fmt.Errorf("create CRD: %w", err) } return nil } diff --git a/pkg/crdutil/crdutil_test.go b/pkg/crdutil/crdutil_test.go index 3de6671..1212256 100644 --- a/pkg/crdutil/crdutil_test.go +++ b/pkg/crdutil/crdutil_test.go @@ -38,6 +38,41 @@ var _ = Describe("CRD Application", func() { Expect(testCRDClient.DeleteCollection(ctx, metav1.DeleteOptions{}, metav1.ListOptions{})).NotTo(HaveOccurred()) }) + Describe("collectYamlPaths", func() { + It("should collect all YAML files in a directory", func() { + By("collecting YAML paths") + paths, err := collectYamlPaths([]string{"test-files"}) + Expect(err).NotTo(HaveOccurred()) + Expect(paths).To(ConsistOf( + "test-files/test-crds.yaml", + "test-files/updated-test-crds.yaml", + )) + }) + + It("should collect a single YAML file", func() { + By("collecting YAML paths") + paths, err := collectYamlPaths([]string{"test-files/test-crds.yaml"}) + Expect(err).NotTo(HaveOccurred()) + Expect(paths).To(ConsistOf("test-files/test-crds.yaml")) + }) + + It("should deduplicate YAML file", func() { + By("collecting YAML paths") + paths, err := collectYamlPaths([]string{"test-files/test-crds.yaml", "test-files"}) + Expect(err).NotTo(HaveOccurred()) + Expect(paths).To(ConsistOf( + "test-files/test-crds.yaml", + "test-files/updated-test-crds.yaml", + )) + }) + + It("should fail to collect non-existent YAML files", func() { + By("collecting YAML paths") + _, err := collectYamlPaths([]string{"test-files/non-existent.yaml"}) + Expect(err).To(HaveOccurred()) + }) + }) + Describe("applyCRDs", func() { It("should apply CRDs multiple times from a valid YAML file", func() { By("applying CRDs") diff --git a/pkg/crdutil/test-files/nested/do-not-apply.yaml b/pkg/crdutil/test-files/nested/do-not-apply.yaml new file mode 100644 index 0000000..eaa7c28 --- /dev/null +++ b/pkg/crdutil/test-files/nested/do-not-apply.yaml @@ -0,0 +1,22 @@ +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + name: foos.example.com +spec: + group: example.com + names: + kind: Foo + listKind: FooList + singular: foo + plural: foos + scope: Namespaced + versions: + - name: v1 + served: true + storage: true + schema: + openAPIV3Schema: + type: object + properties: + spec: + type: object \ No newline at end of file