Skip to content

Commit

Permalink
refactor(paths): centralize path logic
Browse files Browse the repository at this point in the history
  • Loading branch information
nikoksr committed Dec 26, 2022
1 parent 6f6a79f commit 917444f
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 25 deletions.
16 changes: 12 additions & 4 deletions internal/cli/proji/new.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,15 @@ func buildProject(ctx context.Context, project *domain.ProjectAdd) error {

// Get package manager from session
logger.Debug("getting package manager from cli session")
pama := cli.SessionFromContext(ctx).PackageManager
session := cli.SessionFromContext(ctx)

// Get mandatory paths for finding plugins and templates in the filesystem
config := session.Config
pluginsDir := config.PluginsDir()
templatesDir := config.TemplatesDir()

// Get package manager from session
pama := session.PackageManager
if pama == nil {
return errors.New("no package manager found")
}
Expand Down Expand Up @@ -134,7 +142,7 @@ func buildProject(ctx context.Context, project *domain.ProjectAdd) error {
}

if !filepath.IsAbs(path) {
path = filepath.Join("/home/niko/.config/proji/plugins", path)
path = filepath.Join(pluginsDir, path)
}

logger.Infof("Running plugin %q", filepath.Base(path))
Expand Down Expand Up @@ -172,7 +180,7 @@ func buildProject(ctx context.Context, project *domain.ProjectAdd) error {

// Check if template path is absolute
if !filepath.IsAbs(tmplPath) {
tmplPath = filepath.Join("/home/niko/.config/proji/templates", tmplPath)
tmplPath = filepath.Join(templatesDir, tmplPath)
}

// Read template from filesystem
Expand Down Expand Up @@ -222,7 +230,7 @@ func buildProject(ctx context.Context, project *domain.ProjectAdd) error {
}

if !filepath.IsAbs(path) {
path = filepath.Join("/home/niko/.config/proji/plugins", path)
path = filepath.Join(pluginsDir, path)
}

logger.Infof("Running plugin %q", filepath.Base(path))
Expand Down
11 changes: 10 additions & 1 deletion internal/cli/proji/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,16 @@ func rootCommand() *cobra.Command {
}

// Create package manager
pama, err := manager.NewPackageManager(ctx, serverAddress, db, &conf.Auth)
pama, err := manager.NewPackageManager(ctx, manager.Config{
Address: serverAddress,
DB: db,
Auth: &conf.Auth,
LocalPaths: &manager.LocalPaths{
Base: conf.BaseDir(),
Templates: conf.TemplatesDir(),
Plugins: conf.PluginsDir(),
},
})
if err != nil {
return errors.Wrap(err, "setup package manager")
}
Expand Down
37 changes: 26 additions & 11 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ func newProvider(path string) *viper.Viper {
provider := viper.New()

// Allow for cross-platform paths
path = filepath.Clean(path)
path = filepath.FromSlash(path)
dir := filepath.Dir(path)

Expand All @@ -158,8 +159,8 @@ func (conf *Config) setupInfrastructure() error {
return errors.New("config provider is nil")
}

// Get directory for config file and make sure it's cross-platform compatible
configPath := filepath.FromSlash(conf.provider.ConfigFileUsed())
// Get directory for config file
configPath := conf.provider.ConfigFileUsed()
baseDir := filepath.Dir(configPath)

// Create subdirectories; this also implicitly creates the base directory
Expand Down Expand Up @@ -236,6 +237,14 @@ func (conf *Config) readFlags(cmdFlags *pflag.FlagSet) error {
func load(ctx context.Context, path string, flags *pflag.FlagSet) (conf *Config, err error) {
logger := simplog.FromContext(ctx)

// Clean up path
path = filepath.Clean(path)

path, err = filepath.Abs(path)
if err != nil {
return nil, errors.Wrap(err, "get absolute config path")
}

// If no explicit path is given, use default path
if path == "" {
path, err = defaultConfigPath()
Expand All @@ -246,15 +255,6 @@ func load(ctx context.Context, path string, flags *pflag.FlagSet) (conf *Config,
logger.Debugf("no explicit config path given, using default path: %q", path)
}

// Make sure the path is cross-platform compatible
path = filepath.FromSlash(path)

// Make config path absolute
path, err = filepath.Abs(path)
if err != nil {
return nil, errors.Wrap(err, "get absolute config path")
}

// Create default config
logger.Debugf("creating config provider with path: %q", path)
conf = &Config{
Expand Down Expand Up @@ -321,3 +321,18 @@ func (conf *Config) Validate() error {

return nil
}

// BaseDir returns the base directory of the configuration file.
func (conf *Config) BaseDir() string {
return filepath.Dir(conf.provider.ConfigFileUsed())
}

// PluginsDir returns the plugins' directory.
func (conf *Config) PluginsDir() string {
return filepath.Join(conf.BaseDir(), defaultPluginsDir)
}

// TemplatesDir returns the templates' directory.
func (conf *Config) TemplatesDir() string {
return filepath.Join(conf.BaseDir(), defaultTemplatesDir)
}
50 changes: 50 additions & 0 deletions internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,53 @@ func Test_defaultConfigPath(t *testing.T) {
}
}
}

func TestConfig_BaseDir(t *testing.T) {
t.Parallel()

type args struct {
path string
}
cases := []struct {
name string
args args
want string
}{
{
name: "default",
args: args{
path: "/home/user/.config/proji.toml",
},
want: filepath.Join("/home", "user", ".config"),
},
{
name: "complex",
args: args{
path: "/home/user/.config/proji/this/is/a/very/long/path/to/a/config/file.toml",
},
want: filepath.Join("/home", "user", ".config", "proji", "this", "is", "a", "very", "long", "path", "to", "a", "config"),
},
{
name: "current dir",
args: args{
path: "./proji.toml",
},
want: ".",
},
// TODO: Add Windows tests
}

for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

conf := &Config{
provider: newProvider(tc.args.path),
}
if got := conf.BaseDir(); got != filepath.FromSlash(tc.want) {
t.Fatalf("BaseDir() returned unexpected path: %q;\nbase directory of %q should be %q", got, tc.args.path, tc.want)
}
})
}
}
43 changes: 34 additions & 9 deletions internal/manager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,51 @@ import (
// TODO: Should probably be configurable.
const defaultServiceTimeout = 5 * time.Second

type LocalPaths struct {
Base string
Plugins string
Templates string
}

// Config is a package manager configuration. This config is shared between different types of package managers.
type Config struct {
// Address is the address of the remote package manager. If empty, the local package manager will be used.
Address string

// Auth contains the authentication information for the remote package manager.
Auth *config.Auth

// DB is the database connection.
DB *database.DB

// LocalPaths contains local filesystem paths that point to the base directory, the plugins/ directory and the
// templates/ directory. These paths are used by the local package manager to persist packages and templates.
LocalPaths *LocalPaths
}

// NewPackageManager is a convenience function that connects to a package manager based on the given address. If the
// address is empty, it will connect to the local package manager. Otherwise, it will connect to the remote package
// manager.
func NewPackageManager(ctx context.Context, address string, db *database.DB, auth *config.Auth) (packages.Manager, error) {
logger := simplog.FromContext(ctx)
func NewPackageManager(ctx context.Context, config Config) (packages.Manager, error) {
if config == (Config{}) {
return nil, errors.New("config is required")
}

// If an address is given, interpret that as an intent to connect to a remote package manager.
logger := simplog.FromContext(ctx)
logger.Debugf("creating a package manager")

address = strings.TrimSpace(address)
if address != "" {
logger.Debugf("server address not empty, connecting to remote package manager at %s", address)
// If an address is given, interpret that as an intent to connect to a remote package manager.
config.Address = strings.TrimSpace(config.Address)
if config.Address != "" {
logger.Debugf("server address not empty, connecting to remote package manager at %q", config.Address)

return packages.NewRemoteManager(address)
return packages.NewRemoteManager(config.Address)
}

// Otherwise, connect to the local package manager.
logger.Debugf("server address is empty, creating a local package manager")

repo, err := packageRepo.New(db)
repo, err := packageRepo.New(config.DB)
if err != nil {
return nil, errors.Wrap(err, "create package repository")
}
Expand All @@ -52,7 +77,7 @@ func NewPackageManager(ctx context.Context, address string, db *database.DB, aut
}

// Create the local package manager.
return packages.NewLocalManager(auth, service)
return packages.NewLocalManager(config.Auth, service)
}

// NewProjectManager returns a new project manager. Compared to the package manager, the project manager is always local,
Expand Down

0 comments on commit 917444f

Please sign in to comment.