Skip to content

Commit

Permalink
oauth2: Allow customization of JWT claims
Browse files Browse the repository at this point in the history
Signed-off-by: aeneasr <[email protected]>
  • Loading branch information
aeneasr committed Oct 24, 2018
1 parent 189589c commit f97e451
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 17 deletions.
33 changes: 33 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,39 @@ bumps (`0.1.0` -> `0.2.0`).

<!-- END doctoc generated TOC please keep comment here to allow auto update -->

## 0.26.0

This release makes it easier to define custom JWT Containers for access tokens when using the JWT strategy. To do that,
the following signatures have changed:

// github.com/ory/fosite/handler/oauth2
type JWTSessionContainer interface {
// GetJWTClaims returns the claims.
- GetJWTClaims() *jwt.JWTClaims
+ GetJWTClaims() jwt.JWTClaimsContainer

// GetJWTHeader returns the header.
GetJWTHeader() *jwt.Headers

fosite.Session
}

+ type JWTClaimsContainer interface {
+ // With returns a copy of itself with expiresAt and scope set to the given values.
+ With(expiry time.Time, scope []string) JWTClaimsContainer
+
+ // WithDefaults returns a copy of itself with issuedAt and issuer set to the given default values. If those
+ // values are already set in the claims, they will not be updated.
+ WithDefaults(iat time.Time, issuer string) JWTClaimsContainer
+
+ // ToMapClaims returns the claims as a github.com/dgrijalva/jwt-go.MapClaims type.
+ ToMapClaims() jwt.MapClaims
+ }
```
All default session implementations have been updated to reflect this change. If you define custom session, this patch
will affect you.
## 0.24.0
This release addresses areas where the go context was missing or not propagated down the call path properly.
Expand Down
1 change: 1 addition & 0 deletions handler/oauth2/introspector_jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ func TestIntrospectJWT(t *testing.T) {
token, _, err := strat.GenerateAccessToken(nil, jwt)
assert.NoError(t, err)
parts := strings.Split(token, ".")
require.Len(t, parts, 3, "%s - %v", token, parts)
dec, err := base64.RawURLEncoding.DecodeString(parts[1])
assert.NoError(t, err)
s := strings.Replace(string(dec), "peter", "piper", -1)
Expand Down
23 changes: 10 additions & 13 deletions handler/oauth2/strategy_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,22 +156,19 @@ func (h *DefaultJWTStrategy) validate(ctx context.Context, token string) (t *jwt

func (h *DefaultJWTStrategy) generate(ctx context.Context, tokenType fosite.TokenType, requester fosite.Requester) (string, string, error) {
if jwtSession, ok := requester.GetSession().(JWTSessionContainer); !ok {
return "", "", errors.New("Session must be of type JWTSessionContainer")
return "", "", errors.Errorf("Session must be of type JWTSessionContainer but got type: %T", requester.GetSession())
} else if jwtSession.GetJWTClaims() == nil {
return "", "", errors.New("GetTokenClaims() must not be nil")
} else {
claims := jwtSession.GetJWTClaims()
claims.ExpiresAt = jwtSession.GetExpiresAt(tokenType)

if claims.IssuedAt.IsZero() {
claims.IssuedAt = time.Now().UTC()
}

if claims.Issuer == "" {
claims.Issuer = h.Issuer
}

claims.Scope = requester.GetGrantedScopes()
claims := jwtSession.GetJWTClaims().
With(
jwtSession.GetExpiresAt(tokenType),
requester.GetGrantedScopes(),
).
WithDefaults(
time.Now().UTC(),
h.Issuer,
)

return h.JWTStrategy.Generate(ctx, claims.ToMapClaims(), jwtSession.GetJWTHeader())
}
Expand Down
4 changes: 2 additions & 2 deletions handler/oauth2/strategy_jwt_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (

type JWTSessionContainer interface {
// GetJWTClaims returns the claims.
GetJWTClaims() *jwt.JWTClaims
GetJWTClaims() jwt.JWTClaimsContainer

// GetJWTHeader returns the header.
GetJWTHeader() *jwt.Headers
Expand All @@ -48,7 +48,7 @@ type JWTSession struct {
Subject string
}

func (j *JWTSession) GetJWTClaims() *jwt.JWTClaims {
func (j *JWTSession) GetJWTClaims() jwt.JWTClaimsContainer {
if j.JWTClaims == nil {
j.JWTClaims = &jwt.JWTClaims{}
}
Expand Down
8 changes: 6 additions & 2 deletions handler/oauth2/strategy_jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
package oauth2

import (
"github.com/stretchr/testify/require"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -55,7 +56,7 @@ var jwtValidCase = func(tokenType fosite.TokenType) *fosite.Request {
Audience: []string{"group0"},
IssuedAt: time.Now().UTC(),
NotBefore: time.Now().UTC(),
Extra: make(map[string]interface{}),
Extra: map[string]interface{}{"foo": "bar"},
},
JWTHeader: &jwt.Headers{
Extra: make(map[string]interface{}),
Expand Down Expand Up @@ -112,7 +113,10 @@ func TestAccessToken(t *testing.T) {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
token, signature, err := j.GenerateAccessToken(nil, c.r)
assert.NoError(t, err)
assert.Equal(t, strings.Split(token, ".")[2], signature)

parts := strings.Split(token, ".")
require.Len(t, parts, 3, "%s - %v", token, parts)
assert.Equal(t, parts[2], signature)

validate := j.signature(token)
err = j.ValidateAccessToken(nil, c.r, token)
Expand Down
37 changes: 37 additions & 0 deletions token/jwt/claims_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,25 @@ import (
"github.com/pborman/uuid"
)

type JWTClaimsDefaults struct {
ExpiresAt time.Time
IssuedAt time.Time
Issuer string
Scope []string
}

type JWTClaimsContainer interface {
// With returns a copy of itself with expiresAt and scope set to the given values.
With(expiry time.Time, scope []string) JWTClaimsContainer

// WithDefaults returns a copy of itself with issuedAt and issuer set to the given default values. If those
// values are already set in the claims, they will not be updated.
WithDefaults(iat time.Time, issuer string) JWTClaimsContainer

// ToMapClaims returns the claims as a github.com/dgrijalva/jwt-go.MapClaims type.
ToMapClaims() jwt.MapClaims
}

// JWTClaims represent a token's claims.
type JWTClaims struct {
Subject string
Expand All @@ -41,6 +60,24 @@ type JWTClaims struct {
Extra map[string]interface{}
}

func (c *JWTClaims) With(expiry time.Time, scope []string) JWTClaimsContainer {
c.ExpiresAt = expiry
c.Scope = scope
return c

}

func (c *JWTClaims) WithDefaults(iat time.Time, issuer string) JWTClaimsContainer {
if c.IssuedAt.IsZero() {
c.IssuedAt = iat
}

if c.Issuer == "" {
c.Issuer = issuer
}
return c
}

// ToMap will transform the headers to a map structure
func (c *JWTClaims) ToMap() map[string]interface{} {
var ret = Copy(c.Extra)
Expand Down

0 comments on commit f97e451

Please sign in to comment.