Skip to content

Commit

Permalink
Revert "chore: simplify courier code (#3603)"
Browse files Browse the repository at this point in the history
This reverts commit 316cd4a.
  • Loading branch information
hperl committed Nov 8, 2023
1 parent 8cc83bc commit 7c54c9f
Show file tree
Hide file tree
Showing 14 changed files with 239 additions and 138 deletions.
52 changes: 27 additions & 25 deletions courier/courier.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,12 @@ func NewCourier(ctx context.Context, deps Dependencies) (Courier, error) {
if err != nil {
return nil, err
}

expBackoff := backoff.NewExponentialBackOff()
// never stop retrying
expBackoff.MaxElapsedTime = 0

return &courier{
smsClient: newSMS(ctx, deps),
smtpClient: smtp,
httpClient: newHTTP(ctx, deps),
deps: deps,
backoff: expBackoff,
backoff: backoff.NewExponentialBackOff(),
}, nil
}

Expand All @@ -84,29 +79,36 @@ func (c *courier) FailOnDispatchError() {
}

func (c *courier) Work(ctx context.Context) error {
wait := c.deps.CourierConfig().CourierWorkerPullWait(ctx)
for {
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return nil
}
return ctx.Err()
case <-time.After(wait):
if err := backoff.Retry(func() error {
if err := c.DispatchQueue(ctx); err != nil {
return err
}
// when we succeed, we want to reset the backoff
c.backoff.Reset()
return nil
}, c.backoff); err != nil {
return err
}
errChan := make(chan error)
defer close(errChan)

go c.watchMessages(ctx, errChan)

select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return nil
}
return ctx.Err()
case err := <-errChan:
return err
}
}

func (c *courier) UseBackoff(b backoff.BackOff) {
c.backoff = b
}

func (c *courier) watchMessages(ctx context.Context, errChan chan error) {
wait := c.deps.CourierConfig().CourierWorkerPullWait(ctx)
c.backoff.Reset()
for {
if err := backoff.Retry(func() error {
return c.DispatchQueue(ctx)
}, c.backoff); err != nil {
errChan <- err
return
}
time.Sleep(wait)
}
}
35 changes: 23 additions & 12 deletions courier/courier_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,13 @@ func (c *courier) DispatchQueue(ctx context.Context) error {

messages, err := c.deps.CourierPersister().NextMessages(ctx, uint8(pullCount))
if err != nil {
if errors.Is(err, ErrQueueEmpty) {
return nil
}
return err
}

for _, msg := range messages {
for k, msg := range messages {
if msg.SendCount > maxRetries {
if err := c.deps.CourierPersister().SetMessageStatus(ctx, msg.ID, MessageStatusAbandoned); err != nil {
c.deps.Logger().
Expand All @@ -77,33 +80,41 @@ func (c *courier) DispatchQueue(ctx context.Context) error {
WithField("message_id", msg.ID).
WithField("message_nid", msg.NID).
Warnf(`Message was abandoned because it did not deliver after %d attempts`, msg.SendCount)
continue
}

if err := c.DispatchMessage(ctx, msg); err != nil {
} else if err := c.DispatchMessage(ctx, msg); err != nil {
if err := c.deps.CourierPersister().RecordDispatch(ctx, msg.ID, CourierMessageDispatchStatusFailed, err); err != nil {
c.deps.Logger().
WithError(err).
WithField("message_id", msg.ID).
WithField("message_nid", msg.NID).
Error(`Unable to record failure log entry.`)
return err
if c.failOnDispatchError {
return err
}
}

for _, replace := range messages[k:] {
if err := c.deps.CourierPersister().SetMessageStatus(ctx, replace.ID, MessageStatusQueued); err != nil {
c.deps.Logger().
WithError(err).
WithField("message_id", replace.ID).
WithField("message_nid", replace.NID).
Error(`Unable to reset the failed message's status to "queued".`)
if c.failOnDispatchError {
return err
}
}
}

if c.failOnDispatchError {
return err
}
// an error happened, but we want to ignore it
continue
}

if err := c.deps.CourierPersister().RecordDispatch(ctx, msg.ID, CourierMessageDispatchStatusSuccess, nil); err != nil {
} else if err := c.deps.CourierPersister().RecordDispatch(ctx, msg.ID, CourierMessageDispatchStatusSuccess, nil); err != nil {
c.deps.Logger().
WithError(err).
WithField("message_id", msg.ID).
WithField("message_nid", msg.NID).
Error(`Unable to record success log entry.`)
return err
// continue with execution, as the message was successfully dispatched
}
}

Expand Down
9 changes: 4 additions & 5 deletions courier/courier_dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,15 @@ func TestDispatchMessageWithInvalidSMTP(t *testing.T) {

t.Run("case=failed sending", func(t *testing.T) {
id := queueNewMessage(t, ctx, c, reg)
messages, err := reg.CourierPersister().NextMessages(ctx, 10)
message, err := reg.CourierPersister().LatestQueuedMessage(ctx)
require.NoError(t, err)
require.Len(t, messages, 1)
require.Equal(t, id, messages[0].ID)
require.Equal(t, id, message.ID)

err = c.DispatchMessage(ctx, messages[0])
err = c.DispatchMessage(ctx, *message)
// sending the email fails, because there is no SMTP server at foo.url
require.Error(t, err)

messages, err = reg.CourierPersister().NextMessages(ctx, 10)
messages, err := reg.CourierPersister().NextMessages(ctx, 10)
require.NoError(t, err)
require.Len(t, messages, 1)
})
Expand Down
16 changes: 8 additions & 8 deletions courier/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func TestHandler(t *testing.T) {
t.Run("case=list messages", func(t *testing.T) {
// Arrange test data
const msgCount = 10 // total message count
const sentCount = 5 // how many messages' status should be equal to `processing`
const procCount = 5 // how many messages' status should be equal to `processing`
const rcptOryCount = 2 // how many messages' recipient should be equal to `[email protected]`
messages := make([]courier.Message, msgCount)

Expand All @@ -109,8 +109,8 @@ func TestHandler(t *testing.T) {
}
require.NoError(t, reg.CourierPersister().AddMessage(context.Background(), &messages[i]))
}
for i := 0; i < sentCount; i++ {
require.NoError(t, reg.CourierPersister().SetMessageStatus(context.Background(), messages[i].ID, courier.MessageStatusSent))
for i := 0; i < procCount; i++ {
require.NoError(t, reg.CourierPersister().SetMessageStatus(context.Background(), messages[i].ID, courier.MessageStatusProcessing))
}

t.Run("paging", func(t *testing.T) {
Expand Down Expand Up @@ -146,24 +146,24 @@ func TestHandler(t *testing.T) {
for _, tc := range tss {
t.Run("endpoint="+tc.name, func(t *testing.T) {
parsed := getList(t, tc.name, qs)
assert.Len(t, parsed.Array(), msgCount-sentCount)
assert.Len(t, parsed.Array(), msgCount-procCount)

for _, item := range parsed.Array() {
assert.Equal(t, "queued", item.Get("status").String())
}
})
}
})
t.Run("case=should return all sent messages", func(t *testing.T) {
qs := fmt.Sprintf(`?page_token=%s&page_size=250&status=sent`, defaultPageToken)
t.Run("case=should return all processing messages", func(t *testing.T) {
qs := fmt.Sprintf(`?page_token=%s&page_size=250&status=processing`, defaultPageToken)

for _, tc := range tss {
t.Run("endpoint="+tc.name, func(t *testing.T) {
parsed := getList(t, tc.name, qs)
assert.Len(t, parsed.Array(), sentCount)
assert.Len(t, parsed.Array(), procCount)

for _, item := range parsed.Array() {
assert.Equal(t, "sent", item.Get("status").String())
assert.Equal(t, "processing", item.Get("status").String())
}
})
}
Expand Down
15 changes: 10 additions & 5 deletions courier/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ type MessageStatus int
const (
MessageStatusQueued MessageStatus = iota + 1
MessageStatusSent
_
MessageStatusProcessing
MessageStatusAbandoned
)

const (
messageStatusQueuedText = "queued"
messageStatusSentText = "sent"
messageStatusAbandonedText = "abandoned"
messageStatusQueuedText = "queued"
messageStatusSentText = "sent"
messageStatusProcessingText = "processing"
messageStatusAbandonedText = "abandoned"
)

func ToMessageStatus(str string) (MessageStatus, error) {
Expand All @@ -40,6 +41,8 @@ func ToMessageStatus(str string) (MessageStatus, error) {
return MessageStatusQueued, nil
case s.AddCase(MessageStatusSent.String()):
return MessageStatusSent, nil
case s.AddCase(MessageStatusProcessing.String()):
return MessageStatusProcessing, nil
case s.AddCase(MessageStatusAbandoned.String()):
return MessageStatusAbandoned, nil
default:
Expand All @@ -53,6 +56,8 @@ func (ms MessageStatus) String() string {
return messageStatusQueuedText
case MessageStatusSent:
return messageStatusSentText
case MessageStatusProcessing:
return messageStatusProcessingText
case MessageStatusAbandoned:
return messageStatusAbandonedText
default:
Expand All @@ -62,7 +67,7 @@ func (ms MessageStatus) String() string {

func (ms MessageStatus) IsValid() error {
switch ms {
case MessageStatusQueued, MessageStatusSent, MessageStatusAbandoned:
case MessageStatusQueued, MessageStatusSent, MessageStatusProcessing, MessageStatusAbandoned:
return nil
default:
return errors.WithStack(herodot.ErrBadRequest.WithReason("Message status is not valid"))
Expand Down
7 changes: 4 additions & 3 deletions courier/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ func TestMessageStatusValidity(t *testing.T) {
func TestToMessageStatus(t *testing.T) {
t.Run("case=should return corresponding MessageStatus for given str", func(t *testing.T) {
for str, exp := range map[string]courier.MessageStatus{
"queued": courier.MessageStatusQueued,
"sent": courier.MessageStatusSent,
"abandoned": courier.MessageStatusAbandoned,
"queued": courier.MessageStatusQueued,
"sent": courier.MessageStatusSent,
"processing": courier.MessageStatusProcessing,
"abandoned": courier.MessageStatusAbandoned,
} {
result, err := courier.ToMessageStatus(str)
require.NoError(t, err)
Expand Down
5 changes: 5 additions & 0 deletions courier/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ import (
"context"

"github.com/gofrs/uuid"
"github.com/pkg/errors"

"github.com/ory/x/pagination/keysetpagination"
)

var ErrQueueEmpty = errors.New("queue is empty")

type (
Persister interface {
AddMessage(context.Context, *Message) error
Expand All @@ -19,6 +22,8 @@ type (

SetMessageStatus(context.Context, uuid.UUID, MessageStatus) error

LatestQueuedMessage(ctx context.Context) (*Message, error)

IncrementMessageSendCount(context.Context, uuid.UUID) error

// ListMessages lists all messages in the store given the page, itemsPerPage, status and recipient.
Expand Down
Loading

0 comments on commit 7c54c9f

Please sign in to comment.