Skip to content

Commit

Permalink
refactor: remove total count from listSessions and improve secondary …
Browse files Browse the repository at this point in the history
…indices (#4173)

This patch changes sorting to improve performance on list session endpoints. It also removes the `x-total-count` header from list responses.

BREAKING CHANGE: The total count header `x-total-count` will no longer be sent in response to `GET /admin/sessions` requests.

Closes ory-corp/cloud#7177
Closes ory-corp/cloud#7175
Closes ory-corp/cloud#7176
  • Loading branch information
aeneasr authored Oct 29, 2024
1 parent 825aec2 commit e24f993
Show file tree
Hide file tree
Showing 20 changed files with 123 additions and 44 deletions.
11 changes: 6 additions & 5 deletions persistence/sql/identity/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -761,11 +761,11 @@ func QueryForCredentials(con *pop.Connection, where ...Where) (map[uuid.UUID](ma
ici := "identity_credential_identifiers"
switch con.Dialect.Name() {
case "cockroach":
ici += "@identity_credential_identifiers_nid_identity_credential_id_idx"
ici += "@identity_credential_identifiers_identity_credential_id_idx"
case "sqlite3":
ici += " INDEXED BY identity_credential_identifiers_nid_identity_credential_id_idx"
ici += " INDEXED BY identity_credential_identifiers_identity_credential_id_idx"
case "mysql":
ici += " USE INDEX(identity_credential_identifiers_nid_identity_credential_id_idx)"
ici += " USE INDEX(identity_credential_identifiers_identity_credential_id_idx)"
default:
// good luck 🤷‍♂️
}
Expand Down Expand Up @@ -930,8 +930,9 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity.
// important to normalize the identifier before querying the database.

joins = params.TransformStatement(`
INNER JOIN identity_credentials ic ON ic.identity_id = identities.id
INNER JOIN identity_credential_identifiers ici ON ici.identity_credential_id = ic.id`)
INNER JOIN identity_credentials ic ON ic.identity_id = identities.id AND ic.nid = identities.nid
INNER JOIN identity_credential_identifiers ici ON ici.identity_credential_id = ic.id AND ici.nid = ic.nid
`)

wheres += fmt.Sprintf(`
AND ic.nid = ? AND ici.nid = ?
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
CREATE INDEX IF NOT EXISTS identity_credentials_id_nid_idx ON identity_credentials (id ASC, nid ASC);
CREATE INDEX IF NOT EXISTS identity_credentials_nid_id_idx ON identity_credentials (nid ASC, id ASC);
CREATE INDEX IF NOT EXISTS identity_credentials_nid_identity_id_idx ON identity_credentials (identity_id ASC, nid ASC);

DROP INDEX IF EXISTS identity_credentials_identity_id_idx;
DROP INDEX IF EXISTS identity_credentials_nid_idx;
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CREATE INDEX identity_credentials_id_nid_idx ON identity_credentials (id ASC, nid ASC);
CREATE INDEX identity_credentials_nid_id_idx ON identity_credentials (nid ASC, id ASC);

DROP INDEX identity_credentials_nid_idx ON identity_credentials;
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CREATE INDEX identity_credentials_nid_idx ON identity_credentials (nid ASC);

DROP INDEX identity_credentials_id_nid_idx ON identity_credentials;
DROP INDEX identity_credentials_nid_id_idx ON identity_credentials;

Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
CREATE INDEX IF NOT EXISTS identity_credentials_identity_id_idx ON identity_credentials (identity_id ASC);
CREATE INDEX IF NOT EXISTS identity_credentials_nid_idx ON identity_credentials (nid ASC);

DROP INDEX IF EXISTS identity_credentials_id_nid_idx;
DROP INDEX IF EXISTS identity_credentials_nid_id_idx;
DROP INDEX IF EXISTS identity_credentials_nid_identity_id_idx;
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
CREATE INDEX IF NOT EXISTS sessions_nid_id_identity_id_idx ON sessions(nid ASC, identity_id ASC, id ASC);
CREATE INDEX IF NOT EXISTS sessions_id_nid_idx ON sessions(id ASC, nid ASC);
CREATE INDEX IF NOT EXISTS sessions_token_nid_idx ON sessions(nid ASC, token ASC);
CREATE INDEX IF NOT EXISTS sessions_identity_id_nid_sorted_idx ON sessions(identity_id ASC, nid ASC, authenticated_at DESC);
CREATE INDEX IF NOT EXISTS sessions_nid_created_at_id_idx ON sessions(nid ASC, created_at DESC, id ASC);

DROP INDEX IF EXISTS sessions_list_idx;
DROP INDEX IF EXISTS sessions_list_active_idx;
DROP INDEX IF EXISTS sessions_list_identity_idx;
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
CREATE INDEX sessions_nid_id_identity_id_idx ON sessions(nid ASC, identity_id ASC, id ASC);
CREATE INDEX sessions_id_nid_idx ON sessions(id ASC, nid ASC);
CREATE INDEX sessions_token_nid_idx ON sessions(nid ASC, token ASC);
CREATE INDEX sessions_identity_id_nid_sorted_idx ON sessions(identity_id ASC, nid ASC, authenticated_at DESC);
CREATE INDEX sessions_nid_created_at_id_idx ON sessions(nid ASC, created_at DESC, id ASC);

DROP INDEX sessions_list_idx ON sessions;
DROP INDEX sessions_list_active_idx ON sessions;
DROP INDEX sessions_list_identity_idx ON sessions;
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
CREATE INDEX sessions_list_idx ON sessions (nid ASC, created_at DESC, id ASC);
CREATE INDEX sessions_list_active_idx ON sessions (nid ASC, expires_at ASC, active ASC, created_at DESC, id ASC);
CREATE INDEX sessions_list_identity_idx ON sessions (identity_id ASC, nid ASC, created_at DESC);

DROP INDEX sessions_nid_id_identity_id_idx ON sessions;
DROP INDEX sessions_id_nid_idx ON sessions;
DROP INDEX sessions_token_nid_idx ON sessions;
DROP INDEX sessions_identity_id_nid_sorted_idx ON sessions;
DROP INDEX sessions_nid_created_at_id_idx ON sessions;
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
CREATE INDEX IF NOT EXISTS sessions_list_idx ON sessions (nid ASC, created_at DESC, id ASC);
CREATE INDEX IF NOT EXISTS sessions_list_active_idx ON sessions (nid ASC, expires_at ASC, active ASC, created_at DESC, id ASC);
CREATE INDEX IF NOT EXISTS sessions_list_identity_idx ON sessions (identity_id ASC, nid ASC, created_at DESC);

DROP INDEX IF EXISTS sessions_nid_id_identity_id_idx;
DROP INDEX IF EXISTS sessions_id_nid_idx;
DROP INDEX IF EXISTS sessions_token_nid_idx;
DROP INDEX IF EXISTS sessions_identity_id_nid_sorted_idx;
DROP INDEX IF EXISTS sessions_nid_created_at_id_idx;
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
-- THIS IS COCKROACH ONLY
ALTER INDEX identity_credential_identifiers_identifier_nid_type_uq_idx RENAME TO identity_credential_identifiers_identifier_nid_type_uq_idx_deleteme;
CREATE UNIQUE INDEX IF NOT EXISTS identity_credential_identifiers_identifier_nid_type_uq_idx ON identity_credential_identifiers(nid ASC, identity_credential_type_id ASC, identifier ASC);
DROP INDEX IF EXISTS identity_credential_identifiers_identifier_nid_type_uq_idx_deleteme;
--

CREATE INDEX IF NOT EXISTS identity_credential_identifiers_nid_id_idx ON identity_credential_identifiers (nid ASC, id ASC);
CREATE INDEX IF NOT EXISTS identity_credential_identifiers_id_nid_idx ON identity_credential_identifiers (id ASC, nid ASC);
CREATE INDEX IF NOT EXISTS identity_credential_identifiers_nid_identity_credential_id_idx ON identity_credential_identifiers (identity_credential_id ASC, nid ASC);

DROP INDEX IF EXISTS identity_credential_identifiers_identity_credential_id_idx;
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
-- THIS IS COCKROACH ONLY
ALTER INDEX identity_credential_identifiers_identifier_nid_type_uq_idx RENAME TO identity_credential_identifiers_identifier_nid_type_uq_idx_deleteme;
CREATE UNIQUE INDEX IF NOT EXISTS identity_credential_identifiers_identifier_nid_type_uq_idx ON identity_credential_identifiers (nid ASC, identity_credential_type_id ASC, identifier ASC) STORING (identity_credential_id);
DROP INDEX IF EXISTS identity_credential_identifiers_identifier_nid_type_uq_idx_deleteme;
--

CREATE INDEX IF NOT EXISTS identity_credential_identifiers_identity_credential_id_idx ON identity_credential_identifiers (identity_credential_id ASC);

DROP INDEX IF EXISTS identity_credential_identifiers_nid_id_idx;
DROP INDEX IF EXISTS identity_credential_identifiers_id_nid_idx;
DROP INDEX IF EXISTS identity_credential_identifiers_nid_identity_credential_id_idx;
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CREATE INDEX IF NOT EXISTS identity_credential_identifiers_nid_id_idx ON identity_credential_identifiers (nid ASC, id ASC);
CREATE INDEX IF NOT EXISTS identity_credential_identifiers_id_nid_idx ON identity_credential_identifiers (id ASC, nid ASC);
CREATE INDEX IF NOT EXISTS identity_credential_identifiers_nid_identity_credential_id_idx ON identity_credential_identifiers (identity_credential_id ASC, nid ASC);

DROP INDEX IF EXISTS identity_credential_identifiers_identity_credential_id_idx;
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CREATE INDEX identity_credential_identifiers_nid_id_idx ON identity_credential_identifiers (nid ASC, id ASC);
CREATE INDEX identity_credential_identifiers_id_nid_idx ON identity_credential_identifiers (id ASC, nid ASC);
CREATE INDEX identity_credential_identifiers_nid_identity_credential_id_idx ON identity_credential_identifiers (identity_credential_id ASC, nid ASC);

DROP INDEX identity_credential_identifiers_identity_credential_id_idx ON identity_credential_identifiers;
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CREATE INDEX identity_credential_identifiers_identity_credential_id_idx ON identity_credential_identifiers (identity_credential_id ASC);

DROP INDEX identity_credential_identifiers_nid_id_idx ON identity_credential_identifiers;
DROP INDEX identity_credential_identifiers_id_nid_idx ON identity_credential_identifiers;
DROP INDEX identity_credential_identifiers_nid_identity_credential_id_idx ON identity_credential_identifiers;
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CREATE INDEX IF NOT EXISTS identity_credential_identifiers_identity_credential_id_idx ON identity_credential_identifiers (identity_credential_id ASC);

DROP INDEX IF EXISTS identity_credential_identifiers_nid_id_idx;
DROP INDEX IF EXISTS identity_credential_identifiers_id_nid_idx;
DROP INDEX IF EXISTS identity_credential_identifiers_nid_identity_credential_id_idx;
20 changes: 6 additions & 14 deletions persistence/sql/persister_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,11 @@ func (p *Persister) GetSession(ctx context.Context, sid uuid.UUID, expandables s
return &s, nil
}

func (p *Persister) ListSessions(ctx context.Context, active *bool, paginatorOpts []keysetpagination.Option, expandables session.Expandables) (_ []session.Session, _ int64, _ *keysetpagination.Paginator, err error) {
func (p *Persister) ListSessions(ctx context.Context, active *bool, paginatorOpts []keysetpagination.Option, expandables session.Expandables) (_ []session.Session, _ *keysetpagination.Paginator, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListSessions")
defer otelx.End(span, &err)

s := make([]session.Session, 0)
t := int64(0)
nid := p.NetworkID(ctx)

paginatorOpts = append(paginatorOpts, keysetpagination.WithDefaultSize(paginationDefaultItemsSize))
Expand All @@ -84,7 +83,7 @@ func (p *Persister) ListSessions(ctx context.Context, active *bool, paginatorOpt
paginator := keysetpagination.GetPaginator(paginatorOpts...)

if _, err := uuid.FromString(paginator.Token().Parse("id")["id"]); err != nil {
return nil, 0, nil, errors.WithStack(x.PageTokenInvalid)
return nil, nil, errors.WithStack(x.PageTokenInvalid)
}

if err := p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
Expand All @@ -97,13 +96,6 @@ func (p *Persister) ListSessions(ctx context.Context, active *bool, paginatorOpt
}
}

// Get the total count of matching items
total, err := q.Count(new(session.Session))
if err != nil {
return sqlcon.HandleError(err)
}
t = int64(total)

if len(expandables) > 0 {
q = q.EagerPreload(expandables.ToEager()...)
}
Expand All @@ -115,20 +107,20 @@ func (p *Persister) ListSessions(ctx context.Context, active *bool, paginatorOpt

return nil
}); err != nil {
return nil, 0, nil, err
return nil, nil, err
}

for k := range s {
if s[k].Identity == nil {
continue
}
if err := p.InjectTraitsSchemaURL(ctx, s[k].Identity); err != nil {
return nil, 0, nil, err
return nil, nil, err
}
}

s, nextPage := keysetpagination.Result(s, paginator)
return s, t, nextPage, nil
return s, nextPage, nil
}

// ListSessionsByIdentity retrieves sessions for an identity from the store.
Expand Down Expand Up @@ -171,7 +163,7 @@ func (p *Persister) ListSessionsByIdentity(
}
t = int64(total)

q.Order("authenticated_at DESC")
q.Order("created_at DESC")

// Get the paginated list of matching items
if err := q.Paginate(page, perPage).All(&s); err != nil {
Expand Down
3 changes: 1 addition & 2 deletions session/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,13 +414,12 @@ func (h *Handler) adminListSessions(w http.ResponseWriter, r *http.Request, ps h
}
}

sess, total, nextPage, err := h.r.SessionPersister().ListSessions(r.Context(), active, opts, expandables)
sess, nextPage, err := h.r.SessionPersister().ListSessions(r.Context(), active, opts, expandables)
if err != nil {
h.r.Writer().WriteError(w, r, err)
return
}

w.Header().Set("x-total-count", fmt.Sprint(total))
u := *r.URL
keysetpagination.Header(w, &u, nextPage)
h.r.Writer().Write(w, r, sess)
Expand Down
20 changes: 6 additions & 14 deletions session/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,6 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, "1", res.Header.Get("X-Total-Count"))

assertPageToken(t, uuid.Nil.String(), res.Header.Get("Link"))

Expand Down Expand Up @@ -639,7 +638,6 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, "1", res.Header.Get("X-Total-Count"))
assertPageToken(t, uuid.Nil.String(), res.Header.Get("Link"))

body := ioutilx.MustReadAll(res.Body)
Expand Down Expand Up @@ -812,7 +810,7 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
t.Run(fmt.Sprintf("active=%#v", tc.activeOnly), func(t *testing.T) {
sessions, _, _ := reg.SessionPersister().ListSessionsByIdentity(ctx, i.ID, nil, 1, 10, uuid.Nil, ExpandEverything)
require.Equal(t, 5, len(sessions))
assert.True(t, sort.IsSorted(sort.Reverse(byAuthenticatedAt(sessions))))
assert.True(t, sort.IsSorted(sort.Reverse(byCreatedAt(sessions))))

reqURL := ts.URL + "/admin/identities/" + i.ID.String() + "/sessions"
if tc.activeOnly != "" {
Expand All @@ -830,9 +828,6 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
actualSessionIds = append(actualSessionIds, s.ID)
}

totalCount, err := strconv.Atoi(res.Header.Get("X-Total-Count"))
require.NoError(t, err)
assert.Equal(t, len(tc.expectedSessionIds), totalCount)
assert.NotEqual(t, "", res.Header.Get("Link"))
assert.ElementsMatch(t, tc.expectedSessionIds, actualSessionIds)
})
Expand Down Expand Up @@ -895,9 +890,6 @@ func TestHandlerSelfServiceSessionManagement(t *testing.T) {
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode)

totalCount, err := strconv.Atoi(res.Header.Get("X-Total-Count"))
require.NoError(t, err)
require.Equal(t, numSessionsActive, totalCount)
require.NotEqual(t, "", res.Header.Get("Link"))
})

Expand Down Expand Up @@ -1088,10 +1080,10 @@ func TestHandlerRefreshSessionBySessionID(t *testing.T) {
})
}

type byAuthenticatedAt []Session
type byCreatedAt []Session

func (s byAuthenticatedAt) Len() int { return len(s) }
func (s byAuthenticatedAt) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s byAuthenticatedAt) Less(i, j int) bool {
return s[i].AuthenticatedAt.Before(s[j].AuthenticatedAt)
func (s byCreatedAt) Len() int { return len(s) }
func (s byCreatedAt) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s byCreatedAt) Less(i, j int) bool {
return s[i].CreatedAt.Before(s[j].CreatedAt)
}
2 changes: 1 addition & 1 deletion session/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type Persister interface {
GetSession(ctx context.Context, sid uuid.UUID, expandables Expandables) (*Session, error)

// ListSessions retrieves all sessions.
ListSessions(ctx context.Context, active *bool, paginatorOpts []keysetpagination.Option, expandables Expandables) ([]Session, int64, *keysetpagination.Paginator, error)
ListSessions(ctx context.Context, active *bool, paginatorOpts []keysetpagination.Option, expandables Expandables) ([]Session, *keysetpagination.Paginator, error)

// ListSessionsByIdentity retrieves sessions for an identity from the store.
ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, active *bool, page, perPage int, except uuid.UUID, expandables Expandables) ([]Session, int64, error)
Expand Down
12 changes: 4 additions & 8 deletions session/test/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,10 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
} {
t.Run("case=all "+tc.desc, func(t *testing.T) {
paginatorOpts := make([]keysetpagination.Option, 0)
actual, total, nextPage, err := l.ListSessions(ctx, tc.active, paginatorOpts, session.ExpandEverything)
actual, nextPage, err := l.ListSessions(ctx, tc.active, paginatorOpts, session.ExpandEverything)
require.NoError(t, err, "%+v", err)

require.Equal(t, len(tc.expected), len(actual))
require.Equal(t, int64(len(tc.expected)), total)
assert.Equal(t, true, nextPage.IsLast())

mapPageToken := nextPage.Token().Parse("")
Expand All @@ -322,11 +321,10 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {

t.Run("case=all sessions pagination only one page", func(t *testing.T) {
paginatorOpts := make([]keysetpagination.Option, 0)
actual, total, page, err := l.ListSessions(ctx, nil, paginatorOpts, session.ExpandEverything)
actual, page, err := l.ListSessions(ctx, nil, paginatorOpts, session.ExpandEverything)
require.NoError(t, err)

require.Equal(t, 6, len(actual))
require.Equal(t, int64(6), total)
assert.Equal(t, true, page.IsLast())
mapPageToken := page.Token().Parse("")
assert.Equal(t, uuid.Nil.String(), mapPageToken["id"])
Expand All @@ -336,9 +334,8 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
t.Run("case=all sessions pagination multiple pages", func(t *testing.T) {
paginatorOpts := make([]keysetpagination.Option, 0)
paginatorOpts = append(paginatorOpts, keysetpagination.WithSize(3))
firstPageItems, total, page1, err := l.ListSessions(ctx, nil, paginatorOpts, session.ExpandEverything)
firstPageItems, page1, err := l.ListSessions(ctx, nil, paginatorOpts, session.ExpandEverything)
require.NoError(t, err)
require.Equal(t, int64(6), total)
assert.Len(t, firstPageItems, 3)

assert.Equal(t, false, page1.IsLast())
Expand All @@ -347,9 +344,8 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
assert.Equal(t, 3, page1.Size())

// Validate secondPageItems page
secondPageItems, total, page2, err := l.ListSessions(ctx, nil, page1.ToOptions(), session.ExpandEverything)
secondPageItems, page2, err := l.ListSessions(ctx, nil, page1.ToOptions(), session.ExpandEverything)
require.NoError(t, err)
require.Equal(t, int64(6), total)
assert.Len(t, secondPageItems, 3)

acutalIDs := make([]uuid.UUID, 0)
Expand Down

0 comments on commit e24f993

Please sign in to comment.