Skip to content

Commit

Permalink
feat: order sessions by created_at
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-jonas committed Jan 19, 2024
1 parent 55560a1 commit 6460178
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 22 deletions.
8 changes: 3 additions & 5 deletions courier/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/ory/herodot"
"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/x"
"github.com/ory/x/pagination/keysetpagination"
"github.com/ory/x/sqlxx"
"github.com/ory/x/stringsx"
Expand Down Expand Up @@ -115,9 +116,6 @@ const (
messageTypeSMSText = "sms"
)

// The format we need to use in the Page tokens, as it's the only format that is understood by all DBs
const dbFormat = "2006-01-02 15:04:05.99999"

func ToMessageType(str string) (MessageType, error) {
switch s := stringsx.SwitchExact(str); {
case s.AddCase(messageTypeEmailText):
Expand Down Expand Up @@ -211,14 +209,14 @@ type Message struct {
func (m Message) PageToken() keysetpagination.PageToken {
return keysetpagination.MapPageToken{
"id": m.ID.String(),
"created_at": m.CreatedAt.Format(dbFormat),
"created_at": m.CreatedAt.Format(x.MapPaginationDateFormat),
}
}

func (m Message) DefaultPageToken() keysetpagination.PageToken {
return keysetpagination.MapPageToken{
"id": uuid.Nil.String(),
"created_at": time.Date(2200, 12, 31, 23, 59, 59, 0, time.UTC).Format(dbFormat),
"created_at": time.Date(2200, 12, 31, 23, 59, 59, 0, time.UTC).Format(x.MapPaginationDateFormat),
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DROP INDEX sessions_nid_created_at_id_idx;
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DROP INDEX sessions_nid_created_at_id_idx ON sessions;
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE INDEX sessions_nid_created_at_id_idx ON sessions (nid, created_at DESC, id ASC);
11 changes: 7 additions & 4 deletions persistence/sql/persister_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ import (

var _ session.Persister = new(Persister)

const SessionDeviceUserAgentMaxLength = 512
const SessionDeviceLocationMaxLength = 512
const paginationMaxItemsSize = 1000
const paginationDefaultItemsSize = 250
const (
SessionDeviceUserAgentMaxLength = 512
SessionDeviceLocationMaxLength = 512
paginationMaxItemsSize = 1000
paginationDefaultItemsSize = 250
)

func (p *Persister) GetSession(ctx context.Context, sid uuid.UUID, expandables session.Expandables) (_ *session.Session, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetSession")
Expand Down Expand Up @@ -73,6 +75,7 @@ func (p *Persister) ListSessions(ctx context.Context, active *bool, paginatorOpt
paginatorOpts = append(paginatorOpts, keysetpagination.WithDefaultSize(paginationDefaultItemsSize))
paginatorOpts = append(paginatorOpts, keysetpagination.WithMaxSize(paginationMaxItemsSize))
paginatorOpts = append(paginatorOpts, keysetpagination.WithDefaultToken(new(session.Session).DefaultPageToken()))
paginatorOpts = append(paginatorOpts, keysetpagination.WithColumn("created_at", "DESC"))
paginator := keysetpagination.GetPaginator(paginatorOpts...)

if err := p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
Expand Down
2 changes: 1 addition & 1 deletion session/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ func (h *Handler) adminListSessions(w http.ResponseWriter, r *http.Request, ps h
}

// Parse request pagination parameters
opts, err := keysetpagination.Parse(r.URL.Query(), keysetpagination.NewStringPageToken)
opts, err := keysetpagination.Parse(r.URL.Query(), keysetpagination.NewMapPageToken)
if err != nil {
h.r.Writer().WriteError(w, r, herodot.ErrBadRequest.WithError("could not parse parameter page_size"))
return
Expand Down
21 changes: 19 additions & 2 deletions session/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"sort"
"strconv"
"strings"
"testing"
"time"

"github.com/go-faker/faker/v4"
"github.com/peterhellberg/link"
"github.com/tidwall/gjson"

"github.com/ory/kratos/identity"
Expand All @@ -26,6 +28,7 @@ import (
"github.com/pkg/errors"

"github.com/ory/kratos/corpx"
"github.com/ory/x/pagination/keysetpagination"
"github.com/ory/x/sqlcon"

"github.com/julienschmidt/httprouter"
Expand Down Expand Up @@ -557,13 +560,27 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
require.Equal(t, ts.URL+"/sessions/whoami", res.Header.Get("Location"))
})

assertPageToken := func(t *testing.T, id, linkHeader string) {
t.Helper()

g := link.Parse(linkHeader)
require.Len(t, g, 1)
u, err := url.Parse(g["first"].URI)
require.NoError(t, err)
pt, err := keysetpagination.NewMapPageToken(u.Query().Get("page_token"))
require.NoError(t, err)
mpt := pt.(keysetpagination.MapPageToken)
assert.Equal(t, id, mpt["id"])
}

t.Run("list sessions", func(t *testing.T) {
req, _ := http.NewRequest("GET", ts.URL+"/admin/sessions/", nil)
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, "1", res.Header.Get("X-Total-Count"))
assert.Equal(t, "</admin/sessions?page_size=250&page_token=00000000-0000-0000-0000-000000000000>; rel=\"first\"", res.Header.Get("Link"))

assertPageToken(t, uuid.Nil.String(), res.Header.Get("Link"))

var sessions []Session
require.NoError(t, json.NewDecoder(res.Body).Decode(&sessions))
Expand Down Expand Up @@ -611,7 +628,7 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, "1", res.Header.Get("X-Total-Count"))
assert.Equal(t, "</admin/sessions?"+tc.expand+"page_size=250&page_token=00000000-0000-0000-0000-000000000000>; rel=\"first\"", res.Header.Get("Link"))
assertPageToken(t, uuid.Nil.String(), res.Header.Get("Link"))

body := ioutilx.MustReadAll(res.Body)
assert.Equal(t, s.ID.String(), gjson.GetBytes(body, "0.id").String())
Expand Down
12 changes: 9 additions & 3 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,17 @@ type Session struct {
}

func (s Session) PageToken() keysetpagination.PageToken {
return keysetpagination.StringPageToken(s.ID.String())
return keysetpagination.MapPageToken{
"id": s.ID.String(),
"created_at": s.CreatedAt.Format(x.MapPaginationDateFormat),
}
}

func (s Session) DefaultPageToken() keysetpagination.PageToken {
return keysetpagination.StringPageToken(uuid.Nil.String())
func (m Session) DefaultPageToken() keysetpagination.PageToken {
return keysetpagination.MapPageToken{
"id": uuid.Nil.String(),
"created_at": time.Date(2200, 12, 31, 23, 59, 59, 0, time.UTC).Format(x.MapPaginationDateFormat),
}
}

func (s Session) TableName(ctx context.Context) string {
Expand Down
29 changes: 22 additions & 7 deletions session/test/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"testing"
"time"

"github.com/gobuffalo/pop/v6"

"github.com/ory/x/pagination/keysetpagination"

"github.com/ory/x/pointerx"
Expand All @@ -30,7 +32,8 @@ import (

func TestPersister(ctx context.Context, conf *config.Config, p interface {
persistence.Persister
}) func(t *testing.T) {
},
) func(t *testing.T) {
return func(t *testing.T) {
_, p := testhelpers.NewNetworkUnlessExisting(t, ctx, p)

Expand Down Expand Up @@ -149,6 +152,7 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {

seedSessionIDs := make([]uuid.UUID, 5)
seedSessionsList := make([]session.Session, 5)
start := time.Now()
for j := range seedSessionsList {
require.NoError(t, faker.FakeData(&seedSessionsList[j]))
seedSessionsList[j].Identity = &identity1
Expand All @@ -165,9 +169,13 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
seedSessionsList[j].Devices = []session.Device{
device,
}
pop.SetNowFunc(func() time.Time {
return start.Add(time.Duration(j) * time.Minute)
})
require.NoError(t, l.UpsertSession(ctx, &seedSessionsList[j]))
seedSessionIDs[j] = seedSessionsList[j].ID
}
pop.SetNowFunc(time.Now)

identity2Session.Identity = &identity2
identity2Session.Active = true
Expand Down Expand Up @@ -288,7 +296,10 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
require.Equal(t, len(tc.expected), len(actual))
require.Equal(t, int64(len(tc.expected)), total)
assert.Equal(t, true, nextPage.IsLast())
assert.Equal(t, uuid.Nil.String(), nextPage.Token().Encode())

mapPageToken := nextPage.Token().Parse("")
assert.Equal(t, uuid.Nil.String(), mapPageToken["id"])

assert.Equal(t, 250, nextPage.Size())
for _, es := range tc.expected {
found := false
Expand All @@ -312,7 +323,8 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
require.Equal(t, 6, len(actual))
require.Equal(t, int64(6), total)
assert.Equal(t, true, page.IsLast())
assert.Equal(t, uuid.Nil.String(), page.Token().Encode())
mapPageToken := page.Token().Parse("")
assert.Equal(t, uuid.Nil.String(), mapPageToken["id"])
assert.Equal(t, 250, page.Size())
})

Expand All @@ -325,21 +337,24 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
assert.Len(t, firstPageItems, 3)

assert.Equal(t, false, page1.IsLast())
assert.Equal(t, firstPageItems[len(firstPageItems)-1].ID.String(), page1.Token().Encode())
mapPageToken := page1.Token().Parse("")
assert.Equal(t, firstPageItems[len(firstPageItems)-1].ID.String(), mapPageToken["id"])
assert.Equal(t, 3, page1.Size())

// Validate secondPageItems page
secondPageItems, total, page2, err := l.ListSessions(ctx, nil, page1.ToOptions(), session.ExpandEverything)
require.NoError(t, err)
require.Equal(t, int64(6), total)
assert.Len(t, secondPageItems, 3)

acutalIDs := make([]uuid.UUID, 0)
for _, s := range append(firstPageItems, secondPageItems...) {
acutalIDs = append(acutalIDs, s.ID)
}
assert.ElementsMatch(t, append(seedSessionIDs, identity2Session.ID), acutalIDs)
expect := append(seedSessionIDs, identity2Session.ID)
require.Len(t, acutalIDs, len(expect))
assert.ElementsMatch(t, expect, acutalIDs)

require.Equal(t, int64(6), total)
assert.Len(t, secondPageItems, 3)
assert.True(t, page2.IsLast())
assert.Equal(t, 3, page2.Size())
})
Expand Down
3 changes: 3 additions & 0 deletions x/pagination.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import (
"github.com/ory/x/pagination/pagepagination"
)

// The format we need to use in the Page tokens, as it's the only format that is understood by all DBs
const MapPaginationDateFormat = "2006-01-02 15:04:05.99999"

// ParsePagination parses limit and page from *http.Request with given limits and defaults.
func ParsePagination(r *http.Request) (page, itemsPerPage int) {
return migrationpagination.NewDefaultPaginator().ParsePagination(r)
Expand Down

0 comments on commit 6460178

Please sign in to comment.