diff --git a/internal/cmd/beta/beta_cmd.go b/internal/cmd/beta/beta_cmd.go index 919121763..a5847794b 100644 --- a/internal/cmd/beta/beta_cmd.go +++ b/internal/cmd/beta/beta_cmd.go @@ -37,7 +37,7 @@ All commands use the runme.yaml configuration file.`, cmd.SetErr(io.Discard) } - err := autoconfig.InvokeForCommand(func(cfg *config.Config, log *zap.Logger) error { + err := autoconfig.Invoke(func(cfg *config.Config, log *zap.Logger) error { // Override the filename if provided. if cFlags.filename != "" { cfg.Project.Filename = cFlags.filename diff --git a/internal/cmd/beta/list_cmd.go b/internal/cmd/beta/list_cmd.go index ebb008ff7..82cb04798 100644 --- a/internal/cmd/beta/list_cmd.go +++ b/internal/cmd/beta/list_cmd.go @@ -36,7 +36,7 @@ List all blocks from the "setup" and "teardown" tags: runme beta list --tag=setup,teardown `, RunE: func(cmd *cobra.Command, args []string) error { - return autoconfig.InvokeForCommand( + return autoconfig.Invoke( func( proj *project.Project, filters []project.Filter, diff --git a/internal/cmd/beta/print_cmd.go b/internal/cmd/beta/print_cmd.go index aa1407f4c..bbee21f5a 100644 --- a/internal/cmd/beta/print_cmd.go +++ b/internal/cmd/beta/print_cmd.go @@ -25,7 +25,7 @@ Print content of commands from the "setup" and "teardown" tags: runme beta print --tag=setup,teardown `, RunE: func(cmd *cobra.Command, args []string) error { - return autoconfig.InvokeForCommand( + return autoconfig.Invoke( func( proj *project.Project, filters []project.Filter, diff --git a/internal/cmd/beta/run_cmd.go b/internal/cmd/beta/run_cmd.go index a6572b464..59afe2309 100644 --- a/internal/cmd/beta/run_cmd.go +++ b/internal/cmd/beta/run_cmd.go @@ -40,7 +40,7 @@ Run all blocks from the "setup" and "teardown" tags: runme beta run --tag=setup,teardown `, RunE: func(cmd *cobra.Command, args []string) error { - return autoconfig.InvokeForCommand( + return autoconfig.Invoke( func( clientFactory autoconfig.ClientFactory, cmdFactory command.Factory, diff --git a/internal/cmd/beta/server/server_cmd.go b/internal/cmd/beta/server/server_cmd.go index 127a6e5b2..726b29848 100644 --- a/internal/cmd/beta/server/server_cmd.go +++ b/internal/cmd/beta/server/server_cmd.go @@ -13,7 +13,7 @@ func Cmd() *cobra.Command { Short: "Commands to manage and call a runme server.", Hidden: true, PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - return autoconfig.InvokeForCommand( + return autoconfig.Invoke( func( cfg *config.Config, ) error { diff --git a/internal/cmd/beta/server/server_grpcurl_describe_cmd.go b/internal/cmd/beta/server/server_grpcurl_describe_cmd.go index 855e3ebbb..465f3d6a8 100644 --- a/internal/cmd/beta/server/server_grpcurl_describe_cmd.go +++ b/internal/cmd/beta/server/server_grpcurl_describe_cmd.go @@ -22,7 +22,7 @@ func serverGRPCurlDescribeCmd() *cobra.Command { Short: "Describe gRPC services and methods exposed by the server.", Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - return autoconfig.InvokeForCommand( + return autoconfig.Invoke( func( cfg *config.Config, logger *zap.Logger, diff --git a/internal/cmd/beta/server/server_grpcurl_invoke_cmd.go b/internal/cmd/beta/server/server_grpcurl_invoke_cmd.go index aee9a3a1e..4630206b3 100644 --- a/internal/cmd/beta/server/server_grpcurl_invoke_cmd.go +++ b/internal/cmd/beta/server/server_grpcurl_invoke_cmd.go @@ -26,7 +26,7 @@ func serverGRPCurlInvokeCmd() *cobra.Command { Short: "Invoke gRPC command to the server.", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - return autoconfig.InvokeForCommand( + return autoconfig.Invoke( func( cfg *config.Config, logger *zap.Logger, diff --git a/internal/cmd/beta/server/server_grpcurl_list_cmd.go b/internal/cmd/beta/server/server_grpcurl_list_cmd.go index bc8cc1d22..1679ae8cb 100644 --- a/internal/cmd/beta/server/server_grpcurl_list_cmd.go +++ b/internal/cmd/beta/server/server_grpcurl_list_cmd.go @@ -19,7 +19,7 @@ func serverGRPCurlListCmd() *cobra.Command { Short: "List gRPC services exposed by the server.", Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - return autoconfig.InvokeForCommand( + return autoconfig.Invoke( func( cfg *config.Config, logger *zap.Logger, diff --git a/internal/cmd/beta/server/server_start_cmd.go b/internal/cmd/beta/server/server_start_cmd.go index d1831a2c1..6c22b3101 100644 --- a/internal/cmd/beta/server/server_start_cmd.go +++ b/internal/cmd/beta/server/server_start_cmd.go @@ -7,7 +7,6 @@ import ( "github.com/spf13/cobra" "go.uber.org/zap" - "github.com/stateful/runme/v3/internal/command" "github.com/stateful/runme/v3/internal/config" "github.com/stateful/runme/v3/internal/config/autoconfig" "github.com/stateful/runme/v3/internal/server" @@ -19,30 +18,16 @@ func serverStartCmd() *cobra.Command { Use: "start", Short: "Start a server.", RunE: func(cmd *cobra.Command, args []string) error { - return autoconfig.InvokeForCommand( + return autoconfig.Invoke( func( cfg *config.Config, - cmdFactory command.Factory, + server *server.Server, logger *zap.Logger, ) error { defer logger.Sync() - serverCfg := &server.Config{ - Address: cfg.Server.Address, - CertFile: *cfg.Server.Tls.CertFile, // guaranteed by autoconfig - KeyFile: *cfg.Server.Tls.KeyFile, // guaranteed by autoconfig - TLSEnabled: cfg.Server.Tls.Enabled, - } - _ = telemetry.ReportUnlessNoTracking(logger) - logger.Debug("server config", zap.Any("config", serverCfg)) - - s, err := server.New(serverCfg, cmdFactory, logger) - if err != nil { - return err - } - // When using a unix socket, we want to create a file with server's PID. if path := pidFileNameFromAddr(cfg.Server.Address); path != "" { logger.Debug("creating PID file", zap.String("path", path)) @@ -52,9 +37,7 @@ func serverStartCmd() *cobra.Command { defer os.Remove(cfg.Server.Address) } - logger.Debug("starting the server") - - return errors.WithStack(s.Serve()) + return errors.WithStack(server.Serve()) }, ) }, diff --git a/internal/cmd/beta/server/server_stop_cmd.go b/internal/cmd/beta/server/server_stop_cmd.go index bbe93be37..25cdda3cc 100644 --- a/internal/cmd/beta/server/server_stop_cmd.go +++ b/internal/cmd/beta/server/server_stop_cmd.go @@ -17,7 +17,7 @@ func serverStopCmd() *cobra.Command { Use: "stop", Short: "Stop a server.", RunE: func(cmd *cobra.Command, args []string) error { - return autoconfig.InvokeForCommand( + return autoconfig.Invoke( func( cfg *config.Config, logger *zap.Logger, diff --git a/internal/cmd/beta/session_cmd.go b/internal/cmd/beta/session_cmd.go index e1278de63..f535da8cb 100644 --- a/internal/cmd/beta/session_cmd.go +++ b/internal/cmd/beta/session_cmd.go @@ -26,7 +26,7 @@ func sessionCmd(*commonFlags) *cobra.Command { All exported variables during the session will be available to the subsequent commands. `, RunE: func(cmd *cobra.Command, args []string) error { - return autoconfig.InvokeForCommand( + return autoconfig.Invoke( func( cmdFactory command.Factory, logger *zap.Logger, @@ -133,7 +133,7 @@ func sessionSetupCmd() *cobra.Command { Use: "setup", Hidden: true, RunE: func(cmd *cobra.Command, args []string) error { - return autoconfig.InvokeForCommand( + return autoconfig.Invoke( func( cmdFactory command.Factory, logger *zap.Logger, diff --git a/internal/command/config.go b/internal/command/config.go index 96ff2ecb7..a2efc8db0 100644 --- a/internal/command/config.go +++ b/internal/command/config.go @@ -25,10 +25,10 @@ func redactConfig(cfg *ProgramConfig) *ProgramConfig { } func isShell(cfg *ProgramConfig) bool { - return IsShellProgram(filepath.Base(cfg.ProgramName)) || IsShellLanguage(cfg.LanguageId) + return isShellProgram(filepath.Base(cfg.ProgramName)) || IsShellLanguage(cfg.LanguageId) } -func IsShellProgram(programName string) bool { +func isShellProgram(programName string) bool { switch strings.ToLower(programName) { case "sh", "bash", "zsh", "ksh", "shell": return true diff --git a/internal/config/autoconfig/autoconfig.go b/internal/config/autoconfig/autoconfig.go index 7d9353e1e..f39522011 100644 --- a/internal/config/autoconfig/autoconfig.go +++ b/internal/config/autoconfig/autoconfig.go @@ -24,47 +24,39 @@ import ( "github.com/stateful/runme/v3/internal/command" "github.com/stateful/runme/v3/internal/config" "github.com/stateful/runme/v3/internal/dockerexec" + "github.com/stateful/runme/v3/internal/project/projectservice" "github.com/stateful/runme/v3/internal/runnerv2client" + "github.com/stateful/runme/v3/internal/runnerv2service" + "github.com/stateful/runme/v3/internal/server" runmetls "github.com/stateful/runme/v3/internal/tls" + parserv1 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/parser/v1" + projectv1 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/project/v1" + runnerv2 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/runner/v2" + "github.com/stateful/runme/v3/pkg/document/editor/editorservice" "github.com/stateful/runme/v3/pkg/project" ) -var ( - container = dig.New() - commandScope = container.Scope("command") - serverScope = container.Scope("server") -) - -func DecorateRoot(decorator interface{}, opts ...dig.DecorateOption) error { - return container.Decorate(decorator, opts...) -} +var defaultBuilder = NewBuilder() -// InvokeForCommand is used to invoke the function with the given dependencies. -// The package will automatically figure out how to instantiate them -// using the available configuration. -// -// Use it only for commands because it supports only singletons -// created during the program initialization. -func InvokeForCommand(function interface{}, opts ...dig.InvokeOption) error { - err := commandScope.Invoke(function, opts...) - return dig.RootCause(err) +type Builder struct { + *dig.Container } -// InvokeForServer is similar to InvokeForCommand, but it does not provide -// all the dependencies, in particular, it does not provide dependencies -// that differ per request. -func InvokeForServer(function interface{}, opts ...dig.InvokeOption) error { - err := serverScope.Invoke(function, opts...) - return dig.RootCause(err) +func NewBuilder() *Builder { + b := Builder{Container: dig.New()} + b.init() + return &b } -func mustProvide(err error) { - if err != nil { - panic("failed to provide: " + err.Error()) +func (b *Builder) init() { + mustProvide := func(err error) { + if err != nil { + panic("failed to provide: " + err.Error()) + } } -} -func init() { + container := b + mustProvide(container.Provide(getClient)) mustProvide(container.Provide(getClientFactory)) mustProvide(container.Provide(getCommandFactory)) @@ -74,7 +66,19 @@ func init() { mustProvide(container.Provide(getProject)) mustProvide(container.Provide(getProjectFilters)) mustProvide(container.Provide(getRootConfig)) - mustProvide(container.Provide(getUserConfigDir)) + mustProvide(container.Provide(getServer)) +} + +func Decorate(decorator interface{}, opts ...dig.DecorateOption) error { + return defaultBuilder.Decorate(decorator, opts...) +} + +// Invoke is used to invoke the function with the given dependencies. +// The package will automatically figure out how to instantiate them +// using the available configuration. +func Invoke(function interface{}, opts ...dig.InvokeOption) error { + err := defaultBuilder.Invoke(function, opts...) + return dig.RootCause(err) } func getClient(cfg *config.Config, logger *zap.Logger) (*runnerv2client.Client, error) { @@ -82,9 +86,11 @@ func getClient(cfg *config.Config, logger *zap.Logger) (*runnerv2client.Client, return nil, nil } - var opts []grpc.DialOption + opts := []grpc.DialOption{ + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(cfg.Server.MaxMessageSize)), + } - if cfg.Server.Tls != nil && cfg.Server.Tls.Enabled { + if tls := cfg.Server.Tls; tls != nil && tls.Enabled { // It's ok to dereference TLS fields because they are checked in [getRootConfig]. tlsConfig, err := runmetls.LoadClientConfig(*cfg.Server.Tls.CertFile, *cfg.Server.Tls.KeyFile) if err != nil { @@ -96,19 +102,22 @@ func getClient(cfg *config.Config, logger *zap.Logger) (*runnerv2client.Client, opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) } - return runnerv2client.New( - cfg.Server.Address, - logger, - opts..., - ) + conn, err := grpc.NewClient(cfg.Server.Address, opts...) + if err != nil { + return nil, errors.WithStack(err) + } + if conn == nil { + return nil, errors.New("client connection is not configured") + } + return runnerv2client.New(conn, logger), nil } type ClientFactory func() (*runnerv2client.Client, error) -func getClientFactory(cfg *config.Config, logger *zap.Logger) ClientFactory { +func getClientFactory(cfg *config.Config, logger *zap.Logger) (ClientFactory, error) { return func() (*runnerv2client.Client, error) { return getClient(cfg, logger) - } + }, nil } func getCommandFactory(docker *dockerexec.Docker, logger *zap.Logger, proj *project.Project) command.Factory { @@ -166,7 +175,7 @@ func getLogger(c *config.Config) (*zap.Logger, error) { } if c.Log.Verbose { - zapConfig.Level = zap.NewAtomicLevelAt(zap.DebugLevel) + zapConfig.Level = zap.NewAtomicLevelAt(zap.InfoLevel) zapConfig.Development = true zapConfig.Encoding = "console" zapConfig.EncoderConfig = zap.NewDevelopmentEncoderConfig() @@ -266,7 +275,7 @@ func getProjectFilters(c *config.Config) ([]project.Filter, error) { return filters, nil } -func getRootConfig(cfgLoader *config.Loader, userCfgDir UserConfigDir) (*config.Config, error) { +func getRootConfig(cfgLoader *config.Loader) (*config.Config, error) { var cfg *config.Config items, err := cfgLoader.RootConfigs() @@ -284,6 +293,11 @@ func getRootConfig(cfgLoader *config.Loader, userCfgDir UserConfigDir) (*config. if cfg.Server != nil && cfg.Server.Tls != nil && cfg.Server.Tls.Enabled { tls := cfg.Server.Tls + userCfgDir, err := os.UserConfigDir() + if err != nil { + return nil, errors.WithMessage(err, "failed to get user config directory") + } + if tls.CertFile == nil { path := filepath.Join(string(userCfgDir), "runme", "tls", "cert.pem") tls.CertFile = &path @@ -297,9 +311,25 @@ func getRootConfig(cfgLoader *config.Loader, userCfgDir UserConfigDir) (*config. return cfg, nil } -type UserConfigDir string +func getServer(cfg *config.Config, cmdFactory command.Factory, logger *zap.Logger) (*server.Server, error) { + if cfg.Server == nil { + return nil, nil + } -func getUserConfigDir() (UserConfigDir, error) { - dir, err := os.UserConfigDir() - return UserConfigDir(dir), errors.WithStack(err) + parserService := editorservice.NewParserServiceServer(logger) + projectService := projectservice.NewProjectServiceServer(logger) + runnerService, err := runnerv2service.NewRunnerService(cmdFactory, logger) + if err != nil { + return nil, err + } + + return server.New( + cfg, + logger, + func(sr grpc.ServiceRegistrar) { + parserv1.RegisterParserServiceServer(sr, parserService) + projectv1.RegisterProjectServiceServer(sr, projectService) + runnerv2.RegisterRunnerServiceServer(sr, runnerService) + }, + ) } diff --git a/internal/config/autoconfig/autoconfig_test.go b/internal/config/autoconfig/autoconfig_test.go index 534d57576..e2fa5ea12 100644 --- a/internal/config/autoconfig/autoconfig_test.go +++ b/internal/config/autoconfig/autoconfig_test.go @@ -1,35 +1,185 @@ package autoconfig import ( + "context" "fmt" + "os" + "path/filepath" "testing" "testing/fstest" + "time" + "github.com/pkg/errors" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + healthv1 "google.golang.org/grpc/health/grpc_health_v1" "github.com/stateful/runme/v3/internal/config" + "github.com/stateful/runme/v3/internal/runnerv2client" + "github.com/stateful/runme/v3/internal/server" ) func TestInvokeForCommand_Config(t *testing.T) { - // Create a fake filesystem and set it in [config.Loader]. - err := InvokeForCommand(func(loader *config.Loader) error { - fsys := fstest.MapFS{ - "README.md": { - Data: []byte("Hello, World!"), + builder := NewBuilder() + configRootFS := fstest.MapFS{ + "runme.yaml": { + // It's ok that README.md does not exist as it's not used in this test. + Data: []byte(fmt.Sprintf("version: v1alpha1\nproject:\n filename: %s\n", "README.md")), + }, + } + err := builder.Decorate( + func() (*config.Loader, error) { + return config.NewLoader([]string{"runme.yaml"}, configRootFS), nil + }, + ) + require.NoError(t, err) + err = builder.Invoke(func(*config.Config) error { return nil }) + require.NoError(t, err) +} + +func TestInvokeForCommand_ServerClient(t *testing.T) { + t.Run("NoServerInConfig", func(t *testing.T) { + builder := NewBuilder() + temp := t.TempDir() + + err := os.WriteFile(filepath.Join(temp, "README.md"), []byte("Hello, World!"), 0o644) + require.NoError(t, err) + + configRootFS := fstest.MapFS{ + "runme.yaml": { + Data: []byte(`version: v1alpha1 +project: + filename: ` + filepath.Join(temp, "README.md") + ` +server: null +`), }, + } + err = builder.Decorate( + func() (*config.Loader, error) { + return config.NewLoader([]string{"runme.yaml"}, configRootFS), nil + }, + ) + require.NoError(t, err) + + err = builder.Invoke(func( + server *server.Server, + client *runnerv2client.Client, + ) error { + require.Nil(t, server) + require.Nil(t, client) + return nil + }) + require.NoError(t, err) + }) + + t.Run("ServerInConfigWithoutTLS", func(t *testing.T) { + builder := NewBuilder() + temp := t.TempDir() + + err := os.WriteFile(filepath.Join(temp, "README.md"), []byte("Hello, World!"), 0o644) + require.NoError(t, err) + + configRootFS := fstest.MapFS{ "runme.yaml": { - Data: []byte(fmt.Sprintf("version: v1alpha1\nproject:\n filename: %s\n", "README.md")), + Data: []byte(`version: v1alpha1 +project: + filename: ` + filepath.Join(temp, "README.md") + ` +`), }, } - loader.SetConfigRootPath(fsys) - return nil + err = builder.Decorate( + func() (*config.Loader, error) { + return config.NewLoader([]string{"runme.yaml"}, configRootFS), nil + }, + ) + require.NoError(t, err) + + err = builder.Invoke(func( + server *server.Server, + client *runnerv2client.Client, + ) error { + require.NotNil(t, server) + require.NotNil(t, client) + + var g errgroup.Group + + g.Go(func() error { + return server.Serve() + }) + + g.Go(func() error { + defer server.Shutdown() + return checkHealth(client) + }) + + return g.Wait() + }) + require.NoError(t, err) }) - require.NoError(t, err) - err = InvokeForCommand(func( - *config.Config, - ) error { - return nil + t.Run("ServerInConfigWithTLS", func(t *testing.T) { + builder := NewBuilder() + temp := t.TempDir() + + err := os.WriteFile(filepath.Join(temp, "README.md"), []byte("Hello, World!"), 0o644) + require.NoError(t, err) + + configRootFS := fstest.MapFS{ + "runme.yaml": { + Data: []byte(`version: v1alpha1 +project: + filename: ` + filepath.Join(temp, "README.md") + ` +`), + }, + } + err = builder.Decorate( + func() (*config.Loader, error) { + return config.NewLoader([]string{"runme.yaml"}, configRootFS), nil + }, + ) + require.NoError(t, err) + + err = builder.Invoke(func( + server *server.Server, + client *runnerv2client.Client, + ) error { + require.NotNil(t, server) + require.NotNil(t, client) + + var g errgroup.Group + + g.Go(func() error { + return server.Serve() + }) + + g.Go(func() error { + defer server.Shutdown() + return errors.WithMessage(checkHealth(client), "failed to check health") + }) + + return g.Wait() + }) + require.NoError(t, err) }) - require.NoError(t, err) +} + +func checkHealth(client healthv1.HealthClient) error { + var ( + resp *healthv1.HealthCheckResponse + err error + ) + + for i := 0; i < 5; i++ { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + resp, err = client.Check(ctx, &healthv1.HealthCheckRequest{}) + if err != nil || resp.Status != healthv1.HealthCheckResponse_SERVING { + cancel() + time.Sleep(time.Millisecond * 100) + continue + } + cancel() + break + } + + return err } diff --git a/internal/config/config.go b/internal/config/config.go index 676d67bfa..bd02dd6b1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -32,11 +32,9 @@ func Default() *Config { } // ParseYAML parses the given YAML items and returns a configuration object. -// Multiple items are merged into a single configuration. It uses a default -// configuration as a base. +// Multiple items are merged into a single configuration. func ParseYAML(items ...[]byte) (*Config, error) { - items = append([][]byte{defaultRunmeYAML}, items...) - return parseYAML(items...) + return parseYAML(append([][]byte{defaultRunmeYAML}, items...)...) } func parseYAML(items ...[]byte) (*Config, error) { diff --git a/internal/config/config.schema.json b/internal/config/config.schema.json index d8a21f4c0..75da9334f 100644 --- a/internal/config/config.schema.json +++ b/internal/config/config.schema.json @@ -116,6 +116,10 @@ "address": { "type": "string" }, + "max_message_size": { + "type": "integer", + "default": 33554432 + }, "tls": { "type": "object", "properties": { diff --git a/internal/config/config_schema.go b/internal/config/config_schema.go index 3e81b777f..07a6ab94c 100644 --- a/internal/config/config_schema.go +++ b/internal/config/config_schema.go @@ -295,6 +295,9 @@ type ConfigServer struct { // Address corresponds to the JSON schema field "address". Address string `json:"address" yaml:"address"` + // MaxMessageSize corresponds to the JSON schema field "max_message_size". + MaxMessageSize int `json:"max_message_size,omitempty" yaml:"max_message_size,omitempty"` + // Tls corresponds to the JSON schema field "tls". Tls *ConfigServerTls `json:"tls,omitempty" yaml:"tls,omitempty"` } @@ -342,6 +345,9 @@ func (j *ConfigServer) UnmarshalJSON(b []byte) error { if err := json.Unmarshal(b, &plain); err != nil { return err } + if v, ok := raw["max_message_size"]; !ok || v == nil { + plain.MaxMessageSize = 33554432.0 + } *j = ConfigServer(plain) return nil } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 300d2b6ae..ff72a94e2 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -145,6 +145,7 @@ func TestParseYAML(t *testing.T) { Tls: &ConfigServerTls{ Enabled: true, }, + MaxMessageSize: 32 * 1024 * 1024, } require.True( t, diff --git a/internal/config/loader.go b/internal/config/loader.go index fb29b75b0..59fa435cc 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -67,10 +67,6 @@ func NewLoader(configNames []string, configRootPath fs.FS, opts ...LoaderOption) return l } -func (l *Loader) SetConfigRootPath(configRootPath fs.FS) { - l.configRootPath = configRootPath -} - func (l *Loader) FindConfigChain(path string) ([][]byte, error) { paths, err := l.findConfigFilesOnPath(path) if err != nil { diff --git a/internal/config/runme.default.yaml b/internal/config/runme.default.yaml index 877375263..59997e3f2 100644 --- a/internal/config/runme.default.yaml +++ b/internal/config/runme.default.yaml @@ -28,6 +28,7 @@ server: # If not specified, default paths will be used. # cert_file: "/path/to/cert.pem" # key_file: "/path/to/key.pem" + max_message_size: 33554432 # 32 MiB log: enabled: false diff --git a/internal/runnerv2client/client.go b/internal/runnerv2client/client.go index d0e6305e3..849eaabac 100644 --- a/internal/runnerv2client/client.go +++ b/internal/runnerv2client/client.go @@ -8,27 +8,27 @@ import ( "github.com/pkg/errors" "go.uber.org/zap" "google.golang.org/grpc" + "google.golang.org/grpc/health/grpc_health_v1" runnerv2 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/runner/v2" ) +const MaxMsgSize = 32 * 1024 * 1024 // 32 MiB + type Client struct { runnerv2.RunnerServiceClient + grpc_health_v1.HealthClient + conn *grpc.ClientConn logger *zap.Logger } -func New(target string, logger *zap.Logger, opts ...grpc.DialOption) (*Client, error) { - client, err := grpc.NewClient(target, opts...) - if err != nil { - return nil, errors.WithStack(err) - } - serviceClient := &Client{ - RunnerServiceClient: runnerv2.NewRunnerServiceClient(client), - conn: client, - logger: logger, +func New(clientConn *grpc.ClientConn, logger *zap.Logger) *Client { + return &Client{ + RunnerServiceClient: runnerv2.NewRunnerServiceClient(clientConn), + HealthClient: grpc_health_v1.NewHealthClient(clientConn), + logger: logger.Named("runnerv2client.Client"), } - return serviceClient, nil } func (c *Client) Close() error { diff --git a/internal/runnerv2client/client_test.go b/internal/runnerv2client/client_test.go index d2d8cb960..795963462 100644 --- a/internal/runnerv2client/client_test.go +++ b/internal/runnerv2client/client_test.go @@ -108,15 +108,18 @@ func TestClient_ExecuteProgram(t *testing.T) { func createClient(t *testing.T, lis *bufconn.Listener) *Client { t.Helper() - logger := zaptest.NewLogger(t) - client, err := New( + + clientConn, err := grpc.NewClient( "passthrough://bufconn", - logger, grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { return lis.Dial() }), grpc.WithTransportCredentials(insecure.NewCredentials()), ) require.NoError(t, err) - return client + + return New( + clientConn, + zaptest.NewLogger(t), + ) } diff --git a/internal/runnerv2service/service_test.go b/internal/runnerv2service/service_test.go index c661dd5ec..c63ac2f79 100644 --- a/internal/runnerv2service/service_test.go +++ b/internal/runnerv2service/service_test.go @@ -1,29 +1,10 @@ package runnerv2service -import ( - "testing/fstest" - - "github.com/stateful/runme/v3/internal/command" - "github.com/stateful/runme/v3/internal/config" - "github.com/stateful/runme/v3/internal/config/autoconfig" -) +import "github.com/stateful/runme/v3/internal/command" func init() { + // SetEnvDumpCommandForTesting overrides the default command that dumps the environment variables. + // Without this line, running tests results in a fork bomb. + // More: https://github.com/stateful/runme/issues/730 command.SetEnvDumpCommandForTesting() - - // Server uses autoconfig to get necessary dependencies. - // One of them, implicit, is [config.Config]. With the default - // [config.Loader] it won't be found during testing, so - // we need to provide an override. - if err := autoconfig.DecorateRoot(func(loader *config.Loader) *config.Loader { - fsys := fstest.MapFS{ - "runme.yaml": { - Data: []byte("version: v1alpha1\n"), - }, - } - loader.SetConfigRootPath(fsys) - return loader - }); err != nil { - panic(err) - } } diff --git a/internal/server/server.go b/internal/server/server.go index c02a4c195..26b8fa13e 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -9,113 +9,116 @@ import ( "github.com/pkg/errors" "go.uber.org/zap" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/health" - healthv1 "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/reflection" - "github.com/stateful/runme/v3/internal/command" - "github.com/stateful/runme/v3/internal/project/projectservice" - "github.com/stateful/runme/v3/internal/runnerv2service" + "github.com/stateful/runme/v3/internal/config" runmetls "github.com/stateful/runme/v3/internal/tls" - parserv1 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/parser/v1" - projectv1 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/project/v1" - runnerv2 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/runner/v2" - "github.com/stateful/runme/v3/pkg/document/editor/editorservice" ) -const maxMsgSize = 32 * 1024 * 1024 // 32 MiB - -type Config struct { - Address string - CertFile string - KeyFile string - TLSEnabled bool -} - type Server struct { - grpcServer *grpc.Server - lis net.Listener - logger *zap.Logger + cfg *config.Config + gs *grpc.Server + lis net.Listener + logger *zap.Logger } +type ServiceRegistrar func(grpc.ServiceRegistrar) + func New( - c *Config, - cmdFactory command.Factory, + cfg *config.Config, logger *zap.Logger, -) (_ *Server, err error) { - var tlsConfig *tls.Config - - if c.TLSEnabled { - // TODO(adamb): redesign runmetls API. - tlsConfig, err = runmetls.LoadOrGenerateConfig(c.CertFile, c.KeyFile, logger) - if err != nil { - return nil, err - } - } - - addr := c.Address - protocol := "tcp" + registrar ServiceRegistrar, +) (*Server, error) { + logger = logger.Named("Server") - var lis net.Listener - - if strings.HasPrefix(addr, "unix://") { - protocol = "unix" - addr = strings.TrimPrefix(addr, "unix://") - - if _, err := os.Stat(addr); !os.IsNotExist(err) { - return nil, err - } - } - - if tlsConfig == nil { - lis, err = net.Listen(protocol, addr) - } else { - lis, err = tls.Listen(protocol, addr, tlsConfig) - } + tlsCfg, err := createTLSConfig(cfg, logger) if err != nil { - return nil, errors.WithStack(err) + return nil, err } - logger.Info("server listening", zap.String("address", addr)) - - grpcServer := grpc.NewServer( - grpc.MaxRecvMsgSize(maxMsgSize), - grpc.MaxSendMsgSize(maxMsgSize), - ) + grpcServer := createGRPCServer(cfg, tlsCfg) // Register runme services. - parserv1.RegisterParserServiceServer(grpcServer, editorservice.NewParserServiceServer(logger)) - projectv1.RegisterProjectServiceServer(grpcServer, projectservice.NewProjectServiceServer(logger)) - runnerService, err := runnerv2service.NewRunnerService(cmdFactory, logger) - if err != nil { - return nil, err - } - runnerv2.RegisterRunnerServiceServer(grpcServer, runnerService) + registrar(grpcServer) // Register health service. healthcheck := health.NewServer() - healthv1.RegisterHealthServer(grpcServer, healthcheck) + grpc_health_v1.RegisterHealthServer(grpcServer, healthcheck) // Setting SERVING for the whole system. - healthcheck.SetServingStatus("", healthv1.HealthCheckResponse_SERVING) + healthcheck.SetServingStatus("", grpc_health_v1.HealthCheckResponse_SERVING) // Register reflection service. reflection.Register(grpcServer) - return &Server{ - lis: lis, - grpcServer: grpcServer, - logger: logger, - }, nil + s := Server{ + cfg: cfg, + gs: grpcServer, + logger: logger, + } + + return &s, nil } func (s *Server) Addr() string { + if s.lis == nil { + return "" + } return s.lis.Addr().String() } -func (s *Server) Serve() error { - return s.grpcServer.Serve(s.lis) +func (s *Server) Serve() (err error) { + s.lis, err = createListener(s.cfg.Server.Address) + if err != nil { + return err + } + s.logger.Info("starting gRPC server", zap.String("address", s.Addr())) + return s.gs.Serve(s.lis) } func (s *Server) Shutdown() { - s.grpcServer.GracefulStop() + s.logger.Info("stopping gRPC server") + s.gs.GracefulStop() +} + +func createListener(addr string) (net.Listener, error) { + protocol := "tcp" + + if strings.HasPrefix(addr, "unix://") { + protocol = "unix" + addr = strings.TrimPrefix(addr, "unix://") + if _, err := os.Stat(addr); !os.IsNotExist(err) { + return nil, err + } + } + + lis, err := net.Listen(protocol, addr) + return lis, errors.WithStack(err) +} + +func createTLSConfig(cfg *config.Config, logger *zap.Logger) (*tls.Config, error) { + if tls := cfg.Server.Tls; tls != nil && tls.Enabled { + // TODO(adamb): redesign runmetls API. + return runmetls.LoadOrGenerateConfig( + *tls.CertFile, // guaranteed in [getRootConfig] + *tls.KeyFile, // guaranteed in [getRootConfig] + logger, + ) + } + return nil, nil +} + +func createGRPCServer(cfg *config.Config, tlsCfg *tls.Config) *grpc.Server { + opts := []grpc.ServerOption{ + grpc.MaxRecvMsgSize(cfg.Server.MaxMessageSize), + grpc.MaxSendMsgSize(cfg.Server.MaxMessageSize), + } + + if tlsCfg != nil { + opts = append(opts, grpc.Creds(credentials.NewTLS(tlsCfg))) + } + + return grpc.NewServer(opts...) } diff --git a/internal/server/server_test.go b/internal/server/server_test.go deleted file mode 100644 index 22f79b0b6..000000000 --- a/internal/server/server_test.go +++ /dev/null @@ -1,106 +0,0 @@ -package server - -import ( - "context" - "path/filepath" - "runtime" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/require" - "go.uber.org/zap/zaptest" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - healthv1 "google.golang.org/grpc/health/grpc_health_v1" - - "github.com/stateful/runme/v3/internal/command" - runmetls "github.com/stateful/runme/v3/internal/tls" -) - -func TestServer(t *testing.T) { - logger := zaptest.NewLogger(t) - factory := command.NewFactory(command.WithLogger(logger)) - - t.Run("tcp", func(t *testing.T) { - cfg := &Config{ - Address: "localhost:0", - } - s, err := New(cfg, factory, logger) - require.NoError(t, err) - errc := make(chan error, 1) - go func() { - errc <- s.Serve() - }() - - testConnectivity(t, s.Addr(), insecure.NewCredentials()) - - s.Shutdown() - require.NoError(t, <-errc) - }) - - t.Run("tcp with tls", func(t *testing.T) { - dir := t.TempDir() - cfg := &Config{ - Address: "localhost:0", - CertFile: filepath.Join(dir, "cert.pem"), - KeyFile: filepath.Join(dir, "key.pem"), - TLSEnabled: true, - } - s, err := New(cfg, factory, logger) - require.NoError(t, err) - errc := make(chan error, 1) - go func() { - errc <- s.Serve() - }() - - tlsConfig, err := runmetls.LoadClientConfig(cfg.CertFile, cfg.KeyFile) - require.NoError(t, err) - - addr := s.Addr() - if runtime.GOOS == "windows" { - addr = strings.TrimPrefix(addr, "unix://") - } - testConnectivity(t, addr, credentials.NewTLS(tlsConfig)) - - s.Shutdown() - require.NoError(t, <-errc) - }) -} - -func testConnectivity(t *testing.T, addr string, creds credentials.TransportCredentials) { - t.Helper() - - var err error - - for i := 0; i < 5; i++ { - var ( - conn *grpc.ClientConn - resp *healthv1.HealthCheckResponse - ) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - conn, err := grpc.NewClient( - addr, - grpc.WithTransportCredentials(creds), - ) - if err != nil { - goto wait - } - - resp, err = healthv1.NewHealthClient(conn).Check(ctx, &healthv1.HealthCheckRequest{}) - if err != nil || resp.Status != healthv1.HealthCheckResponse_SERVING { - goto wait - } - - cancel() - break - - wait: - cancel() - <-time.After(time.Millisecond * 100) - } - - require.NoError(t, err) -} diff --git a/internal/server/server_unix_test.go b/internal/server/server_unix_test.go deleted file mode 100644 index eee7f0865..000000000 --- a/internal/server/server_unix_test.go +++ /dev/null @@ -1,36 +0,0 @@ -//go:build !windows - -package server - -import ( - "path/filepath" - "testing" - - "github.com/stretchr/testify/require" - "go.uber.org/zap/zaptest" - "google.golang.org/grpc/credentials/insecure" - - "github.com/stateful/runme/v3/internal/command" -) - -func TestServerUnixSocket(t *testing.T) { - dir := t.TempDir() - sock := filepath.Join(dir, "runme.sock") - cfg := &Config{ - Address: "unix://" + sock, - } - logger := zaptest.NewLogger(t) - factory := command.NewFactory(command.WithLogger(logger)) - s, err := New(cfg, factory, logger) - require.NoError(t, err) - errc := make(chan error, 1) - go func() { - err := s.Serve() - errc <- err - }() - - testConnectivity(t, cfg.Address, insecure.NewCredentials()) - - s.Shutdown() - require.NoError(t, <-errc) -} diff --git a/testdata/beta/server.txtar b/testdata/beta/server.txtar index 3b2457e4f..746a83874 100644 --- a/testdata/beta/server.txtar +++ b/testdata/beta/server.txtar @@ -4,7 +4,7 @@ exec sleep 8 exec runme beta server stop wait ! stdout . -stderr '(?sm)server listening' +stderr '(?sm)starting gRPC server' -- experimental/runme.yaml -- version: v1alpha1