Skip to content

Commit

Permalink
support not only user_id for custom claim
Browse files Browse the repository at this point in the history
  • Loading branch information
FZambia committed Sep 1, 2024
1 parent e94895d commit efb0563
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 26 deletions.
4 changes: 4 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ require (
github.com/rs/zerolog v1.33.0
github.com/spf13/cobra v1.8.1
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.17.3
github.com/tidwall/sjson v1.2.5
github.com/twmb/franz-go v1.17.1
github.com/twmb/franz-go/pkg/kadm v1.13.0
github.com/twmb/franz-go/pkg/kmsg v1.8.0
Expand Down Expand Up @@ -56,6 +58,8 @@ require (
github.com/pierrec/lz4/v4 v4.1.21 // indirect
github.com/planetscale/vtprotobuf v0.6.0 // indirect
github.com/shadowspore/fossil-delta v0.0.0-20240102155221-e3a8590b820b // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
go.uber.org/mock v0.4.0 // indirect
)

Expand Down
9 changes: 9 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,15 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/twmb/franz-go v1.17.1 h1:0LwPsbbJeJ9R91DPUHSEd4su82WJWcTY1Zzbgbg4CeQ=
github.com/twmb/franz-go v1.17.1/go.mod h1:NreRdJ2F7dziDY/m6VyspWd6sNxHKXdMZI42UfQ3GXM=
github.com/twmb/franz-go/pkg/kadm v1.13.0 h1:bJq4C2ZikUE2jh/wl9MtMTQ/kpmnBgVFh8XMQBEC+60=
Expand Down
41 changes: 29 additions & 12 deletions internal/cli/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/centrifugal/centrifugo/v5/internal/jwtverify"
"github.com/centrifugal/centrifugo/v5/internal/rule"
"github.com/cristalhq/jwt/v5"
"github.com/tidwall/sjson"
)

// GenerateToken generates sample JWT for user.
Expand All @@ -24,17 +25,25 @@ func GenerateToken(config jwtverify.VerifierConfig, user string, ttlSeconds int6
claims := jwtverify.ConnectTokenClaims{
RegisteredClaims: jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(time.Now()),
Subject: user,
},
}
if config.UserIDClaim != "" {
claims.UserID = user
} else {
claims.Subject = user
}
if ttlSeconds > 0 {
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(time.Duration(ttlSeconds) * time.Second))
}
token, err := builder.Build(claims)

encodedClaims, err := json.Marshal(claims)
if err != nil {
return "", err
}
if config.UserIDClaim != "" {
encodedClaims, err = sjson.SetBytes(encodedClaims, config.UserIDClaim, user)
if err != nil {
return "", err
}
}

token, err := builder.Build(encodedClaims)
if err != nil {
return "", err
}
Expand All @@ -54,18 +63,26 @@ func GenerateSubToken(config jwtverify.VerifierConfig, user string, channel stri
claims := jwtverify.SubscribeTokenClaims{
RegisteredClaims: jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(time.Now()),
Subject: user,
},
Channel: channel,
}
if config.UserIDClaim != "" {
claims.UserID = user
} else {
claims.Subject = user
}
if ttlSeconds > 0 {
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(time.Duration(ttlSeconds) * time.Second))
}
token, err := builder.Build(claims)

encodedClaims, err := json.Marshal(claims)
if err != nil {
return "", err
}
if config.UserIDClaim != "" {
encodedClaims, err = sjson.SetBytes(encodedClaims, config.UserIDClaim, user)
if err != nil {
return "", err
}
}

token, err := builder.Build(encodedClaims)
if err != nil {
return "", err
}
Expand Down
11 changes: 5 additions & 6 deletions internal/jwtverify/token_verifier_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/cristalhq/jwt/v5"
"github.com/rakutentech/jwk-go/okp"
"github.com/rs/zerolog/log"
"github.com/tidwall/gjson"
)

type VerifierConfig struct {
Expand Down Expand Up @@ -187,8 +188,6 @@ type ConnectTokenClaims struct {
Channels []string `json:"channels,omitempty"`
Subs map[string]SubscribeOptions `json:"subs,omitempty"`
Meta json.RawMessage `json:"meta,omitempty"`
// UserID is only used instead of jwt.RegisteredClaims.Subject when explicitly configured.
UserID string `json:"user_id,omitempty"`
// Channel must never be set in connection tokens. We check this on verifying.
Channel string `json:"channel,omitempty"`
jwt.RegisteredClaims
Expand All @@ -200,8 +199,6 @@ type SubscribeTokenClaims struct {
Channel string `json:"channel,omitempty"`
Client string `json:"client,omitempty"`
ExpireAt *int64 `json:"expire_at,omitempty"`
// UserID is only used instead of jwt.RegisteredClaims.Subject when explicitly configured.
UserID string `json:"user_id,omitempty"`
}

type jwksManager struct{ *jwks.Manager }
Expand Down Expand Up @@ -599,7 +596,8 @@ func (verifier *VerifierJWT) VerifyConnectToken(t string, skipVerify bool) (Conn
Meta: claims.Meta,
}
if verifier.userIDClaim != "" {
ct.UserID = claims.UserID
value := gjson.GetBytes(token.Claims(), verifier.userIDClaim)
ct.UserID = value.String()
} else {
ct.UserID = claims.RegisteredClaims.Subject
}
Expand Down Expand Up @@ -757,7 +755,8 @@ func (verifier *VerifierJWT) VerifySubscribeToken(t string, skipVerify bool) (Su
},
}
if verifier.userIDClaim != "" {
st.UserID = claims.UserID
value := gjson.GetBytes(token.Claims(), verifier.userIDClaim)
st.UserID = value.String()
} else {
st.UserID = claims.RegisteredClaims.Subject
}
Expand Down
29 changes: 21 additions & 8 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -1921,6 +1921,11 @@ func consumersFromConfig(v *viper.Viper) []consuming.ConsumerConfig {
return consumers
}

// Now Centrifugo uses https://github.com/tidwall/gjson to extract custom claims from JWT. So technically
// we could support extracting from nested objects using dot syntax, like "centrifugo.user". But for now
// not using this feature to keep things simple until necessary.
var customClaimRe = regexp.MustCompile("^[a-zA-Z_]+$")

func jwtVerifierConfig() (jwtverify.VerifierConfig, error) {
v := viper.GetViper()
cfg := jwtverify.VerifierConfig{}
Expand Down Expand Up @@ -1950,11 +1955,15 @@ func jwtVerifierConfig() (jwtverify.VerifierConfig, error) {
cfg.AudienceRegex = v.GetString("token_audience_regex")
cfg.Issuer = v.GetString("token_issuer")
cfg.IssuerRegex = v.GetString("token_issuer_regex")
var err error
cfg.UserIDClaim, err = tools.OptionalStringChoice(v, "token_user_id_claim", []string{"user_id"})
if err != nil {
return jwtverify.VerifierConfig{}, err

if v.GetString("token_user_id_claim") != "" {
customUserIDClaim := v.GetString("token_user_id_claim")
if !customClaimRe.MatchString(customUserIDClaim) {
return jwtverify.VerifierConfig{}, fmt.Errorf("invalid user ID claim: %s, must match %s regular expression", customUserIDClaim, customClaimRe.String())
}
cfg.UserIDClaim = customUserIDClaim
}

return cfg, nil
}

Expand Down Expand Up @@ -1987,11 +1996,15 @@ func subJWTVerifierConfig() (jwtverify.VerifierConfig, error) {
cfg.AudienceRegex = v.GetString("subscription_token_audience_regex")
cfg.Issuer = v.GetString("subscription_token_issuer")
cfg.IssuerRegex = v.GetString("subscription_token_issuer_regex")
var err error
cfg.UserIDClaim, err = tools.OptionalStringChoice(v, "subscription_token_user_id_claim", []string{"user_id"})
if err != nil {
return jwtverify.VerifierConfig{}, err

if v.GetString("subscription_token_user_id_claim") != "" {
customUserIDClaim := v.GetString("subscription_token_user_id_claim")
if !customClaimRe.MatchString(customUserIDClaim) {
return jwtverify.VerifierConfig{}, fmt.Errorf("invalid user ID claim: %s, must match %s regular expression", customUserIDClaim, customClaimRe.String())
}
cfg.UserIDClaim = customUserIDClaim
}

return cfg, nil
}

Expand Down

0 comments on commit efb0563

Please sign in to comment.