diff --git a/backend/dto/session.go b/backend/dto/session.go index 52859d090..e2b275e34 100644 --- a/backend/dto/session.go +++ b/backend/dto/session.go @@ -1,8 +1,10 @@ package dto import ( + "encoding/json" "fmt" "github.com/gofrs/uuid" + "github.com/lestrrat-go/jwx/v2/jwt" "github.com/mileusna/useragent" "github.com/teamhanko/hanko/backend/persistence/models" "time" @@ -44,10 +46,72 @@ func FromSessionModel(model models.Session, current bool) SessionData { return sessionData } +type Claims struct { + Subject uuid.UUID `json:"subject"` + IssuedAt *time.Time `json:"issued_at,omitempty"` + Expiration time.Time `json:"expiration"` + Audience []string `json:"audience,omitempty"` + Issuer *string `json:"issuer,omitempty"` + Email *EmailJwt `json:"email,omitempty"` + SessionID uuid.UUID `json:"session_id"` +} + type ValidateSessionResponse struct { - IsValid bool `json:"is_valid"` + IsValid bool `json:"is_valid"` + Claims *Claims `json:"claims,omitempty"` + // deprecated ExpirationTime *time.Time `json:"expiration_time,omitempty"` - UserID *uuid.UUID `json:"user_id,omitempty"` + // deprecated + UserID *uuid.UUID `json:"user_id,omitempty"` +} + +func GetClaimsFromToken(token jwt.Token) (*Claims, error) { + claims := &Claims{} + + if subject := token.Subject(); len(subject) > 0 { + s, err := uuid.FromString(subject) + if err != nil { + return nil, fmt.Errorf("'subject' is not a uuid: %w", err) + } + claims.Subject = s + } + + if sessionID, valid := token.Get("session_id"); valid { + s, err := uuid.FromString(sessionID.(string)) + if err != nil { + return nil, fmt.Errorf("'session_id' is not a uuid: %w", err) + } + claims.SessionID = s + } + + if issuedAt := token.IssuedAt(); !issuedAt.IsZero() { + claims.IssuedAt = &issuedAt + } + + if audience := token.Audience(); len(audience) > 0 { + claims.Audience = audience + } + + if issuer := token.Issuer(); len(issuer) > 0 { + claims.Issuer = &issuer + } + + if email, valid := token.Get("email"); valid { + if data, ok := email.(map[string]interface{}); ok { + jsonData, err := json.Marshal(data) + if err != nil { + return nil, fmt.Errorf("failed to marshal 'email' claim: %w", err) + } + err = json.Unmarshal(jsonData, &claims.Email) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal 'email' claim: %w", err) + } + } + } + + claims.Expiration = token.Expiration() + + return claims, nil } type ValidateSessionRequest struct { diff --git a/backend/handler/session.go b/backend/handler/session.go index 81390b75f..ee6481c57 100644 --- a/backend/handler/session.go +++ b/backend/handler/session.go @@ -2,10 +2,8 @@ package handler import ( "fmt" - "github.com/gofrs/uuid" echojwt "github.com/labstack/echo-jwt/v4" "github.com/labstack/echo/v4" - "github.com/lestrrat-go/jwx/v2/jwt" "github.com/teamhanko/hanko/backend/config" "github.com/teamhanko/hanko/backend/dto" "github.com/teamhanko/hanko/backend/persistence" @@ -36,29 +34,23 @@ func (h *SessionHandler) ValidateSession(c echo.Context) error { return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false}) } - var token jwt.Token for _, extractor := range extractors { auths, extractorErr := extractor(c) if extractorErr != nil { continue } for _, auth := range auths { - t, tokenErr := h.sessionManager.Verify(auth) + token, tokenErr := h.sessionManager.Verify(auth) if tokenErr != nil { continue } - // check that the session id is stored in the database - sessionId, ok := t.Get("session_id") - if !ok { - continue - } - sessionID, err := uuid.FromString(sessionId.(string)) + claims, err := dto.GetClaimsFromToken(token) if err != nil { - continue + return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("failed to parse token claims: %w", err)) } - sessionModel, err := h.persister.GetSessionPersister().Get(sessionID) + sessionModel, err := h.persister.GetSessionPersister().Get(claims.SessionID) if err != nil { return fmt.Errorf("failed to get session from database: %w", err) } @@ -73,22 +65,16 @@ func (h *SessionHandler) ValidateSession(c echo.Context) error { return dto.ToHttpError(err) } - token = t - break + return c.JSON(http.StatusOK, dto.ValidateSessionResponse{ + IsValid: true, + Claims: claims, + ExpirationTime: &claims.Expiration, + UserID: &claims.Subject, + }) } } - if token != nil { - expirationTime := token.Expiration() - userID := uuid.FromStringOrNil(token.Subject()) - return c.JSON(http.StatusOK, dto.ValidateSessionResponse{ - IsValid: true, - ExpirationTime: &expirationTime, - UserID: &userID, - }) - } else { - return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false}) - } + return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false}) } func (h *SessionHandler) ValidateSessionFromBody(c echo.Context) error { @@ -108,17 +94,12 @@ func (h *SessionHandler) ValidateSessionFromBody(c echo.Context) error { return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false}) } - // check that the session id is stored in the database - sessionId, ok := token.Get("session_id") - if !ok { - return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false}) - } - sessionID, err := uuid.FromString(sessionId.(string)) + claims, err := dto.GetClaimsFromToken(token) if err != nil { - return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false}) + return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("failed to parse token claims: %w", err)) } - sessionModel, err := h.persister.GetSessionPersister().Get(sessionID) + sessionModel, err := h.persister.GetSessionPersister().Get(claims.SessionID) if err != nil { return dto.ToHttpError(err) } @@ -134,11 +115,10 @@ func (h *SessionHandler) ValidateSessionFromBody(c echo.Context) error { return dto.ToHttpError(err) } - expirationTime := token.Expiration() - userID := uuid.FromStringOrNil(token.Subject()) return c.JSON(http.StatusOK, dto.ValidateSessionResponse{ IsValid: true, - ExpirationTime: &expirationTime, - UserID: &userID, + Claims: claims, + ExpirationTime: &claims.Expiration, + UserID: &claims.Subject, }) }