diff --git a/internal/cli/proji/new.go b/internal/cli/proji/new.go index 1beca7b..8e83b5a 100644 --- a/internal/cli/proji/new.go +++ b/internal/cli/proji/new.go @@ -74,7 +74,15 @@ func buildProject(ctx context.Context, project *domain.ProjectAdd) error { // Get package manager from session logger.Debug("getting package manager from cli session") - pama := cli.SessionFromContext(ctx).PackageManager + session := cli.SessionFromContext(ctx) + + // Get mandatory paths for finding plugins and templates in the filesystem + config := session.Config + pluginsDir := config.PluginsDir() + templatesDir := config.TemplatesDir() + + // Get package manager from session + pama := session.PackageManager if pama == nil { return errors.New("no package manager found") } @@ -134,7 +142,7 @@ func buildProject(ctx context.Context, project *domain.ProjectAdd) error { } if !filepath.IsAbs(path) { - path = filepath.Join("/home/niko/.config/proji/plugins", path) + path = filepath.Join(pluginsDir, path) } logger.Infof("Running plugin %q", filepath.Base(path)) @@ -172,7 +180,7 @@ func buildProject(ctx context.Context, project *domain.ProjectAdd) error { // Check if template path is absolute if !filepath.IsAbs(tmplPath) { - tmplPath = filepath.Join("/home/niko/.config/proji/templates", tmplPath) + tmplPath = filepath.Join(templatesDir, tmplPath) } // Read template from filesystem @@ -222,7 +230,7 @@ func buildProject(ctx context.Context, project *domain.ProjectAdd) error { } if !filepath.IsAbs(path) { - path = filepath.Join("/home/niko/.config/proji/plugins", path) + path = filepath.Join(pluginsDir, path) } logger.Infof("Running plugin %q", filepath.Base(path)) diff --git a/internal/cli/proji/root.go b/internal/cli/proji/root.go index 9a7d1dc..85defc2 100644 --- a/internal/cli/proji/root.go +++ b/internal/cli/proji/root.go @@ -75,7 +75,16 @@ func rootCommand() *cobra.Command { } // Create package manager - pama, err := manager.NewPackageManager(ctx, serverAddress, db, &conf.Auth) + pama, err := manager.NewPackageManager(ctx, manager.Config{ + Address: serverAddress, + DB: db, + Auth: &conf.Auth, + LocalPaths: &manager.LocalPaths{ + Base: conf.BaseDir(), + Templates: conf.TemplatesDir(), + Plugins: conf.PluginsDir(), + }, + }) if err != nil { return errors.Wrap(err, "setup package manager") } diff --git a/internal/config/config.go b/internal/config/config.go index 6753323..4de68e4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -138,6 +138,7 @@ func newProvider(path string) *viper.Viper { provider := viper.New() // Allow for cross-platform paths + path = filepath.Clean(path) path = filepath.FromSlash(path) dir := filepath.Dir(path) @@ -158,8 +159,8 @@ func (conf *Config) setupInfrastructure() error { return errors.New("config provider is nil") } - // Get directory for config file and make sure it's cross-platform compatible - configPath := filepath.FromSlash(conf.provider.ConfigFileUsed()) + // Get directory for config file + configPath := conf.provider.ConfigFileUsed() baseDir := filepath.Dir(configPath) // Create subdirectories; this also implicitly creates the base directory @@ -236,6 +237,14 @@ func (conf *Config) readFlags(cmdFlags *pflag.FlagSet) error { func load(ctx context.Context, path string, flags *pflag.FlagSet) (conf *Config, err error) { logger := simplog.FromContext(ctx) + // Clean up path + path = filepath.Clean(path) + + path, err = filepath.Abs(path) + if err != nil { + return nil, errors.Wrap(err, "get absolute config path") + } + // If no explicit path is given, use default path if path == "" { path, err = defaultConfigPath() @@ -246,15 +255,6 @@ func load(ctx context.Context, path string, flags *pflag.FlagSet) (conf *Config, logger.Debugf("no explicit config path given, using default path: %q", path) } - // Make sure the path is cross-platform compatible - path = filepath.FromSlash(path) - - // Make config path absolute - path, err = filepath.Abs(path) - if err != nil { - return nil, errors.Wrap(err, "get absolute config path") - } - // Create default config logger.Debugf("creating config provider with path: %q", path) conf = &Config{ @@ -321,3 +321,18 @@ func (conf *Config) Validate() error { return nil } + +// BaseDir returns the base directory of the configuration file. +func (conf *Config) BaseDir() string { + return filepath.Dir(conf.provider.ConfigFileUsed()) +} + +// PluginsDir returns the plugins' directory. +func (conf *Config) PluginsDir() string { + return filepath.Join(conf.BaseDir(), defaultPluginsDir) +} + +// TemplatesDir returns the templates' directory. +func (conf *Config) TemplatesDir() string { + return filepath.Join(conf.BaseDir(), defaultTemplatesDir) +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 2cde271..e8ecf1f 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -223,3 +223,53 @@ func Test_defaultConfigPath(t *testing.T) { } } } + +func TestConfig_BaseDir(t *testing.T) { + t.Parallel() + + type args struct { + path string + } + cases := []struct { + name string + args args + want string + }{ + { + name: "default", + args: args{ + path: "/home/user/.config/proji.toml", + }, + want: filepath.Join("/home", "user", ".config"), + }, + { + name: "complex", + args: args{ + path: "/home/user/.config/proji/this/is/a/very/long/path/to/a/config/file.toml", + }, + want: filepath.Join("/home", "user", ".config", "proji", "this", "is", "a", "very", "long", "path", "to", "a", "config"), + }, + { + name: "current dir", + args: args{ + path: "./proji.toml", + }, + want: ".", + }, + // TODO: Add Windows tests + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + conf := &Config{ + provider: newProvider(tc.args.path), + } + if got := conf.BaseDir(); got != filepath.FromSlash(tc.want) { + t.Fatalf("BaseDir() returned unexpected path: %q;\nbase directory of %q should be %q", got, tc.args.path, tc.want) + } + }) + } +} diff --git a/internal/manager/manager.go b/internal/manager/manager.go index 06afb5d..5f30267 100644 --- a/internal/manager/manager.go +++ b/internal/manager/manager.go @@ -22,26 +22,51 @@ import ( // TODO: Should probably be configurable. const defaultServiceTimeout = 5 * time.Second +type LocalPaths struct { + Base string + Plugins string + Templates string +} + +// Config is a package manager configuration. This config is shared between different types of package managers. +type Config struct { + // Address is the address of the remote package manager. If empty, the local package manager will be used. + Address string + + // Auth contains the authentication information for the remote package manager. + Auth *config.Auth + + // DB is the database connection. + DB *database.DB + + // LocalPaths contains local filesystem paths that point to the base directory, the plugins/ directory and the + // templates/ directory. These paths are used by the local package manager to persist packages and templates. + LocalPaths *LocalPaths +} + // NewPackageManager is a convenience function that connects to a package manager based on the given address. If the // address is empty, it will connect to the local package manager. Otherwise, it will connect to the remote package // manager. -func NewPackageManager(ctx context.Context, address string, db *database.DB, auth *config.Auth) (packages.Manager, error) { - logger := simplog.FromContext(ctx) +func NewPackageManager(ctx context.Context, config Config) (packages.Manager, error) { + if config == (Config{}) { + return nil, errors.New("config is required") + } - // If an address is given, interpret that as an intent to connect to a remote package manager. + logger := simplog.FromContext(ctx) logger.Debugf("creating a package manager") - address = strings.TrimSpace(address) - if address != "" { - logger.Debugf("server address not empty, connecting to remote package manager at %s", address) + // If an address is given, interpret that as an intent to connect to a remote package manager. + config.Address = strings.TrimSpace(config.Address) + if config.Address != "" { + logger.Debugf("server address not empty, connecting to remote package manager at %q", config.Address) - return packages.NewRemoteManager(address) + return packages.NewRemoteManager(config.Address) } // Otherwise, connect to the local package manager. logger.Debugf("server address is empty, creating a local package manager") - repo, err := packageRepo.New(db) + repo, err := packageRepo.New(config.DB) if err != nil { return nil, errors.Wrap(err, "create package repository") } @@ -52,7 +77,7 @@ func NewPackageManager(ctx context.Context, address string, db *database.DB, aut } // Create the local package manager. - return packages.NewLocalManager(auth, service) + return packages.NewLocalManager(config.Auth, service) } // NewProjectManager returns a new project manager. Compared to the package manager, the project manager is always local,