Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: order sessions by created_at #3696

Merged
merged 1 commit into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions courier/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/ory/herodot"
"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/x"
"github.com/ory/x/pagination/keysetpagination"
"github.com/ory/x/sqlxx"
"github.com/ory/x/stringsx"
Expand Down Expand Up @@ -115,9 +116,6 @@ const (
messageTypeSMSText = "sms"
)

// The format we need to use in the Page tokens, as it's the only format that is understood by all DBs
const dbFormat = "2006-01-02 15:04:05.99999"

func ToMessageType(str string) (MessageType, error) {
switch s := stringsx.SwitchExact(str); {
case s.AddCase(messageTypeEmailText):
Expand Down Expand Up @@ -211,14 +209,14 @@ type Message struct {
func (m Message) PageToken() keysetpagination.PageToken {
return keysetpagination.MapPageToken{
"id": m.ID.String(),
"created_at": m.CreatedAt.Format(dbFormat),
"created_at": m.CreatedAt.Format(x.MapPaginationDateFormat),
}
}

func (m Message) DefaultPageToken() keysetpagination.PageToken {
return keysetpagination.MapPageToken{
"id": uuid.Nil.String(),
"created_at": time.Date(2200, 12, 31, 23, 59, 59, 0, time.UTC).Format(dbFormat),
"created_at": time.Date(2200, 12, 31, 23, 59, 59, 0, time.UTC).Format(x.MapPaginationDateFormat),
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DROP INDEX sessions_nid_created_at_id_idx;
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DROP INDEX sessions_nid_created_at_id_idx ON sessions;
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE INDEX sessions_nid_created_at_id_idx ON sessions (nid, created_at DESC, id ASC);
11 changes: 7 additions & 4 deletions persistence/sql/persister_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ import (

var _ session.Persister = new(Persister)

const SessionDeviceUserAgentMaxLength = 512
const SessionDeviceLocationMaxLength = 512
const paginationMaxItemsSize = 1000
const paginationDefaultItemsSize = 250
const (
SessionDeviceUserAgentMaxLength = 512
SessionDeviceLocationMaxLength = 512
paginationMaxItemsSize = 1000
paginationDefaultItemsSize = 250
)

func (p *Persister) GetSession(ctx context.Context, sid uuid.UUID, expandables session.Expandables) (_ *session.Session, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetSession")
Expand Down Expand Up @@ -73,6 +75,7 @@ func (p *Persister) ListSessions(ctx context.Context, active *bool, paginatorOpt
paginatorOpts = append(paginatorOpts, keysetpagination.WithDefaultSize(paginationDefaultItemsSize))
paginatorOpts = append(paginatorOpts, keysetpagination.WithMaxSize(paginationMaxItemsSize))
paginatorOpts = append(paginatorOpts, keysetpagination.WithDefaultToken(new(session.Session).DefaultPageToken()))
paginatorOpts = append(paginatorOpts, keysetpagination.WithColumn("created_at", "DESC"))
paginator := keysetpagination.GetPaginator(paginatorOpts...)

if err := p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
Expand Down
2 changes: 1 addition & 1 deletion session/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ func (h *Handler) adminListSessions(w http.ResponseWriter, r *http.Request, ps h
}

// Parse request pagination parameters
opts, err := keysetpagination.Parse(r.URL.Query(), keysetpagination.NewStringPageToken)
opts, err := keysetpagination.Parse(r.URL.Query(), keysetpagination.NewMapPageToken)
if err != nil {
h.r.Writer().WriteError(w, r, herodot.ErrBadRequest.WithError("could not parse parameter page_size"))
return
Expand Down
21 changes: 19 additions & 2 deletions session/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"sort"
"strconv"
"strings"
"testing"
"time"

"github.com/go-faker/faker/v4"
"github.com/peterhellberg/link"
"github.com/tidwall/gjson"

"github.com/ory/kratos/identity"
Expand All @@ -26,6 +28,7 @@ import (
"github.com/pkg/errors"

"github.com/ory/kratos/corpx"
"github.com/ory/x/pagination/keysetpagination"
"github.com/ory/x/sqlcon"

"github.com/julienschmidt/httprouter"
Expand Down Expand Up @@ -557,13 +560,27 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
require.Equal(t, ts.URL+"/sessions/whoami", res.Header.Get("Location"))
})

assertPageToken := func(t *testing.T, id, linkHeader string) {
t.Helper()

g := link.Parse(linkHeader)
require.Len(t, g, 1)
u, err := url.Parse(g["first"].URI)
require.NoError(t, err)
pt, err := keysetpagination.NewMapPageToken(u.Query().Get("page_token"))
require.NoError(t, err)
mpt := pt.(keysetpagination.MapPageToken)
assert.Equal(t, id, mpt["id"])
}

t.Run("list sessions", func(t *testing.T) {
req, _ := http.NewRequest("GET", ts.URL+"/admin/sessions/", nil)
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, "1", res.Header.Get("X-Total-Count"))
assert.Equal(t, "</admin/sessions?page_size=250&page_token=00000000-0000-0000-0000-000000000000>; rel=\"first\"", res.Header.Get("Link"))

assertPageToken(t, uuid.Nil.String(), res.Header.Get("Link"))

var sessions []Session
require.NoError(t, json.NewDecoder(res.Body).Decode(&sessions))
Expand Down Expand Up @@ -611,7 +628,7 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, "1", res.Header.Get("X-Total-Count"))
assert.Equal(t, "</admin/sessions?"+tc.expand+"page_size=250&page_token=00000000-0000-0000-0000-000000000000>; rel=\"first\"", res.Header.Get("Link"))
assertPageToken(t, uuid.Nil.String(), res.Header.Get("Link"))

body := ioutilx.MustReadAll(res.Body)
assert.Equal(t, s.ID.String(), gjson.GetBytes(body, "0.id").String())
Expand Down
12 changes: 9 additions & 3 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,17 @@ type Session struct {
}

func (s Session) PageToken() keysetpagination.PageToken {
return keysetpagination.StringPageToken(s.ID.String())
return keysetpagination.MapPageToken{
"id": s.ID.String(),
"created_at": s.CreatedAt.Format(x.MapPaginationDateFormat),
}
}

func (s Session) DefaultPageToken() keysetpagination.PageToken {
return keysetpagination.StringPageToken(uuid.Nil.String())
func (m Session) DefaultPageToken() keysetpagination.PageToken {
return keysetpagination.MapPageToken{
"id": uuid.Nil.String(),
"created_at": time.Date(2200, 12, 31, 23, 59, 59, 0, time.UTC).Format(x.MapPaginationDateFormat),
}
}

func (s Session) TableName(ctx context.Context) string {
Expand Down
29 changes: 22 additions & 7 deletions session/test/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"testing"
"time"

"github.com/gobuffalo/pop/v6"

"github.com/ory/x/pagination/keysetpagination"

"github.com/ory/x/pointerx"
Expand All @@ -30,7 +32,8 @@ import (

func TestPersister(ctx context.Context, conf *config.Config, p interface {
persistence.Persister
}) func(t *testing.T) {
},
) func(t *testing.T) {
return func(t *testing.T) {
_, p := testhelpers.NewNetworkUnlessExisting(t, ctx, p)

Expand Down Expand Up @@ -149,6 +152,7 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {

seedSessionIDs := make([]uuid.UUID, 5)
seedSessionsList := make([]session.Session, 5)
start := time.Now()
for j := range seedSessionsList {
require.NoError(t, faker.FakeData(&seedSessionsList[j]))
seedSessionsList[j].Identity = &identity1
Expand All @@ -165,9 +169,13 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
seedSessionsList[j].Devices = []session.Device{
device,
}
pop.SetNowFunc(func() time.Time {
return start.Add(time.Duration(j) * time.Minute)
})
require.NoError(t, l.UpsertSession(ctx, &seedSessionsList[j]))
seedSessionIDs[j] = seedSessionsList[j].ID
}
pop.SetNowFunc(time.Now)

identity2Session.Identity = &identity2
identity2Session.Active = true
Expand Down Expand Up @@ -288,7 +296,10 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
require.Equal(t, len(tc.expected), len(actual))
require.Equal(t, int64(len(tc.expected)), total)
assert.Equal(t, true, nextPage.IsLast())
assert.Equal(t, uuid.Nil.String(), nextPage.Token().Encode())

mapPageToken := nextPage.Token().Parse("")
assert.Equal(t, uuid.Nil.String(), mapPageToken["id"])

assert.Equal(t, 250, nextPage.Size())
for _, es := range tc.expected {
found := false
Expand All @@ -312,7 +323,8 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
require.Equal(t, 6, len(actual))
require.Equal(t, int64(6), total)
assert.Equal(t, true, page.IsLast())
assert.Equal(t, uuid.Nil.String(), page.Token().Encode())
mapPageToken := page.Token().Parse("")
assert.Equal(t, uuid.Nil.String(), mapPageToken["id"])
assert.Equal(t, 250, page.Size())
})

Expand All @@ -325,21 +337,24 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
assert.Len(t, firstPageItems, 3)

assert.Equal(t, false, page1.IsLast())
assert.Equal(t, firstPageItems[len(firstPageItems)-1].ID.String(), page1.Token().Encode())
mapPageToken := page1.Token().Parse("")
assert.Equal(t, firstPageItems[len(firstPageItems)-1].ID.String(), mapPageToken["id"])
assert.Equal(t, 3, page1.Size())

// Validate secondPageItems page
secondPageItems, total, page2, err := l.ListSessions(ctx, nil, page1.ToOptions(), session.ExpandEverything)
require.NoError(t, err)
require.Equal(t, int64(6), total)
assert.Len(t, secondPageItems, 3)

acutalIDs := make([]uuid.UUID, 0)
for _, s := range append(firstPageItems, secondPageItems...) {
acutalIDs = append(acutalIDs, s.ID)
}
assert.ElementsMatch(t, append(seedSessionIDs, identity2Session.ID), acutalIDs)
expect := append(seedSessionIDs, identity2Session.ID)
require.Len(t, acutalIDs, len(expect))
assert.ElementsMatch(t, expect, acutalIDs)

require.Equal(t, int64(6), total)
assert.Len(t, secondPageItems, 3)
assert.True(t, page2.IsLast())
assert.Equal(t, 3, page2.Size())
})
Expand Down
3 changes: 3 additions & 0 deletions x/pagination.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import (
"github.com/ory/x/pagination/pagepagination"
)

// The format we need to use in the Page tokens, as it's the only format that is understood by all DBs
const MapPaginationDateFormat = "2006-01-02 15:04:05.99999"

// ParsePagination parses limit and page from *http.Request with given limits and defaults.
func ParsePagination(r *http.Request) (page, itemsPerPage int) {
return migrationpagination.NewDefaultPaginator().ParsePagination(r)
Expand Down
Loading