Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into hperl/link-credenti…
Browse files Browse the repository at this point in the history
…als-when-login
  • Loading branch information
hperl committed Nov 7, 2023
2 parents 9f32943 + 3b75f37 commit 5b9bda5
Show file tree
Hide file tree
Showing 15 changed files with 140 additions and 241 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

**Table of Contents**

- [ (2023-10-30)](#2023-10-30)
- [ (2023-11-06)](#2023-11-06)
- [Breaking Changes](#breaking-changes)
- [Bug Fixes](#bug-fixes)
- [Documentation](#documentation)
Expand Down Expand Up @@ -313,7 +313,7 @@

<!-- END doctoc generated TOC please keep comment here to allow auto update -->

# [](https://github.com/ory/kratos/compare/v1.0.0...v) (2023-10-30)
# [](https://github.com/ory/kratos/compare/v1.0.0...v) (2023-11-06)

## Breaking Changes

Expand Down
52 changes: 25 additions & 27 deletions courier/courier.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,17 @@ func NewCourier(ctx context.Context, deps Dependencies) (Courier, error) {
if err != nil {
return nil, err
}

expBackoff := backoff.NewExponentialBackOff()
// never stop retrying
expBackoff.MaxElapsedTime = 0

return &courier{
smsClient: newSMS(ctx, deps),
smtpClient: smtp,
httpClient: newHTTP(ctx, deps),
deps: deps,
backoff: backoff.NewExponentialBackOff(),
backoff: expBackoff,
}, nil
}

Expand All @@ -79,36 +84,29 @@ func (c *courier) FailOnDispatchError() {
}

func (c *courier) Work(ctx context.Context) error {
errChan := make(chan error)
defer close(errChan)

go c.watchMessages(ctx, errChan)

select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return nil
wait := c.deps.CourierConfig().CourierWorkerPullWait(ctx)
for {
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return nil
}
return ctx.Err()
case <-time.After(wait):
if err := backoff.Retry(func() error {
if err := c.DispatchQueue(ctx); err != nil {
return err
}
// when we succeed, we want to reset the backoff
c.backoff.Reset()
return nil
}, c.backoff); err != nil {
return err
}
}
return ctx.Err()
case err := <-errChan:
return err
}
}

func (c *courier) UseBackoff(b backoff.BackOff) {
c.backoff = b
}

func (c *courier) watchMessages(ctx context.Context, errChan chan error) {
wait := c.deps.CourierConfig().CourierWorkerPullWait(ctx)
c.backoff.Reset()
for {
if err := backoff.Retry(func() error {
return c.DispatchQueue(ctx)
}, c.backoff); err != nil {
errChan <- err
return
}
time.Sleep(wait)
}
}
35 changes: 12 additions & 23 deletions courier/courier_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,10 @@ func (c *courier) DispatchQueue(ctx context.Context) error {

messages, err := c.deps.CourierPersister().NextMessages(ctx, uint8(pullCount))
if err != nil {
if errors.Is(err, ErrQueueEmpty) {
return nil
}
return err
}

for k, msg := range messages {
for _, msg := range messages {
if msg.SendCount > maxRetries {
if err := c.deps.CourierPersister().SetMessageStatus(ctx, msg.ID, MessageStatusAbandoned); err != nil {
c.deps.Logger().
Expand All @@ -80,41 +77,33 @@ func (c *courier) DispatchQueue(ctx context.Context) error {
WithField("message_id", msg.ID).
WithField("message_nid", msg.NID).
Warnf(`Message was abandoned because it did not deliver after %d attempts`, msg.SendCount)
} else if err := c.DispatchMessage(ctx, msg); err != nil {
continue
}

if err := c.DispatchMessage(ctx, msg); err != nil {
if err := c.deps.CourierPersister().RecordDispatch(ctx, msg.ID, CourierMessageDispatchStatusFailed, err); err != nil {
c.deps.Logger().
WithError(err).
WithField("message_id", msg.ID).
WithField("message_nid", msg.NID).
Error(`Unable to record failure log entry.`)
if c.failOnDispatchError {
return err
}
}

for _, replace := range messages[k:] {
if err := c.deps.CourierPersister().SetMessageStatus(ctx, replace.ID, MessageStatusQueued); err != nil {
c.deps.Logger().
WithError(err).
WithField("message_id", replace.ID).
WithField("message_nid", replace.NID).
Error(`Unable to reset the failed message's status to "queued".`)
if c.failOnDispatchError {
return err
}
}
return err
}

if c.failOnDispatchError {
return err
}
} else if err := c.deps.CourierPersister().RecordDispatch(ctx, msg.ID, CourierMessageDispatchStatusSuccess, nil); err != nil {
// an error happened, but we want to ignore it
continue
}

if err := c.deps.CourierPersister().RecordDispatch(ctx, msg.ID, CourierMessageDispatchStatusSuccess, nil); err != nil {
c.deps.Logger().
WithError(err).
WithField("message_id", msg.ID).
WithField("message_nid", msg.NID).
Error(`Unable to record success log entry.`)
// continue with execution, as the message was successfully dispatched
return err
}
}

Expand Down
9 changes: 5 additions & 4 deletions courier/courier_dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,16 @@ func TestDispatchMessageWithInvalidSMTP(t *testing.T) {

t.Run("case=failed sending", func(t *testing.T) {
id := queueNewMessage(t, ctx, c, reg)
message, err := reg.CourierPersister().LatestQueuedMessage(ctx)
messages, err := reg.CourierPersister().NextMessages(ctx, 10)
require.NoError(t, err)
require.Equal(t, id, message.ID)
require.Len(t, messages, 1)
require.Equal(t, id, messages[0].ID)

err = c.DispatchMessage(ctx, *message)
err = c.DispatchMessage(ctx, messages[0])
// sending the email fails, because there is no SMTP server at foo.url
require.Error(t, err)

messages, err := reg.CourierPersister().NextMessages(ctx, 10)
messages, err = reg.CourierPersister().NextMessages(ctx, 10)
require.NoError(t, err)
require.Len(t, messages, 1)
})
Expand Down
16 changes: 8 additions & 8 deletions courier/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func TestHandler(t *testing.T) {
t.Run("case=list messages", func(t *testing.T) {
// Arrange test data
const msgCount = 10 // total message count
const procCount = 5 // how many messages' status should be equal to `processing`
const sentCount = 5 // how many messages' status should be equal to `processing`
const rcptOryCount = 2 // how many messages' recipient should be equal to `[email protected]`
messages := make([]courier.Message, msgCount)

Expand All @@ -109,8 +109,8 @@ func TestHandler(t *testing.T) {
}
require.NoError(t, reg.CourierPersister().AddMessage(context.Background(), &messages[i]))
}
for i := 0; i < procCount; i++ {
require.NoError(t, reg.CourierPersister().SetMessageStatus(context.Background(), messages[i].ID, courier.MessageStatusProcessing))
for i := 0; i < sentCount; i++ {
require.NoError(t, reg.CourierPersister().SetMessageStatus(context.Background(), messages[i].ID, courier.MessageStatusSent))
}

t.Run("paging", func(t *testing.T) {
Expand Down Expand Up @@ -146,24 +146,24 @@ func TestHandler(t *testing.T) {
for _, tc := range tss {
t.Run("endpoint="+tc.name, func(t *testing.T) {
parsed := getList(t, tc.name, qs)
assert.Len(t, parsed.Array(), msgCount-procCount)
assert.Len(t, parsed.Array(), msgCount-sentCount)

for _, item := range parsed.Array() {
assert.Equal(t, "queued", item.Get("status").String())
}
})
}
})
t.Run("case=should return all processing messages", func(t *testing.T) {
qs := fmt.Sprintf(`?page_token=%s&page_size=250&status=processing`, defaultPageToken)
t.Run("case=should return all sent messages", func(t *testing.T) {
qs := fmt.Sprintf(`?page_token=%s&page_size=250&status=sent`, defaultPageToken)

for _, tc := range tss {
t.Run("endpoint="+tc.name, func(t *testing.T) {
parsed := getList(t, tc.name, qs)
assert.Len(t, parsed.Array(), procCount)
assert.Len(t, parsed.Array(), sentCount)

for _, item := range parsed.Array() {
assert.Equal(t, "processing", item.Get("status").String())
assert.Equal(t, "sent", item.Get("status").String())
}
})
}
Expand Down
15 changes: 5 additions & 10 deletions courier/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,14 @@ type MessageStatus int
const (
MessageStatusQueued MessageStatus = iota + 1
MessageStatusSent
MessageStatusProcessing
_
MessageStatusAbandoned
)

const (
messageStatusQueuedText = "queued"
messageStatusSentText = "sent"
messageStatusProcessingText = "processing"
messageStatusAbandonedText = "abandoned"
messageStatusQueuedText = "queued"
messageStatusSentText = "sent"
messageStatusAbandonedText = "abandoned"
)

func ToMessageStatus(str string) (MessageStatus, error) {
Expand All @@ -41,8 +40,6 @@ func ToMessageStatus(str string) (MessageStatus, error) {
return MessageStatusQueued, nil
case s.AddCase(MessageStatusSent.String()):
return MessageStatusSent, nil
case s.AddCase(MessageStatusProcessing.String()):
return MessageStatusProcessing, nil
case s.AddCase(MessageStatusAbandoned.String()):
return MessageStatusAbandoned, nil
default:
Expand All @@ -56,8 +53,6 @@ func (ms MessageStatus) String() string {
return messageStatusQueuedText
case MessageStatusSent:
return messageStatusSentText
case MessageStatusProcessing:
return messageStatusProcessingText
case MessageStatusAbandoned:
return messageStatusAbandonedText
default:
Expand All @@ -67,7 +62,7 @@ func (ms MessageStatus) String() string {

func (ms MessageStatus) IsValid() error {
switch ms {
case MessageStatusQueued, MessageStatusSent, MessageStatusProcessing, MessageStatusAbandoned:
case MessageStatusQueued, MessageStatusSent, MessageStatusAbandoned:
return nil
default:
return errors.WithStack(herodot.ErrBadRequest.WithReason("Message status is not valid"))
Expand Down
7 changes: 3 additions & 4 deletions courier/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ func TestMessageStatusValidity(t *testing.T) {
func TestToMessageStatus(t *testing.T) {
t.Run("case=should return corresponding MessageStatus for given str", func(t *testing.T) {
for str, exp := range map[string]courier.MessageStatus{
"queued": courier.MessageStatusQueued,
"sent": courier.MessageStatusSent,
"processing": courier.MessageStatusProcessing,
"abandoned": courier.MessageStatusAbandoned,
"queued": courier.MessageStatusQueued,
"sent": courier.MessageStatusSent,
"abandoned": courier.MessageStatusAbandoned,
} {
result, err := courier.ToMessageStatus(str)
require.NoError(t, err)
Expand Down
5 changes: 0 additions & 5 deletions courier/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@ import (
"context"

"github.com/gofrs/uuid"
"github.com/pkg/errors"

"github.com/ory/x/pagination/keysetpagination"
)

var ErrQueueEmpty = errors.New("queue is empty")

type (
Persister interface {
AddMessage(context.Context, *Message) error
Expand All @@ -22,8 +19,6 @@ type (

SetMessageStatus(context.Context, uuid.UUID, MessageStatus) error

LatestQueuedMessage(ctx context.Context) (*Message, error)

IncrementMessageSendCount(context.Context, uuid.UUID) error

// ListMessages lists all messages in the store given the page, itemsPerPage, status and recipient.
Expand Down
Loading

0 comments on commit 5b9bda5

Please sign in to comment.