diff --git a/flyteadmin/pkg/async/notifications/implementations/sendgrid_emailer.go b/flyteadmin/pkg/async/notifications/implementations/sendgrid_emailer.go index a325cbee75..0c69784e98 100644 --- a/flyteadmin/pkg/async/notifications/implementations/sendgrid_emailer.go +++ b/flyteadmin/pkg/async/notifications/implementations/sendgrid_emailer.go @@ -5,10 +5,13 @@ import ( "io/ioutil" "os" "strings" + "time" + "github.com/sendgrid/rest" "github.com/sendgrid/sendgrid-go" "github.com/sendgrid/sendgrid-go/helpers/mail" + "github.com/flyteorg/flyte/flyteadmin/pkg/async" "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications/interfaces" runtimeInterfaces "github.com/flyteorg/flyte/flyteadmin/pkg/runtime/interfaces" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" @@ -16,9 +19,16 @@ import ( "github.com/flyteorg/flyte/flytestdlib/promutils" ) +//go:generate mockery -all -case=underscore -output=../mocks -case=underscore + +type SendgridClient interface { + Send(email *mail.SGMailV3) (*rest.Response, error) +} + type SendgridEmailer struct { - client *sendgrid.Client + client SendgridClient systemMetrics emailMetrics + cfg *runtimeInterfaces.NotificationsConfig } func getEmailAddresses(addresses []string) []*mail.Email { @@ -63,9 +73,18 @@ func getAPIKey(config runtimeInterfaces.EmailServerConfig) string { func (s SendgridEmailer) SendEmail(ctx context.Context, email *admin.EmailMessage) error { m := getSendgridEmail(email) s.systemMetrics.SendTotal.Inc() - response, err := s.client.Send(m) + var response *rest.Response + var err error + err = async.Retry(s.cfg.ReconnectAttempts, time.Duration(s.cfg.ReconnectDelaySeconds)*time.Second, func() error { + response, err = s.client.Send(m) + if err != nil { + logger.Errorf(ctx, "Sendgrid error sending email: %+v with: %+v", email, err) + return err + } + return nil + }) if err != nil { - logger.Errorf(ctx, "Sendgrid error sending %s", err) + logger.Errorf(ctx, "all attempts to send email %+v via sendgrid failed: %+v", email, err) s.systemMetrics.SendError.Inc() return err } @@ -79,5 +98,6 @@ func NewSendGridEmailer(config runtimeInterfaces.NotificationsConfig, scope prom return &SendgridEmailer{ client: sendgrid.NewSendClient(getAPIKey(config.NotificationsEmailerConfig.EmailerConfig)), systemMetrics: newEmailMetrics(scope.NewSubScope("sendgrid")), + cfg: &config, } } diff --git a/flyteadmin/pkg/async/notifications/implementations/sendgrid_emailer_test.go b/flyteadmin/pkg/async/notifications/implementations/sendgrid_emailer_test.go index eafad84b2c..dc9760b33c 100644 --- a/flyteadmin/pkg/async/notifications/implementations/sendgrid_emailer_test.go +++ b/flyteadmin/pkg/async/notifications/implementations/sendgrid_emailer_test.go @@ -1,27 +1,26 @@ package implementations import ( + "context" + "errors" "io/ioutil" "os" "path" + "strings" "testing" + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/sendgrid/rest" "github.com/stretchr/testify/assert" + "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications/mocks" runtimeInterfaces "github.com/flyteorg/flyte/flyteadmin/pkg/runtime/interfaces" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyte/flytestdlib/promutils" ) -func TestAddresses(t *testing.T) { - addresses := []string{"alice@example.com", "bob@example.com"} - sgAddresses := getEmailAddresses(addresses) - assert.Equal(t, sgAddresses[0].Address, "alice@example.com") - assert.Equal(t, sgAddresses[1].Address, "bob@example.com") -} - -func TestGetEmail(t *testing.T) { - emailNotification := &admin.EmailMessage{ +var ( + emailNotification = &admin.EmailMessage{ SubjectLine: "Notice: Execution \"name\" has succeeded in \"domain\".", SenderEmail: "no-reply@example.com", RecipientsEmail: []string{ @@ -32,7 +31,16 @@ func TestGetEmail(t *testing.T) { "" + "https://example.com/executions/T/B/D.", } +) +func TestAddresses(t *testing.T) { + addresses := []string{"alice@example.com", "bob@example.com"} + sgAddresses := getEmailAddresses(addresses) + assert.Equal(t, sgAddresses[0].Address, "alice@example.com") + assert.Equal(t, sgAddresses[1].Address, "bob@example.com") +} + +func TestGetEmail(t *testing.T) { sgEmail := getSendgridEmail(emailNotification) assert.Equal(t, `Notice: Execution "name" has succeeded in "domain".`, sgEmail.Personalizations[0].Subject) assert.Equal(t, "john@example.com", sgEmail.Personalizations[0].To[1].Address) @@ -98,3 +106,63 @@ func TestNoFile(t *testing.T) { // shouldn't reach here t.Errorf("did not panic") } + +func TestSendEmail(t *testing.T) { + ctx := context.TODO() + expectedErr := errors.New("expected") + t.Run("exhaust all retry attempts", func(t *testing.T) { + sendgridClient := &mocks.SendgridClient{} + expectedEmail := getSendgridEmail(emailNotification) + sendgridClient.OnSendMatch(expectedEmail). + Return(nil, expectedErr).Times(3) + sendgridClient.OnSendMatch(expectedEmail). + Return(&rest.Response{Body: "email body"}, nil).Once() + scope := promutils.NewScope("bademailer") + emailerMetrics := newEmailMetrics(scope) + + emailer := SendgridEmailer{ + client: sendgridClient, + systemMetrics: emailerMetrics, + cfg: &runtimeInterfaces.NotificationsConfig{ + ReconnectAttempts: 1, + }, + } + + err := emailer.SendEmail(ctx, emailNotification) + assert.EqualError(t, err, expectedErr.Error()) + + assert.NoError(t, testutil.CollectAndCompare(emailerMetrics.SendError, strings.NewReader(` + # HELP bademailer:send_error Number of errors when sending email via Emailer + # TYPE bademailer:send_error counter + bademailer:send_error 1 + `))) + }) + t.Run("exhaust all retry attempts", func(t *testing.T) { + ctx := context.TODO() + sendgridClient := &mocks.SendgridClient{} + expectedEmail := getSendgridEmail(emailNotification) + sendgridClient.OnSendMatch(expectedEmail). + Return(nil, expectedErr).Once() + sendgridClient.OnSendMatch(expectedEmail). + Return(&rest.Response{Body: "email body"}, nil).Once() + scope := promutils.NewScope("goodemailer") + emailerMetrics := newEmailMetrics(scope) + + emailer := SendgridEmailer{ + client: sendgridClient, + systemMetrics: emailerMetrics, + cfg: &runtimeInterfaces.NotificationsConfig{ + ReconnectAttempts: 1, + }, + } + + err := emailer.SendEmail(ctx, emailNotification) + assert.NoError(t, err) + + assert.NoError(t, testutil.CollectAndCompare(emailerMetrics.SendError, strings.NewReader(` + # HELP goodemailer:send_error Number of errors when sending email via Emailer + # TYPE goodemailer:send_error counter + goodemailer:send_error 0 + `))) + }) +} diff --git a/flyteadmin/pkg/async/notifications/mocks/sendgrid_client.go b/flyteadmin/pkg/async/notifications/mocks/sendgrid_client.go new file mode 100644 index 0000000000..4d9e260908 --- /dev/null +++ b/flyteadmin/pkg/async/notifications/mocks/sendgrid_client.go @@ -0,0 +1,56 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + mail "github.com/sendgrid/sendgrid-go/helpers/mail" + mock "github.com/stretchr/testify/mock" + + rest "github.com/sendgrid/rest" +) + +// SendgridClient is an autogenerated mock type for the SendgridClient type +type SendgridClient struct { + mock.Mock +} + +type SendgridClient_Send struct { + *mock.Call +} + +func (_m SendgridClient_Send) Return(_a0 *rest.Response, _a1 error) *SendgridClient_Send { + return &SendgridClient_Send{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *SendgridClient) OnSend(email *mail.SGMailV3) *SendgridClient_Send { + c_call := _m.On("Send", email) + return &SendgridClient_Send{Call: c_call} +} + +func (_m *SendgridClient) OnSendMatch(matchers ...interface{}) *SendgridClient_Send { + c_call := _m.On("Send", matchers...) + return &SendgridClient_Send{Call: c_call} +} + +// Send provides a mock function with given fields: email +func (_m *SendgridClient) Send(email *mail.SGMailV3) (*rest.Response, error) { + ret := _m.Called(email) + + var r0 *rest.Response + if rf, ok := ret.Get(0).(func(*mail.SGMailV3) *rest.Response); ok { + r0 = rf(email) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*rest.Response) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*mail.SGMailV3) error); ok { + r1 = rf(email) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +}