diff --git a/cmd/chains.go b/cmd/chains.go index 83d2ac5..61df9b2 100644 --- a/cmd/chains.go +++ b/cmd/chains.go @@ -84,7 +84,7 @@ $ %s chains list $ %s ch l`, appName, appName)), RunE: func(cmd *cobra.Command, args []string) error { if app.Config == nil { - return fmt.Errorf("config does not exist: %s", app.HomePath) + return fmt.Errorf("config is not initialized") } i := 1 diff --git a/cmd/config.go b/cmd/config.go index 787ccb2..373869e 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -2,12 +2,14 @@ package cmd import ( "fmt" + "os" "strings" "github.com/pelletier/go-toml/v2" "github.com/spf13/cobra" "github.com/bandprotocol/falcon/relayer" + "github.com/bandprotocol/falcon/relayer/config" ) // configCmd returns a command that manages global configuration file @@ -37,7 +39,7 @@ $ %s config show --home %s $ %s cfg list`, appName, defaultHome, appName)), RunE: func(cmd *cobra.Command, args []string) error { if app.Config == nil { - return fmt.Errorf("config does not exist: %s", app.HomePath) + return fmt.Errorf("config is not initialized") } b, err := toml.Marshal(app.Config) @@ -63,23 +65,33 @@ func configInitCmd(app *relayer.App) *cobra.Command { $ %s config init --home %s $ %s cfg i`, appName, defaultHome, appName)), RunE: func(cmd *cobra.Command, args []string) error { - home, err := cmd.Flags().GetString(flagHome) + filePath, err := cmd.Flags().GetString(flagFile) if err != nil { return err } - file, err := cmd.Flags().GetString(flagFile) - if err != nil { - return err + // Parse the config from the file if file's path is given + var cfg *config.Config + if filePath != "" { + b, err := os.ReadFile(filePath) + if err != nil { + return fmt.Errorf("cannot read a config file %s: %w", filePath, err) + } + + cfg, err = config.ParseConfig(b) + if err != nil { + return fmt.Errorf("parsing config error %w", err) + } } - if err := app.InitConfigFile(home, file); err != nil { + if err := app.SaveConfig(cfg); err != nil { return err } - return app.InitPassphrase() + passphrase := os.Getenv(PassphraseEnvKey) + return app.SavePassphrase(passphrase) }, } - cmd.Flags().StringP(flagFile, "f", "", "fetch toml data from specified file") + cmd.Flags().StringP(flagFile, "f", "", "input config .toml file path") return cmd } diff --git a/cmd/config_test.go b/cmd/config_test.go index 4e303be..6a83aef 100644 --- a/cmd/config_test.go +++ b/cmd/config_test.go @@ -27,7 +27,7 @@ func TestConfigShowNotInit(t *testing.T) { sys := relayertest.NewSystem(t) res := sys.RunWithInput(t, "config", "show") - require.ErrorContains(t, res.Err, "config does not exist:") + require.ErrorContains(t, res.Err, "config is not initialized") } func TestConfigInitDefault(t *testing.T) { diff --git a/cmd/keys.go b/cmd/keys.go index 37112d8..1f1d6ed 100644 --- a/cmd/keys.go +++ b/cmd/keys.go @@ -10,6 +10,7 @@ import ( "github.com/spf13/cobra" "github.com/bandprotocol/falcon/relayer" + chainstypes "github.com/bandprotocol/falcon/relayer/chains/types" ) const ( @@ -99,21 +100,26 @@ $ %s k a eth test-key`, appName, appName)), } } - // Add the key to the app - keyOutput, err := app.AddKey( - chainName, - keyName, - input.Mnemonic, - input.PrivateKey, - uint32(input.CoinType), - uint(input.Account), - uint(input.Index), - ) - if err != nil { - return err + var key *chainstypes.Key + if input.PrivateKey != "" { + key, err = app.AddKeyByPrivateKey(chainName, keyName, input.PrivateKey) + if err != nil { + return err + } + } else { + key, err = app.AddKeyByMnemonic( + chainName, keyName, + input.Mnemonic, + uint32(input.CoinType), + uint(input.Account), + uint(input.Index), + ) + if err != nil { + return err + } } - out, err := json.MarshalIndent(keyOutput, "", " ") + out, err := json.MarshalIndent(key, "", " ") if err != nil { return err } diff --git a/cmd/root.go b/cmd/root.go index 79e20bd..9a88ef5 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -32,7 +32,8 @@ var defaultHome = filepath.Join(os.Getenv("HOME"), ".falcon") // NewRootCmd returns the root command for falcon. func NewRootCmd(log *zap.Logger) *cobra.Command { passphrase := os.Getenv(PassphraseEnvKey) - app := falcon.NewApp(log, defaultHome, false, nil, passphrase, nil) + homePath := defaultHome + app := falcon.NewApp(log, false, nil, passphrase, nil) // RootCmd represents the base command when called without any subcommands rootCmd := &cobra.Command{ @@ -50,7 +51,7 @@ func NewRootCmd(log *zap.Logger) *cobra.Command { rootCmd.PersistentPreRunE = func(cmd *cobra.Command, _ []string) (err error) { // set up store - app.Store, err = store.NewFileSystem(app.HomePath) + app.Store, err = store.NewFileSystem(homePath) if err != nil { return err } @@ -85,7 +86,7 @@ func NewRootCmd(log *zap.Logger) *cobra.Command { } // Register --home flag - rootCmd.PersistentFlags().StringVar(&app.HomePath, flagHome, defaultHome, "set home directory") + rootCmd.PersistentFlags().StringVar(&homePath, flagHome, defaultHome, "set home directory") if err := viper.BindPFlag(flagHome, rootCmd.PersistentFlags().Lookup(flagHome)); err != nil { panic(err) } diff --git a/cmd/start.go b/cmd/start.go index 651bc6a..440960d 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -7,6 +7,7 @@ import ( "github.com/spf13/cobra" + "github.com/bandprotocol/falcon/internal/relayermetrics" "github.com/bandprotocol/falcon/relayer" ) @@ -31,12 +32,22 @@ $ %s start 1 12 # start relaying data from specific tunnelIDs.`, appName, a tunnelIDs = append(tunnelIDs, tunnelID) } - metricsListenAddrFlag, err := cmd.Flags().GetString(flagMetricsListenAddr) + metricsListenAddr, err := cmd.Flags().GetString(flagMetricsListenAddr) if err != nil { return err } - return app.Start(cmd.Context(), tunnelIDs, metricsListenAddrFlag) + // setup metrics server + if metricsListenAddr == "" { + metricsListenAddr = app.Config.Global.MetricsListenAddr + } + if metricsListenAddr != "" { + if err := relayermetrics.StartMetricsServer(cmd.Context(), app.Log, metricsListenAddr); err != nil { + return err + } + } + + return app.Start(cmd.Context(), tunnelIDs) }, } diff --git a/internal/relayertest/mocks/chain_provider.go b/internal/relayertest/mocks/chain_provider.go index 32acc69..c9cca69 100644 --- a/internal/relayertest/mocks/chain_provider.go +++ b/internal/relayertest/mocks/chain_provider.go @@ -43,19 +43,34 @@ func (m *MockChainProvider) EXPECT() *MockChainProviderMockRecorder { return m.recorder } -// AddKey mocks base method. -func (m *MockChainProvider) AddKey(keyName, mnemonic, privateKeyHex string, coinType uint32, account, index uint) (*types0.Key, error) { +// AddKeyByMnemonic mocks base method. +func (m *MockChainProvider) AddKeyByMnemonic(keyName, mnemonic string, coinType uint32, account, index uint) (*types0.Key, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddKey", keyName, mnemonic, privateKeyHex, coinType, account, index) + ret := m.ctrl.Call(m, "AddKeyByMnemonic", keyName, mnemonic, coinType, account, index) ret0, _ := ret[0].(*types0.Key) ret1, _ := ret[1].(error) return ret0, ret1 } -// AddKey indicates an expected call of AddKey. -func (mr *MockChainProviderMockRecorder) AddKey(keyName, mnemonic, privateKeyHex, coinType, account, index any) *gomock.Call { +// AddKeyByMnemonic indicates an expected call of AddKeyByMnemonic. +func (mr *MockChainProviderMockRecorder) AddKeyByMnemonic(keyName, mnemonic, coinType, account, index any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddKey", reflect.TypeOf((*MockChainProvider)(nil).AddKey), keyName, mnemonic, privateKeyHex, coinType, account, index) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddKeyByMnemonic", reflect.TypeOf((*MockChainProvider)(nil).AddKeyByMnemonic), keyName, mnemonic, coinType, account, index) +} + +// AddKeyByPrivateKey mocks base method. +func (m *MockChainProvider) AddKeyByPrivateKey(keyName, privateKeyHex string) (*types0.Key, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddKeyByPrivateKey", keyName, privateKeyHex) + ret0, _ := ret[0].(*types0.Key) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AddKeyByPrivateKey indicates an expected call of AddKeyByPrivateKey. +func (mr *MockChainProviderMockRecorder) AddKeyByPrivateKey(keyName, privateKeyHex any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddKeyByPrivateKey", reflect.TypeOf((*MockChainProvider)(nil).AddKeyByPrivateKey), keyName, privateKeyHex) } // DeleteKey mocks base method. @@ -101,20 +116,6 @@ func (mr *MockChainProviderMockRecorder) Init(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockChainProvider)(nil).Init), ctx) } -// IsKeyNameExist mocks base method. -func (m *MockChainProvider) IsKeyNameExist(keyName string) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsKeyNameExist", keyName) - ret0, _ := ret[0].(bool) - return ret0 -} - -// IsKeyNameExist indicates an expected call of IsKeyNameExist. -func (mr *MockChainProviderMockRecorder) IsKeyNameExist(keyName any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsKeyNameExist", reflect.TypeOf((*MockChainProvider)(nil).IsKeyNameExist), keyName) -} - // ListKeys mocks base method. func (m *MockChainProvider) ListKeys() []*types0.Key { m.ctrl.T.Helper() @@ -226,19 +227,34 @@ func (m *MockKeyProvider) EXPECT() *MockKeyProviderMockRecorder { return m.recorder } -// AddKey mocks base method. -func (m *MockKeyProvider) AddKey(keyName, mnemonic, privateKeyHex string, coinType uint32, account, index uint) (*types0.Key, error) { +// AddKeyByMnemonic mocks base method. +func (m *MockKeyProvider) AddKeyByMnemonic(keyName, mnemonic string, coinType uint32, account, index uint) (*types0.Key, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddKey", keyName, mnemonic, privateKeyHex, coinType, account, index) + ret := m.ctrl.Call(m, "AddKeyByMnemonic", keyName, mnemonic, coinType, account, index) ret0, _ := ret[0].(*types0.Key) ret1, _ := ret[1].(error) return ret0, ret1 } -// AddKey indicates an expected call of AddKey. -func (mr *MockKeyProviderMockRecorder) AddKey(keyName, mnemonic, privateKeyHex, coinType, account, index any) *gomock.Call { +// AddKeyByMnemonic indicates an expected call of AddKeyByMnemonic. +func (mr *MockKeyProviderMockRecorder) AddKeyByMnemonic(keyName, mnemonic, coinType, account, index any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddKey", reflect.TypeOf((*MockKeyProvider)(nil).AddKey), keyName, mnemonic, privateKeyHex, coinType, account, index) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddKeyByMnemonic", reflect.TypeOf((*MockKeyProvider)(nil).AddKeyByMnemonic), keyName, mnemonic, coinType, account, index) +} + +// AddKeyByPrivateKey mocks base method. +func (m *MockKeyProvider) AddKeyByPrivateKey(keyName, privateKeyHex string) (*types0.Key, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddKeyByPrivateKey", keyName, privateKeyHex) + ret0, _ := ret[0].(*types0.Key) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AddKeyByPrivateKey indicates an expected call of AddKeyByPrivateKey. +func (mr *MockKeyProviderMockRecorder) AddKeyByPrivateKey(keyName, privateKeyHex any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddKeyByPrivateKey", reflect.TypeOf((*MockKeyProvider)(nil).AddKeyByPrivateKey), keyName, privateKeyHex) } // DeleteKey mocks base method. @@ -270,20 +286,6 @@ func (mr *MockKeyProviderMockRecorder) ExportPrivateKey(keyName any) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExportPrivateKey", reflect.TypeOf((*MockKeyProvider)(nil).ExportPrivateKey), keyName) } -// IsKeyNameExist mocks base method. -func (m *MockKeyProvider) IsKeyNameExist(keyName string) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsKeyNameExist", keyName) - ret0, _ := ret[0].(bool) - return ret0 -} - -// IsKeyNameExist indicates an expected call of IsKeyNameExist. -func (mr *MockKeyProviderMockRecorder) IsKeyNameExist(keyName any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsKeyNameExist", reflect.TypeOf((*MockKeyProvider)(nil).IsKeyNameExist), keyName) -} - // ListKeys mocks base method. func (m *MockKeyProvider) ListKeys() []*types0.Key { m.ctrl.T.Helper() diff --git a/internal/relayertest/mocks/chain_provider_config.go b/internal/relayertest/mocks/chain_provider_config.go index 573b809..9cfa2d6 100644 --- a/internal/relayertest/mocks/chain_provider_config.go +++ b/internal/relayertest/mocks/chain_provider_config.go @@ -57,18 +57,18 @@ func (mr *MockChainProviderConfigMockRecorder) GetChainType() *gomock.Call { } // NewChainProvider mocks base method. -func (m *MockChainProviderConfig) NewChainProvider(chainName string, log *zap.Logger, homePath string, debug bool, wallet wallet.Wallet) (chains.ChainProvider, error) { +func (m *MockChainProviderConfig) NewChainProvider(chainName string, log *zap.Logger, debug bool, wallet wallet.Wallet) (chains.ChainProvider, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewChainProvider", chainName, log, homePath, debug, wallet) + ret := m.ctrl.Call(m, "NewChainProvider", chainName, log, debug, wallet) ret0, _ := ret[0].(chains.ChainProvider) ret1, _ := ret[1].(error) return ret0, ret1 } // NewChainProvider indicates an expected call of NewChainProvider. -func (mr *MockChainProviderConfigMockRecorder) NewChainProvider(chainName, log, homePath, debug, wallet any) *gomock.Call { +func (mr *MockChainProviderConfigMockRecorder) NewChainProvider(chainName, log, debug, wallet any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewChainProvider", reflect.TypeOf((*MockChainProviderConfig)(nil).NewChainProvider), chainName, log, homePath, debug, wallet) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewChainProvider", reflect.TypeOf((*MockChainProviderConfig)(nil).NewChainProvider), chainName, log, debug, wallet) } // Validate mocks base method. diff --git a/internal/relayertest/mocks/store.go b/internal/relayertest/mocks/store.go index 8af0f94..6d0ae7d 100644 --- a/internal/relayertest/mocks/store.go +++ b/internal/relayertest/mocks/store.go @@ -116,16 +116,30 @@ func (mr *MockStoreMockRecorder) SaveConfig(cfg any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveConfig", reflect.TypeOf((*MockStore)(nil).SaveConfig), cfg) } -// SaveHashedPassphrase mocks base method. -func (m *MockStore) SaveHashedPassphrase(hashedPassphrase []byte) error { +// SavePassphrase mocks base method. +func (m *MockStore) SavePassphrase(passphrase string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SaveHashedPassphrase", hashedPassphrase) + ret := m.ctrl.Call(m, "SavePassphrase", passphrase) ret0, _ := ret[0].(error) return ret0 } -// SaveHashedPassphrase indicates an expected call of SaveHashedPassphrase. -func (mr *MockStoreMockRecorder) SaveHashedPassphrase(hashedPassphrase any) *gomock.Call { +// SavePassphrase indicates an expected call of SavePassphrase. +func (mr *MockStoreMockRecorder) SavePassphrase(passphrase any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveHashedPassphrase", reflect.TypeOf((*MockStore)(nil).SaveHashedPassphrase), hashedPassphrase) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePassphrase", reflect.TypeOf((*MockStore)(nil).SavePassphrase), passphrase) +} + +// ValidatePassphrase mocks base method. +func (m *MockStore) ValidatePassphrase(passphrase string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidatePassphrase", passphrase) + ret0, _ := ret[0].(error) + return ret0 +} + +// ValidatePassphrase indicates an expected call of ValidatePassphrase. +func (mr *MockStoreMockRecorder) ValidatePassphrase(passphrase any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidatePassphrase", reflect.TypeOf((*MockStore)(nil).ValidatePassphrase), passphrase) } diff --git a/relayer/app.go b/relayer/app.go index 06367c6..aa2f331 100644 --- a/relayer/app.go +++ b/relayer/app.go @@ -1,16 +1,12 @@ package relayer import ( - "bytes" "context" - "crypto/sha256" "fmt" "math/big" - "os" "go.uber.org/zap" - "github.com/bandprotocol/falcon/internal/relayermetrics" "github.com/bandprotocol/falcon/relayer/band" bandtypes "github.com/bandprotocol/falcon/relayer/band/types" "github.com/bandprotocol/falcon/relayer/chains" @@ -20,19 +16,12 @@ import ( "github.com/bandprotocol/falcon/relayer/types" ) -const ( - ConfigFolderName = "config" - ConfigFileName = "config.toml" - PassphraseFileName = "passphrase.hash" -) - // App is the main application struct. type App struct { - Log *zap.Logger - HomePath string - Debug bool - Config *config.Config - Store store.Store + Log *zap.Logger + Debug bool + Config *config.Config + Store store.Store TargetChains chains.ChainProviders BandClient band.Client @@ -42,7 +31,6 @@ type App struct { // NewApp creates a new App instance. func NewApp( log *zap.Logger, - homePath string, debug bool, config *config.Config, passphrase string, @@ -50,7 +38,6 @@ func NewApp( ) *App { app := App{ Log: log, - HomePath: homePath, Debug: debug, Config: config, Store: store, @@ -106,7 +93,7 @@ func (a *App) initTargetChains() error { return err } - cp, err := chainConfig.NewChainProvider(chainName, a.Log, a.HomePath, a.Debug, wallet) + cp, err := chainConfig.NewChainProvider(chainName, a.Log, a.Debug, wallet) if err != nil { a.Log.Error("Cannot create chain provider", zap.Error(err), @@ -121,8 +108,8 @@ func (a *App) initTargetChains() error { return nil } -// InitConfigFile initializes the configuration to the given path. -func (a *App) InitConfigFile(homePath string, customFilePath string) error { +// SaveConfig saves the configuration into the application's store. +func (a *App) SaveConfig(cfg *config.Config) error { // Check if config already exists if ok, err := a.Store.HasConfig(); err != nil { return err @@ -130,34 +117,18 @@ func (a *App) InitConfigFile(homePath string, customFilePath string) error { return fmt.Errorf("config already exists") } - // Load config from given custom file path if exists - var cfg *config.Config - switch { - case customFilePath != "": - b, err := os.ReadFile(customFilePath) - if err != nil { - return fmt.Errorf("cannot read a config file %s: %w", customFilePath, err) - } - - cfg, err = config.ParseConfig(b) - if err != nil { - return fmt.Errorf("parsing config error %w", err) - } - default: + if cfg == nil { cfg = config.DefaultConfig() // Initialize with DefaultConfig if no file is provided } + a.Config = cfg return a.Store.SaveConfig(cfg) } -// InitPassphrase hashes the provided passphrase and saves it to the given path. -func (a *App) InitPassphrase() error { - // Load and hash the passphrase - h := sha256.New() - h.Write([]byte(a.Passphrase)) - hashedPassphrase := h.Sum(nil) - - return a.Store.SaveHashedPassphrase(hashedPassphrase) +// SavePassphrase hash the provided passphrase and save it into the application's store. +func (a *App) SavePassphrase(passphrase string) error { + a.Passphrase = passphrase + return a.Store.SavePassphrase(passphrase) } // QueryTunnelInfo queries tunnel information by given tunnel ID @@ -208,7 +179,7 @@ func (a *App) QueryTunnelPacketInfo(ctx context.Context, tunnelID uint64, sequen // AddChainConfig adds a new chain configuration to the config file. func (a *App) AddChainConfig(chainName string, filePath string) error { if a.Config == nil { - return fmt.Errorf("config does not exist: %s", a.HomePath) + return fmt.Errorf("config is not initialized") } if _, ok := a.Config.TargetChains[chainName]; ok { @@ -227,7 +198,7 @@ func (a *App) AddChainConfig(chainName string, filePath string) error { // DeleteChainConfig deletes the chain configuration from the config file. func (a *App) DeleteChainConfig(chainName string) error { if a.Config == nil { - return fmt.Errorf("config does not exist: %s", a.HomePath) + return fmt.Errorf("config is not initialized") } if _, ok := a.Config.TargetChains[chainName]; !ok { @@ -241,7 +212,7 @@ func (a *App) DeleteChainConfig(chainName string) error { // GetChainConfig retrieves the chain configuration by given chain name. func (a *App) GetChainConfig(chainName string) (chains.ChainProviderConfig, error) { if a.Config == nil { - return nil, fmt.Errorf("config does not exist: %s", a.HomePath) + return nil, fmt.Errorf("config is not initialized") } chainProviders := a.Config.TargetChains @@ -253,50 +224,50 @@ func (a *App) GetChainConfig(chainName string) (chains.ChainProviderConfig, erro return chainProviders[chainName], nil } -// AddKey adds a new key to the chain provider. -func (a *App) AddKey( +// AddKeyByPrivateKey adds a new key to the chain provider using a private key. +func (a *App) AddKeyByPrivateKey(chainName string, keyName string, privateKey string) (*chainstypes.Key, error) { + if err := a.Store.ValidatePassphrase(a.Passphrase); err != nil { + return nil, err + } + + cp, err := a.getChainProvider(chainName) + if err != nil { + return nil, err + } + + return cp.AddKeyByPrivateKey(keyName, privateKey) +} + +// AddKeyByMnemonic adds a new key to the chain provider using a mnemonic phrase. +func (a *App) AddKeyByMnemonic( chainName string, keyName string, mnemonic string, - privateKey string, coinType uint32, account uint, index uint, ) (*chainstypes.Key, error) { - if a.Config == nil { - return nil, fmt.Errorf("config does not exist: %s", a.HomePath) - } - - if err := a.ValidatePassphrase(a.Passphrase); err != nil { + if err := a.Store.ValidatePassphrase(a.Passphrase); err != nil { return nil, err } - cp, exist := a.TargetChains[chainName] - if !exist { - return nil, fmt.Errorf("chain name does not exist: %s", chainName) - } - - keyOutput, err := cp.AddKey(keyName, mnemonic, privateKey, coinType, account, index) + cp, err := a.getChainProvider(chainName) if err != nil { return nil, err } - return keyOutput, nil + return cp.AddKeyByMnemonic(keyName, mnemonic, coinType, account, index) } // DeleteKey deletes the key from the chain provider. func (a *App) DeleteKey(chainName string, keyName string) error { - if a.Config == nil { - return fmt.Errorf("config does not exist: %s", a.HomePath) - } - - if err := a.ValidatePassphrase(a.Passphrase); err != nil { + if err := a.Store.ValidatePassphrase(a.Passphrase); err != nil { return err } - cp, exist := a.TargetChains[chainName] - if !exist { - return fmt.Errorf("chain name does not exist: %s", chainName) + cp, err := a.getChainProvider(chainName) + if err != nil { + return err } return cp.DeleteKey(keyName) @@ -304,36 +275,23 @@ func (a *App) DeleteKey(chainName string, keyName string) error { // ExportKey exports the private key from the chain provider. func (a *App) ExportKey(chainName string, keyName string) (string, error) { - if a.Config == nil { - return "", fmt.Errorf("config does not exist: %s", a.HomePath) - } - - if err := a.ValidatePassphrase(a.Passphrase); err != nil { + if err := a.Store.ValidatePassphrase(a.Passphrase); err != nil { return "", err } - cp, exist := a.TargetChains[chainName] - if !exist { - return "", fmt.Errorf("chain name does not exist: %s", chainName) - } - - privateKey, err := cp.ExportPrivateKey(keyName) + cp, err := a.getChainProvider(chainName) if err != nil { return "", err } - return privateKey, nil + return cp.ExportPrivateKey(keyName) } // ListKeys retrieves the list of keys from the chain provider. func (a *App) ListKeys(chainName string) ([]*chainstypes.Key, error) { - if a.Config == nil { - return nil, fmt.Errorf("config does not exist: %s", a.HomePath) - } - - cp, exist := a.TargetChains[chainName] - if !exist { - return nil, fmt.Errorf("chain name does not exist: %s", chainName) + cp, err := a.getChainProvider(chainName) + if err != nil { + return nil, err } return cp.ListKeys(), nil @@ -341,13 +299,9 @@ func (a *App) ListKeys(chainName string) ([]*chainstypes.Key, error) { // ShowKey retrieves the key information from the chain provider. func (a *App) ShowKey(chainName string, keyName string) (string, error) { - if a.Config == nil { - return "", fmt.Errorf("config does not exist: %s", a.HomePath) - } - - cp, exist := a.TargetChains[chainName] - if !exist { - return "", fmt.Errorf("chain name does not exist: %s", chainName) + cp, err := a.getChainProvider(chainName) + if err != nil { + return "", err } return cp.ShowKey(keyName) @@ -355,59 +309,20 @@ func (a *App) ShowKey(chainName string, keyName string) (string, error) { // QueryBalance retrieves the balance of the key from the chain provider. func (a *App) QueryBalance(ctx context.Context, chainName string, keyName string) (*big.Int, error) { - if a.Config == nil { - return nil, fmt.Errorf("config does not exist: %s", a.HomePath) - } - - cp, exist := a.TargetChains[chainName] - - if !exist { - return nil, fmt.Errorf("chain name does not exist: %s", chainName) - } - - return cp.QueryBalance(ctx, keyName) -} - -// ValidatePassphrase checks if the provided passphrase (from the environment) -// matches the hashed passphrase stored on disk. -func (a *App) ValidatePassphrase(envPassphrase string) error { - // prepare bytes slices of hashed env passphrase - h := sha256.New() - h.Write([]byte(envPassphrase)) - hashedPassphrase := h.Sum(nil) - - // load passphrase from local disk - storedHashedPassphrase, err := a.Store.GetHashedPassphrase() + cp, err := a.getChainProvider(chainName) if err != nil { - return err - } - - if !bytes.Equal(hashedPassphrase, storedHashedPassphrase) { - return fmt.Errorf("invalid passphrase: the provided passphrase does not match the stored hashed passphrase") + return nil, err } - return nil + return cp.QueryBalance(ctx, keyName) } // Start starts the tunnel relayer program. -func (a *App) Start( - ctx context.Context, - tunnelIDs []uint64, - metricsListenAddrFlag string, -) error { +func (a *App) Start(ctx context.Context, tunnelIDs []uint64) error { a.Log.Info("Starting tunnel relayer") // validate passphrase - if err := a.ValidatePassphrase(a.Passphrase); err != nil { - return err - } - - // setup metrics server - metricsListenAddr := a.Config.Global.MetricsListenAddr - if metricsListenAddrFlag != "" { - metricsListenAddr = metricsListenAddrFlag - } - if err := a.setupMetricsServer(ctx, metricsListenAddr); err != nil { + if err := a.Store.ValidatePassphrase(a.Passphrase); err != nil { return err } @@ -451,7 +366,7 @@ func (a *App) Relay(ctx context.Context, tunnelID uint64) error { return err } - if err := a.ValidatePassphrase(a.Passphrase); err != nil { + if err := a.Store.ValidatePassphrase(a.Passphrase); err != nil { return err } @@ -482,18 +397,16 @@ func (a *App) Relay(ctx context.Context, tunnelID uint64) error { return err } -// setupMetricsServer starts the metrics server if enabled. -func (a *App) setupMetricsServer( - ctx context.Context, - metricsListenAddr string, -) error { - if metricsListenAddr == "" { - a.Log.Warn( - "Metrics server is disabled. It is controlled by the global config, and setting --metrics-listen-addr will override it and enable the server.", - ) - return nil +// getChainProvider retrieves the chain provider by given chain name. +func (a *App) getChainProvider(chainName string) (chains.ChainProvider, error) { + if a.Config == nil { + return nil, fmt.Errorf("config is not initialized") + } + + cp, exist := a.TargetChains[chainName] + if !exist { + return nil, fmt.Errorf("chain name does not exist: %s", chainName) } - // start server - return relayermetrics.StartMetricsServer(ctx, a.Log, metricsListenAddr) + return cp, nil } diff --git a/relayer/app_test.go b/relayer/app_test.go index 40234bc..4d24a0a 100644 --- a/relayer/app_test.go +++ b/relayer/app_test.go @@ -34,7 +34,7 @@ type AppTestSuite struct { chainProviderConfig *mocks.MockChainProviderConfig chainProvider *mocks.MockChainProvider client *mocks.MockClient - mockStore *mocks.MockStore + store *mocks.MockStore passphrase string hashedPassphrase []byte @@ -42,7 +42,6 @@ type AppTestSuite struct { // SetupTest sets up the test suite by creating a temporary directory and declare mock objects. func (s *AppTestSuite) SetupTest() { - tmpDir := s.T().TempDir() ctrl := gomock.NewController(s.T()) log := zap.NewNop() @@ -50,7 +49,7 @@ func (s *AppTestSuite) SetupTest() { s.chainProviderConfig = mocks.NewMockChainProviderConfig(ctrl) s.chainProvider = mocks.NewMockChainProvider(ctrl) s.client = mocks.NewMockClient(ctrl) - s.mockStore = mocks.NewMockStore(ctrl) + s.store = mocks.NewMockStore(ctrl) cfg := config.Config{ BandChain: band.Config{ @@ -68,13 +67,12 @@ func (s *AppTestSuite) SetupTest() { h := sha256.New() h.Write([]byte(s.passphrase)) s.hashedPassphrase = h.Sum(nil) - s.mockStore.EXPECT().GetHashedPassphrase().Return(s.hashedPassphrase, nil).AnyTimes() + s.store.EXPECT().GetHashedPassphrase().Return(s.hashedPassphrase, nil).AnyTimes() s.app = &relayer.App{ - Log: log, - HomePath: tmpDir, - Config: &cfg, - Store: s.mockStore, + Log: log, + Config: &cfg, + Store: s.store, TargetChains: map[string]chains.ChainProvider{ "testnet_evm": s.chainProvider, }, @@ -88,62 +86,46 @@ func TestAppTestSuite(t *testing.T) { } func (s *AppTestSuite) TestInitConfig() { + customCfg := &config.Config{ + BandChain: band.Config{ + RpcEndpoints: []string{"http://localhost:26659"}, + Timeout: 50, + }, + TargetChains: map[string]chains.ChainProviderConfig{}, + Global: config.GlobalConfig{ + CheckingPacketInterval: time.Minute, + }, + } + testcases := []struct { name string preprocess func() - in string - out *config.Config + in *config.Config err error }{ { name: "success - default", - in: "", + in: nil, preprocess: func() { - s.mockStore.EXPECT().HasConfig().Return(false, nil) - s.mockStore.EXPECT().SaveConfig(config.DefaultConfig()).Return(nil) + s.store.EXPECT().HasConfig().Return(false, nil) + s.store.EXPECT().SaveConfig(config.DefaultConfig()).Return(nil) }, }, { name: "config already exists", preprocess: func() { - s.mockStore.EXPECT().HasConfig().Return(true, nil) + s.store.EXPECT().HasConfig().Return(true, nil) }, - in: "", + in: nil, err: fmt.Errorf("config already exists"), }, { name: "init config from specific file", preprocess: func() { - customCfgPath := path.Join(s.app.HomePath, "custom.toml") - cfg := ` - [target_chains] - - [global] - checking_packet_interval = 60000000000 - - [bandchain] - rpc_endpoints = ['http://localhost:26659'] - timeout = 50 - ` - - err := os.WriteFile(customCfgPath, []byte(cfg), 0o600) - s.Require().NoError(err) - - expectCfg := &config.Config{ - BandChain: band.Config{ - RpcEndpoints: []string{"http://localhost:26659"}, - Timeout: 50, - }, - TargetChains: map[string]chains.ChainProviderConfig{}, - Global: config.GlobalConfig{ - CheckingPacketInterval: time.Minute, - }, - } - - s.mockStore.EXPECT().HasConfig().Return(false, nil) - s.mockStore.EXPECT().SaveConfig(expectCfg).Return(nil) + s.store.EXPECT().HasConfig().Return(false, nil) + s.store.EXPECT().SaveConfig(customCfg).Return(nil) }, - in: path.Join(s.app.HomePath, "custom.toml"), + in: customCfg, }, } @@ -153,7 +135,7 @@ func (s *AppTestSuite) TestInitConfig() { tc.preprocess() } - err := s.app.InitConfigFile(s.app.HomePath, tc.in) + err := s.app.SaveConfig(tc.in) if tc.err != nil { s.Require().ErrorContains(err, tc.err.Error()) @@ -165,7 +147,8 @@ func (s *AppTestSuite) TestInitConfig() { } func (s *AppTestSuite) TestAddChainConfig() { - newHomePath := path.Join(s.app.HomePath, "new_folder") + tmpDir := s.T().TempDir() + newHomePath := path.Join(tmpDir, "new_folder") err := os.Mkdir(newHomePath, os.ModePerm) s.Require().NoError(err) @@ -195,7 +178,7 @@ func (s *AppTestSuite) TestAddChainConfig() { cfg, err := config.ParseConfig([]byte(relayertest.DefaultCfgTextWithChainCfg)) s.Require().NoError(err) - s.mockStore.EXPECT().SaveConfig(cfg).Return(nil) + s.store.EXPECT().SaveConfig(cfg).Return(nil) }, }, { @@ -269,7 +252,7 @@ func (s *AppTestSuite) TestDeleteChainConfig() { name: "success", in: "testnet", preprocess: func() { - s.mockStore.EXPECT().SaveConfig(config.DefaultConfig()).Return(nil) + s.store.EXPECT().SaveConfig(config.DefaultConfig()).Return(nil) }, out: relayertest.DefaultCfgText, }, @@ -446,23 +429,19 @@ func (s *AppTestSuite) TestInitPassphrase() { ctrl := gomock.NewController(s.T()) newStoreMock := mocks.NewMockStore(ctrl) - s.app.Passphrase = "new_passphrase" s.app.Store = newStoreMock - newStoreMock.EXPECT(). - SaveHashedPassphrase([]byte{ - 194, 83, 183, 41, 238, 49, 98, 232, 230, 229, 194, - 192, 115, 133, 235, 215, 215, 206, 160, 68, 116, - 34, 59, 169, 179, 24, 231, 151, 191, 178, 90, 202, - }). - Return(nil) + newStoreMock.EXPECT().SavePassphrase("new_passphrase").Return(nil) // Call InitPassphrase - err := s.app.InitPassphrase() + err := s.app.SavePassphrase("new_passphrase") s.Require().NoError(err) + s.Require().Equal("new_passphrase", s.app.Passphrase) } func (s *AppTestSuite) TestAddKey() { + s.store.EXPECT().ValidatePassphrase(s.passphrase).Return(nil).AnyTimes() + testcases := []struct { name string chainName string @@ -485,13 +464,9 @@ func (s *AppTestSuite) TestAddKey() { out: chainstypes.NewKey("", "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266", ""), preprocess: func() { s.chainProvider.EXPECT(). - AddKey( + AddKeyByPrivateKey( "testkey", - "", "0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80", - uint32(60), - uint(0), - uint(0), ). Return(chainstypes.NewKey("", "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266", ""), nil) }, @@ -504,13 +479,9 @@ func (s *AppTestSuite) TestAddKey() { coinType: 60, preprocess: func() { s.chainProvider.EXPECT(). - AddKey( + AddKeyByPrivateKey( "testkey", - "", "0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80", - uint32(60), - uint(0), - uint(0), ). Return(nil, fmt.Errorf("add key error")) }, @@ -532,14 +503,10 @@ func (s *AppTestSuite) TestAddKey() { tc.preprocess() } - actual, err := s.app.AddKey( + actual, err := s.app.AddKeyByPrivateKey( tc.chainName, tc.keyName, - tc.mnemonic, tc.privateKey, - tc.coinType, - tc.account, - tc.index, ) if tc.err != nil { @@ -553,6 +520,8 @@ func (s *AppTestSuite) TestAddKey() { } func (s *AppTestSuite) TestDeleteKey() { + s.store.EXPECT().ValidatePassphrase(s.passphrase).Return(nil).AnyTimes() + testcases := []struct { name string chainName string @@ -607,6 +576,8 @@ func (s *AppTestSuite) TestDeleteKey() { } func (s *AppTestSuite) TestExportKey() { + s.store.EXPECT().ValidatePassphrase(s.passphrase).Return(nil).AnyTimes() + testcases := []struct { name string chainName string @@ -766,25 +737,3 @@ func (s *AppTestSuite) TestShowKey() { }) } } - -func (s *AppTestSuite) TestValidatePassphraseInvalidPassphrase() { - testcases := []struct { - name string - envPassphrase string - err error - }{ - {name: "valid", envPassphrase: "secret", err: nil}, - {name: "invalid", envPassphrase: "invalid", err: fmt.Errorf("invalid passphrase")}, - } - - for _, tc := range testcases { - s.Run(tc.name, func() { - err := s.app.ValidatePassphrase(tc.envPassphrase) - if tc.err != nil { - s.Require().ErrorContains(err, tc.err.Error()) - } else { - s.Require().NoError(err) - } - }) - } -} diff --git a/relayer/chains/config.go b/relayer/chains/config.go index 47e9073..81f2349 100644 --- a/relayer/chains/config.go +++ b/relayer/chains/config.go @@ -28,7 +28,6 @@ type ChainProviderConfig interface { NewChainProvider( chainName string, log *zap.Logger, - homePath string, debug bool, wallet wallet.Wallet, ) (ChainProvider, error) diff --git a/relayer/chains/evm/config.go b/relayer/chains/evm/config.go index daa25ad..19b690d 100644 --- a/relayer/chains/evm/config.go +++ b/relayer/chains/evm/config.go @@ -33,13 +33,12 @@ type EVMChainProviderConfig struct { func (cpc *EVMChainProviderConfig) NewChainProvider( chainName string, log *zap.Logger, - homePath string, debug bool, wallet wallet.Wallet, ) (chains.ChainProvider, error) { client := NewClient(chainName, cpc, log) - return NewEVMChainProvider(chainName, client, cpc, log, homePath, wallet) + return NewEVMChainProvider(chainName, client, cpc, log, wallet) } // Validate validates the EVM chain provider configuration. diff --git a/relayer/chains/evm/keys.go b/relayer/chains/evm/keys.go index ab108f5..4ed9ecb 100644 --- a/relayer/chains/evm/keys.go +++ b/relayer/chains/evm/keys.go @@ -25,39 +25,24 @@ const ( infoFileName = "info.toml" ) -func (cp *EVMChainProvider) AddKey( +// AddKeyByMnemonic adds a key using a mnemonic phrase. +func (cp *EVMChainProvider) AddKeyByMnemonic( keyName string, mnemonic string, - privateKey string, coinType uint32, account uint, index uint, ) (*chainstypes.Key, error) { - if privateKey != "" { - return cp.AddKeyWithPrivateKey(keyName, privateKey) - } - var err error - // Generate mnemonic if not provided if mnemonic == "" { mnemonic, err = hdwallet.NewMnemonic(mnemonicSize) if err != nil { return nil, err } } - return cp.AddKeyWithMnemonic(keyName, mnemonic, coinType, account, index) -} -// AddKeyWithMnemonic adds a key using a mnemonic phrase. -func (cp *EVMChainProvider) AddKeyWithMnemonic( - keyName string, - mnemonic string, - coinType uint32, - account uint, - index uint, -) (*chainstypes.Key, error) { // Generate private key using mnemonic - priv, err := cp.generatePrivateKey(mnemonic, coinType, account, index) + priv, err := generatePrivateKey(mnemonic, coinType, account, index) if err != nil { return nil, err } @@ -65,8 +50,8 @@ func (cp *EVMChainProvider) AddKeyWithMnemonic( return cp.finalizeKeyAddition(keyName, priv, mnemonic) } -// AddKeyWithPrivateKey adds a key using a raw private key. -func (cp *EVMChainProvider) AddKeyWithPrivateKey(keyName, privateKey string) (*chainstypes.Key, error) { +// AddKeyByPrivateKey adds a key using a raw private key. +func (cp *EVMChainProvider) AddKeyByPrivateKey(keyName, privateKey string) (*chainstypes.Key, error) { // Convert private key from hex priv, err := crypto.HexToECDSA(StripPrivateKeyPrefix(privateKey)) if err != nil { @@ -125,14 +110,8 @@ func (cp *EVMChainProvider) ShowKey(keyName string) (string, error) { return address, nil } -// IsKeyNameExist checks whether the given key name is already in use. -func (cp *EVMChainProvider) IsKeyNameExist(keyName string) bool { - _, ok := cp.Wallet.GetAddress(keyName) - return ok -} - // generatePrivateKey generates private key from given mnemonic. -func (cp *EVMChainProvider) generatePrivateKey( +func generatePrivateKey( mnemonic string, coinType uint32, account uint, @@ -142,16 +121,18 @@ func (cp *EVMChainProvider) generatePrivateKey( if err != nil { return nil, err } + hdPath := fmt.Sprintf(hdPathTemplate, coinType, account, index) path := hdwallet.MustParseDerivationPath(hdPath) - accs, err := wallet.Derive(path, true) if err != nil { return nil, err } + privatekey, err := wallet.PrivateKey(accs) if err != nil { return nil, err } + return privatekey, nil } diff --git a/relayer/chains/evm/keys_test.go b/relayer/chains/evm/keys_test.go index 726d4fa..77183ae 100644 --- a/relayer/chains/evm/keys_test.go +++ b/relayer/chains/evm/keys_test.go @@ -4,7 +4,6 @@ import ( "fmt" "testing" - "github.com/ethereum/go-ethereum/crypto" "github.com/stretchr/testify/suite" "go.uber.org/zap" @@ -25,6 +24,7 @@ type KeysTestSuite struct { chainProvider *evm.EVMChainProvider log *zap.Logger homePath string + wallet wallet.Wallet } func TestKeysTestSuite(t *testing.T) { @@ -42,9 +42,11 @@ func (s *KeysTestSuite) SetupTest() { wallet, err := wallet.NewGethWallet("", s.homePath, chainName) s.Require().NoError(err) - chainProvider, err := evm.NewEVMChainProvider(chainName, client, evmCfg, s.log, s.homePath, wallet) + chainProvider, err := evm.NewEVMChainProvider(chainName, client, evmCfg, s.log, wallet) s.Require().NoError(err) + s.chainProvider = chainProvider + s.wallet = wallet } func (s *KeysTestSuite) TestAddKeyByPrivateKey() { @@ -86,14 +88,7 @@ func (s *KeysTestSuite) TestAddKeyByPrivateKey() { for _, tc := range testcases { s.T().Run(tc.name, func(t *testing.T) { - key, err := s.chainProvider.AddKey( - tc.input.keyName, - "", - tc.input.privKey, - 0, - 0, - 0, - ) + key, err := s.chainProvider.AddKeyByPrivateKey(tc.input.keyName, tc.input.privKey) if tc.err != nil { s.Require().ErrorContains(err, tc.err.Error()) @@ -102,11 +97,8 @@ func (s *KeysTestSuite) TestAddKeyByPrivateKey() { s.Require().Equal(tc.out, key) // check that key info actually stored in local disk - keyInfo, err := evm.LoadKeyInfo(s.homePath, s.chainProvider.ChainName) - s.Require().NoError(err) - - _, exist := keyInfo[tc.input.keyName] - s.Require().True(exist) + _, ok := s.wallet.GetAddress(tc.input.keyName) + s.Require().True(ok) } }) } @@ -184,10 +176,9 @@ func (s *KeysTestSuite) TestAddKeyByMnemonic() { for _, tc := range testcases { s.T().Run(tc.name, func(t *testing.T) { - key, err := s.chainProvider.AddKey( + key, err := s.chainProvider.AddKeyByMnemonic( tc.input.keyName, tc.input.mnemonic, - "", tc.input.coinType, tc.input.account, tc.input.index, @@ -203,11 +194,8 @@ func (s *KeysTestSuite) TestAddKeyByMnemonic() { } // check that key info actually stored in local disk - keyInfo, err := evm.LoadKeyInfo(s.homePath, s.chainProvider.ChainName) - s.Require().NoError(err) - - _, exist := keyInfo[tc.input.keyName] - s.Require().True(exist) + _, ok := s.wallet.GetAddress(tc.input.keyName) + s.Require().True(ok) } }) } @@ -218,7 +206,7 @@ func (s *KeysTestSuite) TestDeleteKey() { privatekeyHex := testPrivateKey // Add a key to delete - _, err := s.chainProvider.AddKeyWithPrivateKey(keyName, privatekeyHex) + _, err := s.chainProvider.AddKeyByPrivateKey(keyName, privatekeyHex) s.Require().NoError(err) // Delete the key @@ -226,7 +214,8 @@ func (s *KeysTestSuite) TestDeleteKey() { s.Require().NoError(err) // Ensure the key is no longer in the KeyInfo or KeyStore - s.Require().False(s.chainProvider.IsKeyNameExist(keyName)) + _, ok := s.chainProvider.Wallet.GetAddress(keyName) + s.Require().False(ok) // Delete the key again should return error err = s.chainProvider.DeleteKey(keyName) @@ -238,7 +227,7 @@ func (s *KeysTestSuite) TestExportPrivateKey() { privatekeyHex := testPrivateKey // Add a key to export - _, err := s.chainProvider.AddKeyWithPrivateKey(keyName, privatekeyHex) + _, err := s.chainProvider.AddKeyByPrivateKey(keyName, privatekeyHex) s.Require().NoError(err) // Export the private key @@ -253,25 +242,22 @@ func (s *KeysTestSuite) TestListKeys() { keyName1 := "key1" keyName2 := "key2" mnemonic := "" - privateKey := "" coinType := 60 account := 0 index := 0 - key1, err := s.chainProvider.AddKey( + key1, err := s.chainProvider.AddKeyByMnemonic( keyName1, mnemonic, - privateKey, uint32(coinType), uint(account), uint(index), ) s.Require().NoError(err) - key2, err := s.chainProvider.AddKey( + key2, err := s.chainProvider.AddKeyByMnemonic( keyName2, mnemonic, - privateKey, uint32(coinType), uint(account), uint(index), @@ -307,7 +293,7 @@ func (s *KeysTestSuite) TestShowKey() { privatekeyHex := testPrivateKey // Add a key to show - _, err := s.chainProvider.AddKeyWithPrivateKey(keyName, privatekeyHex) + _, err := s.chainProvider.AddKeyByPrivateKey(keyName, privatekeyHex) s.Require().NoError(err) // Show the key @@ -315,18 +301,3 @@ func (s *KeysTestSuite) TestShowKey() { s.Require().Equal(address, address) s.Require().NoError(err) } - -func (s *KeysTestSuite) TestIsKeyNameExist() { - priv, err := crypto.HexToECDSA(evm.StripPrivateKeyPrefix(testPrivateKey)) - s.Require().NoError(err) - - _, err = s.chainProvider.Wallet.SavePrivateKey("testkey1", priv) - s.Require().NoError(err) - - expected := s.chainProvider.IsKeyNameExist("testkey1") - - s.Require().Equal(expected, true) - - expected = s.chainProvider.IsKeyNameExist("testkey2") - s.Require().Equal(expected, false) -} diff --git a/relayer/chains/evm/provider.go b/relayer/chains/evm/provider.go index fa6637f..9b5b97b 100644 --- a/relayer/chains/evm/provider.go +++ b/relayer/chains/evm/provider.go @@ -47,7 +47,6 @@ func NewEVMChainProvider( client Client, cfg *EVMChainProviderConfig, log *zap.Logger, - homePath string, wallet wallet.Wallet, ) (*EVMChainProvider, error) { // load abis here diff --git a/relayer/chains/evm/provider_eip1559_test.go b/relayer/chains/evm/provider_eip1559_test.go index 9dcd21d..c5b72dd 100644 --- a/relayer/chains/evm/provider_eip1559_test.go +++ b/relayer/chains/evm/provider_eip1559_test.go @@ -52,7 +52,7 @@ func (s *EIP1559ProviderTestSuite) SetupTest() { wallet, err := wallet.NewGethWallet("", s.homePath, s.chainName) s.Require().NoError(err) - chainProvider, err := evm.NewEVMChainProvider(s.chainName, s.client, &evmConfig, zap.NewNop(), s.homePath, wallet) + chainProvider, err := evm.NewEVMChainProvider(s.chainName, s.client, &evmConfig, zap.NewNop(), wallet) s.Require().NoError(err) priv, err := crypto.HexToECDSA(evm.StripPrivateKeyPrefix(testPrivateKey)) diff --git a/relayer/chains/evm/provider_legacy_test.go b/relayer/chains/evm/provider_legacy_test.go index 9e1b6ed..bc28067 100644 --- a/relayer/chains/evm/provider_legacy_test.go +++ b/relayer/chains/evm/provider_legacy_test.go @@ -52,7 +52,7 @@ func (s *LegacyProviderTestSuite) SetupTest() { wallet, err := wallet.NewGethWallet("", s.homePath, s.chainName) s.Require().NoError(err) - chainProvider, err := evm.NewEVMChainProvider(s.chainName, s.client, &evmConfig, zap.NewNop(), s.homePath, wallet) + chainProvider, err := evm.NewEVMChainProvider(s.chainName, s.client, &evmConfig, zap.NewNop(), wallet) s.Require().NoError(err) priv, err := crypto.HexToECDSA(evm.StripPrivateKeyPrefix(testPrivateKey)) diff --git a/relayer/chains/evm/provider_test.go b/relayer/chains/evm/provider_test.go index 0a65502..2733c99 100644 --- a/relayer/chains/evm/provider_test.go +++ b/relayer/chains/evm/provider_test.go @@ -116,7 +116,7 @@ func (s *ProviderTestSuite) SetupTest() { wallet, err := wallet.NewGethWallet("", s.homePath, s.chainName) s.Require().NoError(err) - s.chainProvider, err = evm.NewEVMChainProvider(s.chainName, s.client, baseEVMCfg, s.log, s.homePath, wallet) + s.chainProvider, err = evm.NewEVMChainProvider(s.chainName, s.client, baseEVMCfg, s.log, wallet) s.Require().NoError(err) s.chainProvider.Client = s.client diff --git a/relayer/chains/evm/sender.go b/relayer/chains/evm/sender.go index 84c2eb7..6cb6269 100644 --- a/relayer/chains/evm/sender.go +++ b/relayer/chains/evm/sender.go @@ -2,16 +2,10 @@ package evm import ( "fmt" - "os" - "path" gethcommon "github.com/ethereum/go-ethereum/common" - "github.com/pelletier/go-toml/v2" ) -// KeyInfo struct is the struct that represents mapping of key name -> address -type KeyInfo map[string]string - // Sender is the struct that represents the sender of the transaction. type Sender struct { Name string @@ -49,29 +43,3 @@ func (cp *EVMChainProvider) LoadFreeSenders() error { cp.FreeSenders = freeSenders return nil } - -// loadKeyInfo loads key information from local disk. -func LoadKeyInfo(homePath, chainName string) (KeyInfo, error) { - keyInfo := make(KeyInfo) - - keyInfoDir := path.Join(homePath, keyDir, chainName, infoDir) - keyInfoPath := path.Join(keyInfoDir, infoFileName) - - if _, err := os.Stat(keyInfoPath); err != nil { - // don't return error if file doesn't exist - return keyInfo, nil - } - - b, err := os.ReadFile(keyInfoPath) - if err != nil { - return nil, err - } - - // unmarshal them with Config into struct - err = toml.Unmarshal(b, &keyInfo) - if err != nil { - return nil, err - } - - return keyInfo, nil -} diff --git a/relayer/chains/evm/sender_test.go b/relayer/chains/evm/sender_test.go index 022d52b..966875f 100644 --- a/relayer/chains/evm/sender_test.go +++ b/relayer/chains/evm/sender_test.go @@ -2,13 +2,9 @@ package evm_test import ( "context" - "os" - "path" "testing" "time" - "github.com/pelletier/go-toml/v2" - "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "go.uber.org/zap" @@ -74,52 +70,21 @@ func (s *SenderTestSuite) SetupTest() { wallet, err := wallet.NewGethWallet("", s.homePath, chainName) s.Require().NoError(err) - s.chainProvider, err = evm.NewEVMChainProvider(chainName, client, evmCfg, log, tmpDir, wallet) + s.chainProvider, err = evm.NewEVMChainProvider(chainName, client, evmCfg, log, wallet) s.Require().NoError(err) s.ctx = context.Background() } -func TestLoadKeyInfo(t *testing.T) { - tmpDir := t.TempDir() - chainName := "testnet" - - // write mock keyInfo at keyInfo's path - keyInfo := make(evm.KeyInfo) - keyInfo["key1"] = "" - keyInfo["key2"] = "" - b, err := toml.Marshal(&keyInfo) - require.NoError(t, err) - - keyInfoDir := path.Join(tmpDir, "keys", chainName, "info") - keyInfoPath := path.Join(keyInfoDir, "info.toml") - // Create the info folder if doesn't exist - err = os.MkdirAll(keyInfoDir, os.ModePerm) - require.NoError(t, err) - // Create the file and write the default config to the given location. - f, err := os.Create(keyInfoPath) - require.NoError(t, err) - defer f.Close() - - _, err = f.Write(b) - require.NoError(t, err) - - // load keyInfo - actual, err := evm.LoadKeyInfo(tmpDir, chainName) - require.NoError(t, err) - - require.Equal(t, keyInfo, actual) -} - func (s *SenderTestSuite) TestLoadFreeSenders() { keyName1 := "key1" keyName2 := "key2" // Add two mock keys to the chain provider - _, err := s.chainProvider.AddKeyWithPrivateKey(keyName1, privateKey1) + _, err := s.chainProvider.AddKeyByPrivateKey(keyName1, privateKey1) s.Require().NoError(err) - _, err = s.chainProvider.AddKeyWithPrivateKey(keyName2, privateKey2) + _, err = s.chainProvider.AddKeyByPrivateKey(keyName2, privateKey2) s.Require().NoError(err) // Load free senders diff --git a/relayer/chains/provider.go b/relayer/chains/provider.go index c73993e..f1aef2e 100644 --- a/relayer/chains/provider.go +++ b/relayer/chains/provider.go @@ -35,16 +35,18 @@ type ChainProvider interface { // KeyProvider defines the interface for the key interaction with destination chain type KeyProvider interface { - // AddKey stores the private key with a given mnemonic and key name on the user's local disk. - AddKey( + // AddKeyByMnemonic adds a key using a mnemonic phrase. + AddKeyByMnemonic( keyName string, mnemonic string, - privateKeyHex string, coinType uint32, account uint, index uint, ) (*chainstypes.Key, error) + // AddKeyByPrivateKey adds a key using a private key. + AddKeyByPrivateKey(keyName string, privateKeyHex string) (*chainstypes.Key, error) + // DeleteKey deletes the key information and private key DeleteKey(keyName string) error @@ -57,9 +59,6 @@ type KeyProvider interface { // ShowKey shows the address of the given key ShowKey(keyName string) (string, error) - // IsKeyNameExist checks whether a key with the specified keyName already exists in storage. - IsKeyNameExist(keyName string) bool - // LoadFreeSenders loads key info to prepare to relay the packet LoadFreeSenders() error } diff --git a/relayer/store/filesystem.go b/relayer/store/filesystem.go index ed8d2f1..9978387 100644 --- a/relayer/store/filesystem.go +++ b/relayer/store/filesystem.go @@ -1,6 +1,8 @@ package store import ( + "bytes" + "crypto/sha256" "fmt" "path" @@ -75,11 +77,26 @@ func (fs *FileSystem) GetHashedPassphrase() ([]byte, error) { return fs.hashedPassphrase, nil } -// SaveHashedPassphrase saves the hashedPassphrase to the filesystem. -func (fs *FileSystem) SaveHashedPassphrase(hashedPassphrase []byte) error { - fs.hashedPassphrase = hashedPassphrase +// SavePassphrase hashes and saves the hashedPassphrase to the filesystem. +func (fs *FileSystem) SavePassphrase(passphrase string) error { + fs.hashedPassphrase = hashPassphrase(passphrase) - return os.Write(hashedPassphrase, getPassphrasePath(fs.HomePath)) + return os.Write(fs.hashedPassphrase, getPassphrasePath(fs.HomePath)) +} + +// ValidatePassphrase validates the given passphrase with the stored hashed passphrase. +func (fs *FileSystem) ValidatePassphrase(passphrase string) error { + // load passphrase from local disk + storedHashedPassphrase, err := fs.GetHashedPassphrase() + if err != nil { + return err + } + + if !bytes.Equal(hashPassphrase(passphrase), storedHashedPassphrase) { + return fmt.Errorf("invalid passphrase: the provided passphrase does not match the stored hashed passphrase") + } + + return nil } // NewWallet creates a new wallet object based on the chain type and chain name. @@ -101,3 +118,10 @@ func getConfigPath(homePath string) []string { func getPassphrasePath(homePath string) []string { return []string{homePath, cfgDir, passphraseFileName} } + +// hashPassphrase hashes the given passphrase and returns the hashed bytes. +func hashPassphrase(passphrase string) []byte { + h := sha256.New() + h.Write([]byte(passphrase)) + return h.Sum(nil) +} diff --git a/relayer/store/filesystem_test.go b/relayer/store/filesystem_test.go index b72be6e..110714c 100644 --- a/relayer/store/filesystem_test.go +++ b/relayer/store/filesystem_test.go @@ -1,6 +1,7 @@ package store_test import ( + "fmt" "testing" "github.com/stretchr/testify/suite" @@ -60,18 +61,50 @@ func (s *FileSystemTestSuite) TestGetEmptyHashedPassphrase() { } func (s *FileSystemTestSuite) TestGetHashedPassphrase() { - err := s.store.SaveHashedPassphrase([]byte("test")) + err := s.store.SavePassphrase("test") s.NoError(err) // overwrite the passphrase shouldn't cause any error - err = s.store.SaveHashedPassphrase([]byte("new passphrase")) + err = s.store.SavePassphrase("new passphrase") s.NoError(err) // create a new store to read the new passphrase newStore, err := store.NewFileSystem(s.store.HomePath) s.NoError(err) + expect := []byte{ + 0x5c, 0xb5, 0xf0, 0x32, 0x6, 0x65, 0x34, 0x19, 0x2e, 0x6e, 0xda, 0xe1, 0x7, 0x3c, + 0xe9, 0x0, 0x37, 0x2e, 0x5c, 0x35, 0x69, 0x54, 0x65, 0x9d, 0xb9, 0x96, 0x92, 0xc6, + 0x1d, 0x1d, 0xc, 0xe7, + } + hashedPassphrase, err := newStore.GetHashedPassphrase() s.NoError(err) - s.Require().Equal([]byte("new passphrase"), hashedPassphrase) + s.Require().Equal(expect, hashedPassphrase) +} + +func (s *FileSystemTestSuite) TestValidatePassphraseInvalidPassphrase() { + // prepare bytes slices of hashed env passphrase + err := s.store.SavePassphrase("secret") + s.NoError(err) + + testcases := []struct { + name string + envPassphrase string + err error + }{ + {name: "valid", envPassphrase: "secret", err: nil}, + {name: "invalid", envPassphrase: "invalid", err: fmt.Errorf("invalid passphrase")}, + } + + for _, tc := range testcases { + s.Run(tc.name, func() { + err := s.store.ValidatePassphrase(tc.envPassphrase) + if tc.err != nil { + s.Require().ErrorContains(err, tc.err.Error()) + } else { + s.Require().NoError(err) + } + }) + } } diff --git a/relayer/store/store.go b/relayer/store/store.go index b7675d8..09c17ea 100644 --- a/relayer/store/store.go +++ b/relayer/store/store.go @@ -11,6 +11,7 @@ type Store interface { GetConfig() (*config.Config, error) SaveConfig(cfg *config.Config) error GetHashedPassphrase() ([]byte, error) - SaveHashedPassphrase(hashedPassphrase []byte) error + SavePassphrase(passphrase string) error + ValidatePassphrase(passphrase string) error NewWallet(chainType chains.ChainType, chainName, passphrase string) (wallet.Wallet, error) }