Skip to content

Commit

Permalink
feat: allow to pass a custom fs
Browse files Browse the repository at this point in the history
closes #17
  • Loading branch information
caarlos0 committed Feb 3, 2025
1 parent 1dfb9e3 commit 4ce1454
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 14 deletions.
51 changes: 51 additions & 0 deletions fs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package keygen

import (
"io/fs"
"os"
)

// KeyFS is an interface that defines everything we need to read and write
// keys.
type KeyFS interface {
fs.ReadFileFS
fs.StatFS
Chmod(name string, mode fs.FileMode) error
MkdirAll(path string, perm fs.FileMode) error
WriteFile(path string, data []byte, perm fs.FileMode) error
}

var _ KeyFS = &RealFS{}

// RealFS is a KeyFS implementation that uses the real filesystem.
type RealFS struct{}

// WriteFile implements KeyFS.
func (n *RealFS) WriteFile(path string, data []byte, perm fs.FileMode) error {
return os.WriteFile(path, data, perm)
}

// MkdirAll implements KeyFS.
func (n *RealFS) MkdirAll(path string, perm fs.FileMode) error {
return os.MkdirAll(path, perm)
}

// Chmod implements KeyFS.
func (n *RealFS) Chmod(name string, mode fs.FileMode) error {
return os.Chmod(name, mode)
}

// Open implements KeyFS.
func (n *RealFS) Open(name string) (fs.File, error) {
return os.Open(name)
}

// ReadFile implements KeyFS.
func (n *RealFS) ReadFile(name string) ([]byte, error) {
return os.ReadFile(name)
}

// Stat implements KeyFS.
func (n *RealFS) Stat(name string) (fs.FileInfo, error) {
return os.Stat(name)
}
37 changes: 23 additions & 14 deletions keygen.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ type SSHKeyPair = KeyPair
type KeyPair struct {
path string // private key filename path; public key will have .pub appended
writeKeys bool
fs KeyFS
passphrase []byte
rsaBitSize int
ec elliptic.Curve
Expand Down Expand Up @@ -138,6 +139,13 @@ func WithWrite() Option {
}
}

// WithFS allows to set a different KeyFS implementation.
func WithFS(fs KeyFS) Option {
return func(s *KeyPair) {
s.fs = fs
}
}

// WithEllipticCurve sets the elliptic curve for the ECDSA key pair.
// Supported curves are P-256, P-384, and P-521.
// The default curve is P-384.
Expand All @@ -160,6 +168,7 @@ func New(path string, opts ...Option) (*KeyPair, error) {
rsaBitSize: rsaDefaultBits,
ec: elliptic.P384(),
keyType: Ed25519,
fs: &RealFS{},
}

for _, opt := range opts {
Expand All @@ -174,7 +183,7 @@ func New(path string, opts ...Option) (*KeyPair, error) {
}

if s.KeyPairExists() {
privData, err := os.ReadFile(s.privateKeyPath())
privData, err := s.fs.ReadFile(s.privateKeyPath())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -420,10 +429,10 @@ func (s *KeyPair) prepFilesystem() error {
return err
}

info, err := os.Stat(keyDir)
info, err := s.fs.Stat(keyDir)
if os.IsNotExist(err) {
// Directory doesn't exist: create it
return os.MkdirAll(keyDir, 0o700)
return s.fs.MkdirAll(keyDir, 0o700)
}
if err != nil {
// There was another error statting the directory; something is awry
Expand All @@ -435,17 +444,17 @@ func (s *KeyPair) prepFilesystem() error {
}
if info.Mode().Perm() != 0o700 {
// Permissions are wrong: fix 'em
if err := os.Chmod(keyDir, 0o700); err != nil {
if err := s.fs.Chmod(keyDir, 0o700); err != nil {
return FilesystemErr{Err: err}
}
}
}

// Make sure the files we're going to write to don't already exist
if fileExists(s.privateKeyPath()) {
if fileExists(s.fs, s.privateKeyPath()) {
return SSHKeysAlreadyExistErr{Path: s.privateKeyPath()}
}
if fileExists(s.publicKeyPath()) {
if fileExists(s.fs, s.publicKeyPath()) {
return SSHKeysAlreadyExistErr{Path: s.publicKeyPath()}
}

Expand All @@ -465,7 +474,7 @@ func (s *KeyPair) WriteKeys() error {
return err
}

if err := writeKeyToFile(priv, s.privateKeyPath()); err != nil {
if err := writeKeyToFile(s.fs, priv, s.privateKeyPath()); err != nil {
return err
}

Expand All @@ -474,23 +483,23 @@ func (s *KeyPair) WriteKeys() error {
ak = fmt.Sprintf("%s %s", ak, memo)
}

return writeKeyToFile([]byte(ak), s.publicKeyPath())
return writeKeyToFile(s.fs, []byte(ak), s.publicKeyPath())
}

// KeyPairExists checks if the SSH key pair exists on disk.
func (s *KeyPair) KeyPairExists() bool {
return fileExists(s.privateKeyPath())
return fileExists(s.fs, s.privateKeyPath())
}

func writeKeyToFile(keyBytes []byte, path string) error {
if _, err := os.Stat(path); os.IsNotExist(err) {
return os.WriteFile(path, keyBytes, 0o600)
func writeKeyToFile(fs KeyFS, keyBytes []byte, path string) error {
if _, err := fs.Stat(path); os.IsNotExist(err) {
return fs.WriteFile(path, keyBytes, 0o600)
}
return FilesystemErr{Err: fmt.Errorf("file %s already exists", path)}
}

func fileExists(path string) bool {
_, err := os.Stat(path)
func fileExists(fs KeyFS, path string) bool {
_, err := fs.Stat(path)
if os.IsNotExist(err) {
return false
}
Expand Down
2 changes: 2 additions & 0 deletions keygen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ func TestGenerateEd25519Keys(t *testing.T) {
k := &KeyPair{
path: filepath.Join(dir, filename),
keyType: Ed25519,
fs: &RealFS{},
}

t.Run("test generate SSH keys", func(t *testing.T) {
Expand Down Expand Up @@ -167,6 +168,7 @@ func TestGenerateECDSAKeys(t *testing.T) {
path: filepath.Join(dir, filename),
keyType: ECDSA,
ec: elliptic.P384(),
fs: &RealFS{},
}

t.Run("test generate SSH keys", func(t *testing.T) {
Expand Down

0 comments on commit 4ce1454

Please sign in to comment.