diff --git a/flyteadmin/pkg/async/notifications/factory.go b/flyteadmin/pkg/async/notifications/factory.go index 5698c3597e..483978238e 100644 --- a/flyteadmin/pkg/async/notifications/factory.go +++ b/flyteadmin/pkg/async/notifications/factory.go @@ -6,8 +6,6 @@ import ( "sync" "time" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/NYTimes/gizmo/pubsub" gizmoAWS "github.com/NYTimes/gizmo/pubsub/aws" gizmoGCP "github.com/NYTimes/gizmo/pubsub/gcp" @@ -20,6 +18,7 @@ import ( "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications/interfaces" "github.com/flyteorg/flyte/flyteadmin/pkg/common" runtimeInterfaces "github.com/flyteorg/flyte/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flytestdlib/logger" "github.com/flyteorg/flyte/flytestdlib/promutils" ) @@ -29,6 +28,7 @@ const maxRetries = 3 var enable64decoding = false var msgChan chan []byte + var once sync.Once type PublisherConfig struct { @@ -37,222 +37,404 @@ type PublisherConfig struct { type ProcessorConfig struct { QueueName string + AccountID string } type EmailerConfig struct { SenderEmail string - BaseURL string + + BaseURL string } // For sandbox only + func CreateMsgChan() { + once.Do(func() { + msgChan = make(chan []byte) + }) + } func GetEmailer(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope, sm core.SecretManager) interfaces.Emailer { + // If an external email service is specified use that instead. + // TODO: Handling of this is messy, see https://github.com/flyteorg/flyte/issues/1063 + if config.NotificationsEmailerConfig.EmailerConfig.ServiceName != "" { + switch config.NotificationsEmailerConfig.EmailerConfig.ServiceName { + case implementations.Sendgrid: + return implementations.NewSendGridEmailer(config, scope) + case implementations.SMTP: + return implementations.NewSMTPEmailer(context.Background(), config, scope, sm) + default: + panic(fmt.Errorf("No matching email implementation for %s", config.NotificationsEmailerConfig.EmailerConfig.ServiceName)) + } + } switch config.Type { + case common.AWS: + region := config.AWSConfig.Region + if region == "" { + region = config.Region + } + awsConfig := aws.NewConfig().WithRegion(region).WithMaxRetries(maxRetries) + awsSession, err := session.NewSession(awsConfig) + if err != nil { + panic(err) + } + sesClient := ses.New(awsSession) + return implementations.NewAwsEmailer( + config, + scope, + sesClient, ) + case common.Local: + fallthrough + default: + logger.Infof(context.Background(), "Using default noop emailer implementation for config type [%s]", config.Type) + return implementations.NewNoopEmail() + } + } func NewNotificationsProcessor(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope, sm core.SecretManager) interfaces.Processor { + reconnectAttempts := config.ReconnectAttempts + reconnectDelay := time.Duration(config.ReconnectDelaySeconds) * time.Second + var sub pubsub.Subscriber + var emailer interfaces.Emailer + switch config.Type { + case common.AWS: + sqsConfig := gizmoAWS.SQSConfig{ - QueueName: config.NotificationsProcessorConfig.QueueName, + + QueueName: config.NotificationsProcessorConfig.QueueName, + QueueOwnerAccountID: config.NotificationsProcessorConfig.AccountID, + // The AWS configuration type uses SNS to SQS for notifications. + // Gizmo by default will decode the SQS message using Base64 decoding. + // However, the message body of SQS is the SNS message format which isn't Base64 encoded. + ConsumeBase64: &enable64decoding, } + if config.AWSConfig.Region != "" { + sqsConfig.Region = config.AWSConfig.Region + } else { + sqsConfig.Region = config.Region + } + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + sub, err = gizmoAWS.NewSubscriber(sqsConfig) + if err != nil { + logger.Warnf(context.TODO(), "Failed to initialize new gizmo aws subscriber with config [%+v] and err: %v", sqsConfig, err) + } + return err + }) if err != nil { + panic(err) + } + emailer = GetEmailer(config, scope, sm) + return implementations.NewProcessor(sub, emailer, scope) + case common.GCP: + projectID := config.GCPConfig.ProjectID + subscription := config.NotificationsProcessorConfig.QueueName + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + sub, err = gizmoGCP.NewSubscriber(context.TODO(), projectID, subscription) + if err != nil { + logger.Warnf(context.TODO(), "Failed to initialize new gizmo gcp subscriber with config [ProjectID: %s, Subscription: %s] and err: %v", projectID, subscription, err) + } + return err + }) + if err != nil { + panic(err) + } + emailer = GetEmailer(config, scope, sm) + return implementations.NewGcpProcessor(sub, emailer, scope) + case common.Sandbox: + emailer = GetEmailer(config, scope, sm) + return implementations.NewSandboxProcessor(msgChan, emailer) + case common.Local: + fallthrough + default: + logger.Infof(context.Background(), + "Using default noop notifications processor implementation for config type [%s]", config.Type) + return implementations.NewNoopProcess() + } + } func NewNotificationsPublisher(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope) interfaces.Publisher { + reconnectAttempts := config.ReconnectAttempts + reconnectDelay := time.Duration(config.ReconnectDelaySeconds) * time.Second + switch config.Type { + case common.AWS: + snsConfig := gizmoAWS.SNSConfig{ + Topic: config.NotificationsPublisherConfig.TopicName, } + if config.AWSConfig.Region != "" { + snsConfig.Region = config.AWSConfig.Region + } else { + snsConfig.Region = config.Region + } var publisher pubsub.Publisher + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + publisher, err = gizmoAWS.NewPublisher(snsConfig) + return err + }) // Any persistent errors initiating Publisher with Amazon configurations results in a failed start up. + if err != nil { + panic(err) + } + return implementations.NewPublisher(publisher, scope) + case common.GCP: + pubsubConfig := gizmoGCP.Config{ + Topic: config.NotificationsPublisherConfig.TopicName, } + pubsubConfig.ProjectID = config.GCPConfig.ProjectID + var publisher pubsub.MultiPublisher + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + publisher, err = gizmoGCP.NewPublisher(context.TODO(), pubsubConfig) + return err + }) if err != nil { + panic(err) + } + return implementations.NewPublisher(publisher, scope) + case common.Sandbox: + CreateMsgChan() + return implementations.NewSandboxPublisher(msgChan) + case common.Local: + fallthrough + default: + logger.Infof(context.Background(), + "Using default noop notifications publisher implementation for config type [%s]", config.Type) + return implementations.NewNoopPublish() + } + } func NewEventsPublisher(config runtimeInterfaces.ExternalEventsConfig, scope promutils.Scope) interfaces.Publisher { + if !config.Enable { + return implementations.NewNoopPublish() + } + reconnectAttempts := config.ReconnectAttempts + reconnectDelay := time.Duration(config.ReconnectDelaySeconds) * time.Second + switch config.Type { + case common.AWS: + snsConfig := gizmoAWS.SNSConfig{ + Topic: config.EventsPublisherConfig.TopicName, } + snsConfig.Region = config.AWSConfig.Region var publisher pubsub.Publisher + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + publisher, err = gizmoAWS.NewPublisher(snsConfig) + return err + }) // Any persistent errors initiating Publisher with Amazon configurations results in a failed start up. + if err != nil { + panic(err) + } + return implementations.NewEventsPublisher(publisher, scope, config.EventsPublisherConfig.EventTypes) + case common.GCP: + pubsubConfig := gizmoGCP.Config{ + Topic: config.EventsPublisherConfig.TopicName, } + pubsubConfig.ProjectID = config.GCPConfig.ProjectID + var publisher pubsub.MultiPublisher + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + publisher, err = gizmoGCP.NewPublisher(context.TODO(), pubsubConfig) + return err + }) if err != nil { + panic(err) + } + return implementations.NewEventsPublisher(publisher, scope, config.EventsPublisherConfig.EventTypes) + case common.Local: + fallthrough + default: + logger.Infof(context.Background(), + "Using default noop events publisher implementation for config type [%s]", config.Type) + return implementations.NewNoopPublish() + } + } diff --git a/flyteadmin/pkg/async/notifications/factory_test.go b/flyteadmin/pkg/async/notifications/factory_test.go index 72cdb88794..43602525a5 100644 --- a/flyteadmin/pkg/async/notifications/factory_test.go +++ b/flyteadmin/pkg/async/notifications/factory_test.go @@ -4,35 +4,50 @@ import ( "context" "testing" + "github.com/stretchr/testify/assert" + "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications/implementations" runtimeInterfaces "github.com/flyteorg/flyte/flyteadmin/pkg/runtime/interfaces" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" "github.com/flyteorg/flyte/flytestdlib/promutils" - "github.com/stretchr/testify/assert" ) var ( - scope = promutils.NewScope("test_sandbox_processor") + scope = promutils.NewScope("test_sandbox_processor") + notificationsConfig = runtimeInterfaces.NotificationsConfig{ + Type: "sandbox", } + testEmail = admin.EmailMessage{ + RecipientsEmail: []string{ + "a@example.com", + "b@example.com", }, + SenderEmail: "no-reply@example.com", + SubjectLine: "Test email", - Body: "This is a sample email.", + + Body: "This is a sample email.", } ) func TestGetEmailer(t *testing.T) { + defer func() { r := recover(); assert.NotNil(t, r) }() + cfg := runtimeInterfaces.NotificationsConfig{ + NotificationsEmailerConfig: runtimeInterfaces.NotificationsEmailerConfig{ + EmailerConfig: runtimeInterfaces.EmailServerConfig{ + ServiceName: "unsupported", }, }, @@ -41,20 +56,29 @@ func TestGetEmailer(t *testing.T) { GetEmailer(cfg, promutils.NewTestScope(), &mocks.SecretManager{}) // shouldn't reach here + t.Errorf("did not panic") + } func TestNewNotificationPublisherAndProcessor(t *testing.T) { + testSandboxPublisher := NewNotificationsPublisher(notificationsConfig, scope) + assert.IsType(t, testSandboxPublisher, &implementations.SandboxPublisher{}) + testSandboxProcessor := NewNotificationsProcessor(notificationsConfig, scope, &mocks.SecretManager{}) + assert.IsType(t, testSandboxProcessor, &implementations.SandboxProcessor{}) go func() { + testSandboxProcessor.StartProcessing() + }() assert.Nil(t, testSandboxPublisher.Publish(context.Background(), "TEST_NOTIFICATION", &testEmail)) assert.Nil(t, testSandboxProcessor.StopProcessing()) + } diff --git a/flyteadmin/pkg/async/notifications/implementations/smtp_emailer.go b/flyteadmin/pkg/async/notifications/implementations/smtp_emailer.go index b28b0550d9..5a705bc0c1 100644 --- a/flyteadmin/pkg/async/notifications/implementations/smtp_emailer.go +++ b/flyteadmin/pkg/async/notifications/implementations/smtp_emailer.go @@ -19,15 +19,16 @@ import ( ) type SMTPEmailer struct { - config *runtimeInterfaces.NotificationsEmailerConfig - systemMetrics emailMetrics - tlsConf *tls.Config - auth *smtp.Auth - smtpClient *smtp.Client + config *runtimeInterfaces.NotificationsEmailerConfig + systemMetrics emailMetrics + tlsConf *tls.Config + auth *smtp.Auth + smtpClient interfaces.SMTPClient + CreateSMTPClientFunc func(connectString string) (interfaces.SMTPClient, error) } -func (s *SMTPEmailer) createClient(ctx context.Context) (*smtp.Client, error) { - newClient, err := smtp.Dial(s.config.EmailerConfig.SMTPServer + ":" + s.config.EmailerConfig.SMTPPort) +func (s *SMTPEmailer) createClient(ctx context.Context) (interfaces.SMTPClient, error) { + newClient, err := s.CreateSMTPClientFunc(s.config.EmailerConfig.SMTPServer + ":" + s.config.EmailerConfig.SMTPPort) if err != nil { return nil, s.emailError(ctx, fmt.Sprintf("Error creating email client: %s", err)) @@ -39,7 +40,7 @@ func (s *SMTPEmailer) createClient(ctx context.Context) (*smtp.Client, error) { if ok, _ := newClient.Extension("STARTTLS"); ok { if err = newClient.StartTLS(s.tlsConf); err != nil { - return nil, err + return nil, s.emailError(ctx, fmt.Sprintf("Error initiating connection to SMTP server: %s", err)) } } @@ -77,7 +78,7 @@ func (s *SMTPEmailer) SendEmail(ctx context.Context, email *admin.EmailMessage) for _, recipient := range email.RecipientsEmail { if err := s.smtpClient.Rcpt(recipient); err != nil { - logger.Errorf(ctx, "Error adding email recipient: %s", err) + return s.emailError(ctx, fmt.Sprintf("Error adding email recipient: %s", err)) } } @@ -150,5 +151,8 @@ func NewSMTPEmailer(ctx context.Context, config runtimeInterfaces.NotificationsC systemMetrics: newEmailMetrics(scope.NewSubScope("smtp")), tlsConf: tlsConfiguration, auth: &auth, + CreateSMTPClientFunc: func(connectString string) (interfaces.SMTPClient, error) { + return smtp.Dial(connectString) + }, } } diff --git a/flyteadmin/pkg/async/notifications/implementations/smtp_emailer_test.go b/flyteadmin/pkg/async/notifications/implementations/smtp_emailer_test.go index 1fd3a75e00..924bd1278c 100644 --- a/flyteadmin/pkg/async/notifications/implementations/smtp_emailer_test.go +++ b/flyteadmin/pkg/async/notifications/implementations/smtp_emailer_test.go @@ -2,17 +2,42 @@ package implementations import ( "context" + "crypto/tls" + "errors" + "google.golang.org/grpc/codes" + "net/smtp" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + notification_interfaces "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications/interfaces" + notification_mocks "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications/mocks" + + flyte_errors "github.com/flyteorg/flyte/flyteadmin/pkg/errors" + "github.com/flyteorg/flyte/flyteadmin/pkg/runtime/interfaces" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" "github.com/flyteorg/flyte/flytestdlib/promutils" ) +type StringWriter struct { + buffer string + writeErr error + closeErr error +} + +func (s *StringWriter) Write(p []byte) (n int, err error) { + s.buffer = s.buffer + string(p) + return len(p), s.writeErr +} + +func (s *StringWriter) Close() error { + return s.closeErr +} + func getNotificationsEmailerConfig() interfaces.NotificationsConfig { return interfaces.NotificationsConfig{ Type: "", @@ -58,3 +83,414 @@ func TestNewSmtpEmailer(t *testing.T) { assert.NotNil(t, smtpEmailer) } + +func TestCreateClient(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := notification_mocks.NewSMTPClient(t) + smtpClient.EXPECT().Hello("localhost").Return(nil).Times(1) + smtpClient.EXPECT().Extension("STARTTLS").Return(true, "").Times(1) + smtpClient.EXPECT().StartTLS(&tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.EXPECT().Extension("AUTH").Return(true, "").Times(1) + smtpClient.EXPECT().Auth(auth).Return(nil).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + client, err := smtpEmailer.createClient(context.Background()) + + assert.Nil(t, err) + assert.NotNil(t, client) + +} + +func TestCreateClientErrorCreatingClient(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := notification_mocks.NewSMTPClient(t) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, errors.New("error creating client")) + + client, err := smtpEmailer.createClient(context.Background()) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + assert.Nil(t, client) + +} + +func TestCreateClientErrorHello(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := notification_mocks.NewSMTPClient(t) + smtpClient.EXPECT().Hello("localhost").Return(errors.New("Error with hello")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + client, err := smtpEmailer.createClient(context.Background()) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + assert.Nil(t, client) + +} + +func TestCreateClientErrorStartTLS(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := notification_mocks.NewSMTPClient(t) + smtpClient.EXPECT().Hello("localhost").Return(nil).Times(1) + smtpClient.EXPECT().Extension("STARTTLS").Return(true, "").Times(1) + smtpClient.EXPECT().StartTLS(&tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(errors.New("Error with startls")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + client, err := smtpEmailer.createClient(context.Background()) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + assert.Nil(t, client) + +} + +func TestCreateClientErrorAuth(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := notification_mocks.NewSMTPClient(t) + smtpClient.EXPECT().Hello("localhost").Return(nil).Times(1) + smtpClient.EXPECT().Extension("STARTTLS").Return(true, "").Times(1) + smtpClient.EXPECT().StartTLS(&tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.EXPECT().Extension("AUTH").Return(true, "").Times(1) + smtpClient.EXPECT().Auth(auth).Return(errors.New("Error with hello")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + client, err := smtpEmailer.createClient(context.Background()) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + assert.Nil(t, client) + +} + +func TestSendMail(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + stringWriter := StringWriter{buffer: ""} + + smtpClient := notification_mocks.NewSMTPClient(t) + smtpClient.EXPECT().Noop().Return(errors.New("no connection")).Times(1) + smtpClient.EXPECT().Close().Return(nil).Times(1) + smtpClient.EXPECT().Hello("localhost").Return(nil).Times(1) + smtpClient.EXPECT().Extension("STARTTLS").Return(true, "").Times(1) + smtpClient.EXPECT().StartTLS(&tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.EXPECT().Extension("AUTH").Return(true, "").Times(1) + smtpClient.EXPECT().Auth(auth).Return(nil).Times(1) + smtpClient.EXPECT().Mail("flyte@flyte.org").Return(nil).Times(1) + smtpClient.EXPECT().Rcpt("alice@flyte.org").Return(nil).Times(1) + smtpClient.EXPECT().Rcpt("bob@flyte.org").Return(nil).Times(1) + smtpClient.EXPECT().Data().Return(&stringWriter, nil).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.True(t, strings.Contains(stringWriter.buffer, "From: sender")) + assert.True(t, strings.Contains(stringWriter.buffer, "To: alice@flyte.org,bob@flyte.org")) + assert.True(t, strings.Contains(stringWriter.buffer, "Subject: subject")) + assert.True(t, strings.Contains(stringWriter.buffer, "This is an email.")) + assert.Nil(t, err) + +} + +func TestSendMailCreateClientError(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := notification_mocks.NewSMTPClient(t) + smtpClient.EXPECT().Noop().Return(errors.New("no connection")).Times(1) + smtpClient.EXPECT().Close().Return(nil).Times(1) + smtpClient.EXPECT().Hello("localhost").Return(errors.New("error hello")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func TestSendMailErrorMail(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := notification_mocks.NewSMTPClient(t) + smtpClient.EXPECT().Noop().Return(errors.New("no connection")).Times(1) + smtpClient.EXPECT().Close().Return(nil).Times(1) + smtpClient.EXPECT().Hello("localhost").Return(nil).Times(1) + smtpClient.EXPECT().Extension("STARTTLS").Return(true, "").Times(1) + smtpClient.EXPECT().StartTLS(&tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.EXPECT().Extension("AUTH").Return(true, "").Times(1) + smtpClient.EXPECT().Auth(auth).Return(nil).Times(1) + smtpClient.EXPECT().Mail("flyte@flyte.org").Return(errors.New("error sending mail")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func TestSendMailErrorRecipient(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := notification_mocks.NewSMTPClient(t) + smtpClient.EXPECT().Noop().Return(errors.New("no connection")).Times(1) + smtpClient.EXPECT().Close().Return(nil).Times(1) + smtpClient.EXPECT().Hello("localhost").Return(nil).Times(1) + smtpClient.EXPECT().Extension("STARTTLS").Return(true, "").Times(1) + smtpClient.EXPECT().StartTLS(&tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.EXPECT().Extension("AUTH").Return(true, "").Times(1) + smtpClient.EXPECT().Auth(auth).Return(nil).Times(1) + smtpClient.EXPECT().Mail("flyte@flyte.org").Return(nil).Times(1) + smtpClient.EXPECT().Rcpt("alice@flyte.org").Return(errors.New("error adding recipient")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func TestSendMailErrorData(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := notification_mocks.NewSMTPClient(t) + smtpClient.EXPECT().Noop().Return(errors.New("no connection")).Times(1) + smtpClient.EXPECT().Close().Return(nil).Times(1) + smtpClient.EXPECT().Hello("localhost").Return(nil).Times(1) + smtpClient.EXPECT().Extension("STARTTLS").Return(true, "").Times(1) + smtpClient.EXPECT().StartTLS(&tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.EXPECT().Extension("AUTH").Return(true, "").Times(1) + smtpClient.EXPECT().Auth(auth).Return(nil).Times(1) + smtpClient.EXPECT().Mail("flyte@flyte.org").Return(nil).Times(1) + smtpClient.EXPECT().Rcpt("alice@flyte.org").Return(nil).Times(1) + smtpClient.EXPECT().Rcpt("bob@flyte.org").Return(nil).Times(1) + smtpClient.EXPECT().Data().Return(nil, errors.New("error creating data writer")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func TestSendMailErrorWriting(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + stringWriter := StringWriter{buffer: "", writeErr: errors.New("error writing"), closeErr: nil} + + smtpClient := notification_mocks.NewSMTPClient(t) + smtpClient.EXPECT().Noop().Return(errors.New("no connection")).Times(1) + smtpClient.EXPECT().Close().Return(nil).Times(1) + smtpClient.EXPECT().Hello("localhost").Return(nil).Times(1) + smtpClient.EXPECT().Extension("STARTTLS").Return(true, "").Times(1) + smtpClient.EXPECT().StartTLS(&tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.EXPECT().Extension("AUTH").Return(true, "").Times(1) + smtpClient.EXPECT().Auth(auth).Return(nil).Times(1) + smtpClient.EXPECT().Mail("flyte@flyte.org").Return(nil).Times(1) + smtpClient.EXPECT().Rcpt("alice@flyte.org").Return(nil).Times(1) + smtpClient.EXPECT().Rcpt("bob@flyte.org").Return(nil).Times(1) + smtpClient.EXPECT().Data().Return(&stringWriter, nil).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func TestSendMailErrorClose(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + stringWriter := StringWriter{buffer: "", writeErr: nil, closeErr: errors.New("error writing")} + + smtpClient := notification_mocks.NewSMTPClient(t) + smtpClient.EXPECT().Noop().Return(errors.New("no connection")).Times(1) + smtpClient.EXPECT().Close().Return(nil).Times(1) + smtpClient.EXPECT().Hello("localhost").Return(nil).Times(1) + smtpClient.EXPECT().Extension("STARTTLS").Return(true, "").Times(1) + smtpClient.EXPECT().StartTLS(&tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.EXPECT().Extension("AUTH").Return(true, "").Times(1) + smtpClient.EXPECT().Auth(auth).Return(nil).Times(1) + smtpClient.EXPECT().Mail("flyte@flyte.org").Return(nil).Times(1) + smtpClient.EXPECT().Rcpt("alice@flyte.org").Return(nil).Times(1) + smtpClient.EXPECT().Rcpt("bob@flyte.org").Return(nil).Times(1) + smtpClient.EXPECT().Data().Return(&stringWriter, nil).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.True(t, strings.Contains(stringWriter.buffer, "From: sender")) + assert.True(t, strings.Contains(stringWriter.buffer, "To: alice@flyte.org,bob@flyte.org")) + assert.True(t, strings.Contains(stringWriter.buffer, "Subject: subject")) + assert.True(t, strings.Contains(stringWriter.buffer, "This is an email.")) + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func createSMTPEmailer(smtpClient notification_interfaces.SMTPClient, tlsConf *tls.Config, auth *smtp.Auth, creationErr error) *SMTPEmailer { + secretManagerMock := mocks.SecretManager{} + secretManagerMock.On("Get", mock.Anything, "smtp_password").Return("password", nil) + + notificationsConfig := getNotificationsEmailerConfig() + + return &SMTPEmailer{ + config: ¬ificationsConfig.NotificationsEmailerConfig, + systemMetrics: newEmailMetrics(promutils.NewTestScope()), + tlsConf: tlsConf, + auth: auth, + CreateSMTPClientFunc: func(connectString string) (notification_interfaces.SMTPClient, error) { + return smtpClient, creationErr + }, + smtpClient: smtpClient, + } +} diff --git a/flyteadmin/pkg/async/notifications/interfaces/smtp_client.go b/flyteadmin/pkg/async/notifications/interfaces/smtp_client.go new file mode 100644 index 0000000000..9d22cdc345 --- /dev/null +++ b/flyteadmin/pkg/async/notifications/interfaces/smtp_client.go @@ -0,0 +1,22 @@ +package interfaces + +import ( + "crypto/tls" + "io" + "net/smtp" +) + +// This interface is introduced to allow for mocking of the smtp.Client object. + +//go:generate mockery --name=SMTPClient --output=../mocks --case=underscore --with-expecter +type SMTPClient interface { + Hello(localName string) error + Extension(ext string) (bool, string) + Auth(a smtp.Auth) error + StartTLS(config *tls.Config) error + Noop() error + Close() error + Mail(from string) error + Rcpt(to string) error + Data() (io.WriteCloser, error) +} diff --git a/flyteadmin/pkg/async/notifications/mocks/smtp_client.go b/flyteadmin/pkg/async/notifications/mocks/smtp_client.go new file mode 100644 index 0000000000..39c3a63caa --- /dev/null +++ b/flyteadmin/pkg/async/notifications/mocks/smtp_client.go @@ -0,0 +1,472 @@ +// Code generated by mockery v2.45.1. DO NOT EDIT. + +package mocks + +import ( + io "io" + smtp "net/smtp" + + mock "github.com/stretchr/testify/mock" + + tls "crypto/tls" +) + +// SMTPClient is an autogenerated mock type for the SMTPClient type +type SMTPClient struct { + mock.Mock +} + +type SMTPClient_Expecter struct { + mock *mock.Mock +} + +func (_m *SMTPClient) EXPECT() *SMTPClient_Expecter { + return &SMTPClient_Expecter{mock: &_m.Mock} +} + +// Auth provides a mock function with given fields: a +func (_m *SMTPClient) Auth(a smtp.Auth) error { + ret := _m.Called(a) + + if len(ret) == 0 { + panic("no return value specified for Auth") + } + + var r0 error + if rf, ok := ret.Get(0).(func(smtp.Auth) error); ok { + r0 = rf(a) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SMTPClient_Auth_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Auth' +type SMTPClient_Auth_Call struct { + *mock.Call +} + +// Auth is a helper method to define mock.On call +// - a smtp.Auth +func (_e *SMTPClient_Expecter) Auth(a interface{}) *SMTPClient_Auth_Call { + return &SMTPClient_Auth_Call{Call: _e.mock.On("Auth", a)} +} + +func (_c *SMTPClient_Auth_Call) Run(run func(a smtp.Auth)) *SMTPClient_Auth_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(smtp.Auth)) + }) + return _c +} + +func (_c *SMTPClient_Auth_Call) Return(_a0 error) *SMTPClient_Auth_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *SMTPClient_Auth_Call) RunAndReturn(run func(smtp.Auth) error) *SMTPClient_Auth_Call { + _c.Call.Return(run) + return _c +} + +// Close provides a mock function with given fields: +func (_m *SMTPClient) Close() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SMTPClient_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type SMTPClient_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *SMTPClient_Expecter) Close() *SMTPClient_Close_Call { + return &SMTPClient_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *SMTPClient_Close_Call) Run(run func()) *SMTPClient_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *SMTPClient_Close_Call) Return(_a0 error) *SMTPClient_Close_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *SMTPClient_Close_Call) RunAndReturn(run func() error) *SMTPClient_Close_Call { + _c.Call.Return(run) + return _c +} + +// Data provides a mock function with given fields: +func (_m *SMTPClient) Data() (io.WriteCloser, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Data") + } + + var r0 io.WriteCloser + var r1 error + if rf, ok := ret.Get(0).(func() (io.WriteCloser, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() io.WriteCloser); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(io.WriteCloser) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SMTPClient_Data_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Data' +type SMTPClient_Data_Call struct { + *mock.Call +} + +// Data is a helper method to define mock.On call +func (_e *SMTPClient_Expecter) Data() *SMTPClient_Data_Call { + return &SMTPClient_Data_Call{Call: _e.mock.On("Data")} +} + +func (_c *SMTPClient_Data_Call) Run(run func()) *SMTPClient_Data_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *SMTPClient_Data_Call) Return(_a0 io.WriteCloser, _a1 error) *SMTPClient_Data_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *SMTPClient_Data_Call) RunAndReturn(run func() (io.WriteCloser, error)) *SMTPClient_Data_Call { + _c.Call.Return(run) + return _c +} + +// Extension provides a mock function with given fields: ext +func (_m *SMTPClient) Extension(ext string) (bool, string) { + ret := _m.Called(ext) + + if len(ret) == 0 { + panic("no return value specified for Extension") + } + + var r0 bool + var r1 string + if rf, ok := ret.Get(0).(func(string) (bool, string)); ok { + return rf(ext) + } + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(ext) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(string) string); ok { + r1 = rf(ext) + } else { + r1 = ret.Get(1).(string) + } + + return r0, r1 +} + +// SMTPClient_Extension_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Extension' +type SMTPClient_Extension_Call struct { + *mock.Call +} + +// Extension is a helper method to define mock.On call +// - ext string +func (_e *SMTPClient_Expecter) Extension(ext interface{}) *SMTPClient_Extension_Call { + return &SMTPClient_Extension_Call{Call: _e.mock.On("Extension", ext)} +} + +func (_c *SMTPClient_Extension_Call) Run(run func(ext string)) *SMTPClient_Extension_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *SMTPClient_Extension_Call) Return(_a0 bool, _a1 string) *SMTPClient_Extension_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *SMTPClient_Extension_Call) RunAndReturn(run func(string) (bool, string)) *SMTPClient_Extension_Call { + _c.Call.Return(run) + return _c +} + +// Hello provides a mock function with given fields: localName +func (_m *SMTPClient) Hello(localName string) error { + ret := _m.Called(localName) + + if len(ret) == 0 { + panic("no return value specified for Hello") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(localName) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SMTPClient_Hello_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Hello' +type SMTPClient_Hello_Call struct { + *mock.Call +} + +// Hello is a helper method to define mock.On call +// - localName string +func (_e *SMTPClient_Expecter) Hello(localName interface{}) *SMTPClient_Hello_Call { + return &SMTPClient_Hello_Call{Call: _e.mock.On("Hello", localName)} +} + +func (_c *SMTPClient_Hello_Call) Run(run func(localName string)) *SMTPClient_Hello_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *SMTPClient_Hello_Call) Return(_a0 error) *SMTPClient_Hello_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *SMTPClient_Hello_Call) RunAndReturn(run func(string) error) *SMTPClient_Hello_Call { + _c.Call.Return(run) + return _c +} + +// Mail provides a mock function with given fields: from +func (_m *SMTPClient) Mail(from string) error { + ret := _m.Called(from) + + if len(ret) == 0 { + panic("no return value specified for Mail") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(from) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SMTPClient_Mail_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Mail' +type SMTPClient_Mail_Call struct { + *mock.Call +} + +// Mail is a helper method to define mock.On call +// - from string +func (_e *SMTPClient_Expecter) Mail(from interface{}) *SMTPClient_Mail_Call { + return &SMTPClient_Mail_Call{Call: _e.mock.On("Mail", from)} +} + +func (_c *SMTPClient_Mail_Call) Run(run func(from string)) *SMTPClient_Mail_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *SMTPClient_Mail_Call) Return(_a0 error) *SMTPClient_Mail_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *SMTPClient_Mail_Call) RunAndReturn(run func(string) error) *SMTPClient_Mail_Call { + _c.Call.Return(run) + return _c +} + +// Noop provides a mock function with given fields: +func (_m *SMTPClient) Noop() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Noop") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SMTPClient_Noop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Noop' +type SMTPClient_Noop_Call struct { + *mock.Call +} + +// Noop is a helper method to define mock.On call +func (_e *SMTPClient_Expecter) Noop() *SMTPClient_Noop_Call { + return &SMTPClient_Noop_Call{Call: _e.mock.On("Noop")} +} + +func (_c *SMTPClient_Noop_Call) Run(run func()) *SMTPClient_Noop_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *SMTPClient_Noop_Call) Return(_a0 error) *SMTPClient_Noop_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *SMTPClient_Noop_Call) RunAndReturn(run func() error) *SMTPClient_Noop_Call { + _c.Call.Return(run) + return _c +} + +// Rcpt provides a mock function with given fields: to +func (_m *SMTPClient) Rcpt(to string) error { + ret := _m.Called(to) + + if len(ret) == 0 { + panic("no return value specified for Rcpt") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(to) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SMTPClient_Rcpt_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Rcpt' +type SMTPClient_Rcpt_Call struct { + *mock.Call +} + +// Rcpt is a helper method to define mock.On call +// - to string +func (_e *SMTPClient_Expecter) Rcpt(to interface{}) *SMTPClient_Rcpt_Call { + return &SMTPClient_Rcpt_Call{Call: _e.mock.On("Rcpt", to)} +} + +func (_c *SMTPClient_Rcpt_Call) Run(run func(to string)) *SMTPClient_Rcpt_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *SMTPClient_Rcpt_Call) Return(_a0 error) *SMTPClient_Rcpt_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *SMTPClient_Rcpt_Call) RunAndReturn(run func(string) error) *SMTPClient_Rcpt_Call { + _c.Call.Return(run) + return _c +} + +// StartTLS provides a mock function with given fields: config +func (_m *SMTPClient) StartTLS(config *tls.Config) error { + ret := _m.Called(config) + + if len(ret) == 0 { + panic("no return value specified for StartTLS") + } + + var r0 error + if rf, ok := ret.Get(0).(func(*tls.Config) error); ok { + r0 = rf(config) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SMTPClient_StartTLS_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StartTLS' +type SMTPClient_StartTLS_Call struct { + *mock.Call +} + +// StartTLS is a helper method to define mock.On call +// - config *tls.Config +func (_e *SMTPClient_Expecter) StartTLS(config interface{}) *SMTPClient_StartTLS_Call { + return &SMTPClient_StartTLS_Call{Call: _e.mock.On("StartTLS", config)} +} + +func (_c *SMTPClient_StartTLS_Call) Run(run func(config *tls.Config)) *SMTPClient_StartTLS_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*tls.Config)) + }) + return _c +} + +func (_c *SMTPClient_StartTLS_Call) Return(_a0 error) *SMTPClient_StartTLS_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *SMTPClient_StartTLS_Call) RunAndReturn(run func(*tls.Config) error) *SMTPClient_StartTLS_Call { + _c.Call.Return(run) + return _c +} + +// NewSMTPClient creates a new instance of SMTPClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewSMTPClient(t interface { + mock.TestingT + Cleanup(func()) +}) *SMTPClient { + mock := &SMTPClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +}