diff --git a/flyteplugins/go/tasks/pluginmachinery/io/mocks/output_reader.go b/flyteplugins/go/tasks/pluginmachinery/io/mocks/output_reader.go index 285a79dad9..7120063ceb 100644 --- a/flyteplugins/go/tasks/pluginmachinery/io/mocks/output_reader.go +++ b/flyteplugins/go/tasks/pluginmachinery/io/mocks/output_reader.go @@ -5,8 +5,8 @@ package mocks import ( context "context" - core "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" io "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io" + core "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" mock "github.com/stretchr/testify/mock" ) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/config.go b/flyteplugins/go/tasks/plugins/webapi/agent/config.go index a0014e9627..a25b28891c 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/config.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/config.go @@ -43,6 +43,11 @@ var ( Endpoint: "dns:///flyteagent.flyte.svc.cluster.local:80", Insecure: true, DefaultTimeout: config.Duration{Duration: 10 * time.Second}, + KeepAliveParameters: &KeepAliveParameters{ + Time: config.Duration{Duration: 10 * time.Second}, + Timeout: config.Duration{Duration: 5 * time.Second}, + PermitWithoutStream: true, + }, }, } @@ -67,6 +72,22 @@ type Config struct { AgentForTaskTypes map[string]string `json:"agentForTaskTypes" pflag:"-,"` } +// KeepAliveParameters defines keepalive parameters on the client-side. For more details, check https://pkg.go.dev/google.golang.org/grpc/keepalive#ClientParameters +type KeepAliveParameters struct { + // After a duration of this time if the client doesn't see any activity it + // pings the server to see if the transport is still alive. + // If set below 10s, a minimum value of 10s will be used instead. + Time config.Duration `json:"time"` + // After having pinged for keepalive check, the client waits for a duration + // of Timeout and if no activity is seen even after that the connection is + // closed. + Timeout config.Duration `json:"timeout"` + // If true, client sends keepalive pings even with no active RPCs. If false, + // when there are no active RPCs, Time and Timeout will be ignored and no + // keepalive pings will be sent. + PermitWithoutStream bool `json:"permitWithoutStream"` +} + type Agent struct { // Endpoint points to an agent gRPC endpoint Endpoint string `json:"endpoint"` @@ -82,6 +103,9 @@ type Agent struct { // DefaultTimeout gives the default RPC timeout if a more specific one is not defined in Timeouts; if neither DefaultTimeout nor Timeouts is defined for an operation, RPC timeout will not be enforced DefaultTimeout config.Duration `json:"defaultTimeout"` + + // KeepAliveParameters defines keepalive parameters for the gRPC client + KeepAliveParameters *KeepAliveParameters `json:"keepAliveParameters"` } func GetConfig() *Config { diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/config_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/config_test.go index b110897a47..b52b37ad6a 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/config_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/config_test.go @@ -27,6 +27,11 @@ func TestGetAndSetConfig(t *testing.T) { }, } cfg.DefaultAgent.DefaultTimeout = config.Duration{Duration: 10 * time.Second} + cfg.DefaultAgent.KeepAliveParameters = &KeepAliveParameters{ + Time: config.Duration{Duration: 10 * time.Second}, + Timeout: config.Duration{Duration: 5 * time.Second}, + PermitWithoutStream: true, + } cfg.Agents = map[string]*Agent{ "agent_1": { Insecure: cfg.DefaultAgent.Insecure, diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index 9713ba90f3..aaf9827fe5 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -10,8 +10,8 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/keepalive" pluginErrors "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" @@ -247,6 +247,15 @@ func getClientFunc(ctx context.Context, agent *Agent, connectionCache map[*Agent opts = append(opts, grpc.WithDefaultServiceConfig(agent.DefaultServiceConfig)) } + if agent.KeepAliveParameters != nil { + + opts = append(opts, grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: agent.KeepAliveParameters.Time.Duration, + Timeout: agent.KeepAliveParameters.Timeout.Duration, + PermitWithoutStream: agent.KeepAliveParameters.PermitWithoutStream, + })) + } + var err error conn, err = grpc.Dial(agent.Endpoint, opts...) if err != nil { diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index 24e93bf1c8..692c7c9d69 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -65,7 +65,12 @@ func TestPlugin(t *testing.T) { }) t.Run("test getClientFunc more config", func(t *testing.T) { - client, err := getClientFunc(context.Background(), &Agent{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"}, map[*Agent]*grpc.ClientConn{}) + client, err := getClientFunc(context.Background(), &Agent{ + Endpoint: "localhost:80", + Insecure: true, + DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}", + KeepAliveParameters: &KeepAliveParameters{Time: config.Duration{Duration: 10 * time.Second}}}, + map[*Agent]*grpc.ClientConn{}) assert.NoError(t, err) assert.NotNil(t, client) })