Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Override sub with federated_claims.user_id when dex is used #20683

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmd/argocd/commands/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"golang.org/x/oauth2"

"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/headless"
"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils"
argocdclient "github.com/argoproj/argo-cd/v2/pkg/apiclient"
sessionpkg "github.com/argoproj/argo-cd/v2/pkg/apiclient/session"
settingspkg "github.com/argoproj/argo-cd/v2/pkg/apiclient/settings"
Expand Down Expand Up @@ -196,7 +197,7 @@ func userDisplayName(claims jwt.MapClaims) string {
if name := jwtutil.StringField(claims, "name"); name != "" {
return name
}
return jwtutil.StringField(claims, "sub")
return utils.GetUserIdentifier(claims)
}

// oauth2Login opens a browser, runs a temporary HTTP server to delegate OAuth2 login flow and
Expand Down
2 changes: 1 addition & 1 deletion cmd/argocd/commands/project_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ Create token succeeded for proj:test-project:test-role.
issuedAt, _ := jwt.IssuedAt(claims)
expiresAt := int64(jwt.Float64Field(claims, "exp"))
id := jwt.StringField(claims, "jti")
subject := jwt.StringField(claims, "sub")
subject := utils.GetUserIdentifier(claims)

if !outputTokenOnly {
fmt.Printf("Create token succeeded for %s.\n", subject)
Expand Down
22 changes: 22 additions & 0 deletions cmd/argocd/commands/utils/claims.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package utils

import (
"github.com/golang-jwt/jwt/v4"
)

// GetUserIdentifier returns a consistent user identifier, checking federated_claims.user_id when Dex is in use
func GetUserIdentifier(claims jwt.MapClaims) string {
// Check for federated_claims.user_id if Dex is used
if federatedClaims, ok := claims["federated_claims"].(map[string]interface{}); ok {
if userID, exists := federatedClaims["user_id"].(string); exists {
return userID
}
}

// Fallback to sub
if sub, ok := claims["sub"].(string); ok {
return sub
}
return ""

}

Check failure on line 22 in cmd/argocd/commands/utils/claims.go

View workflow job for this annotation

GitHub Actions / Lint Go code

unnecessary trailing newline (whitespace)
3 changes: 2 additions & 1 deletion server/rbacpolicy/rbacpolicy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/golang-jwt/jwt/v4"
log "github.com/sirupsen/logrus"

"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils"
"github.com/argoproj/argo-cd/v2/pkg/apis/application/v1alpha1"
applister "github.com/argoproj/argo-cd/v2/pkg/client/listers/application/v1alpha1"
jwtutil "github.com/argoproj/argo-cd/v2/util/jwt"
Expand Down Expand Up @@ -114,7 +115,7 @@ func (p *RBACPolicyEnforcer) EnforceClaims(claims jwt.Claims, rvals ...interface
return false
}

subject := jwtutil.StringField(mapClaims, "sub")
subject := utils.GetUserIdentifier(mapClaims)
// Check if the request is for an application resource. We have special enforcement which takes
// into consideration the project's token and group bindings
var runtimePolicy string
Expand Down
3 changes: 2 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ import (
"k8s.io/client-go/tools/cache"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils"
"github.com/argoproj/argo-cd/v2/common"
"github.com/argoproj/argo-cd/v2/pkg/apiclient"
accountpkg "github.com/argoproj/argo-cd/v2/pkg/apiclient/account"
Expand Down Expand Up @@ -1417,7 +1418,7 @@ func (a *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string, error
log.Errorf("error fetching user info endpoint: %v", err)
return claims, "", status.Errorf(codes.Internal, "invalid userinfo response")
}
if groupClaims["sub"] != userInfo["sub"] {
if utils.GetUserIdentifier(groupClaims) != utils.GetUserIdentifier(userInfo) {
return claims, "", status.Error(codes.Unknown, "subject of claims from user info endpoint didn't match subject of idToken, see https://openid.net/specs/openid-connect-core-1_0.html#UserInfo")
}
groupClaims["groups"] = userInfo["groups"]
Expand Down
16 changes: 9 additions & 7 deletions util/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"

"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils"
"github.com/argoproj/argo-cd/v2/common"
"github.com/argoproj/argo-cd/v2/server/settings/oidc"
"github.com/argoproj/argo-cd/v2/util/cache"
Expand Down Expand Up @@ -402,9 +403,8 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
log.Errorf("cannot encrypt accessToken: %v (claims=%s)", err, claimsJSON)
return
}
sub := jwtutil.StringField(claims, "sub")
err = a.clientCache.Set(&cache.Item{
Key: formatAccessTokenCacheKey(sub),
Key: formatAccessTokenCacheKey(claims),
Object: encToken,
CacheActionOpts: cache.CacheActionOpts{
Expiration: getTokenExpiration(claims),
Expand Down Expand Up @@ -552,12 +552,12 @@ func createClaimsAuthenticationRequestParameter(requestedClaims map[string]*oidc

// GetUserInfo queries the IDP userinfo endpoint for claims
func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoPath string) (jwt.MapClaims, bool, error) {
sub := jwtutil.StringField(actualClaims, "sub")
sub := utils.GetUserIdentifier(actualClaims)
var claims jwt.MapClaims
var encClaims []byte

// in case we got it in the cache, we just return the item
clientCacheKey := formatUserInfoResponseCacheKey(sub)
clientCacheKey := formatUserInfoResponseCacheKey(actualClaims)
if err := a.clientCache.Get(clientCacheKey, &encClaims); err == nil {
claimsRaw, err := crypto.Decrypt(encClaims, a.encryptionKey)
if err != nil {
Expand All @@ -575,7 +575,7 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP

// check if the accessToken for the user is still present
var encAccessToken []byte
err := a.clientCache.Get(formatAccessTokenCacheKey(sub), &encAccessToken)
err := a.clientCache.Get(formatAccessTokenCacheKey(actualClaims), &encAccessToken)
// without an accessToken we can't query the user info endpoint
// thus the user needs to reauthenticate for argocd to get a new accessToken
if errors.Is(err, cache.ErrCacheMiss) {
Expand Down Expand Up @@ -684,11 +684,13 @@ func getTokenExpiration(claims jwt.MapClaims) time.Duration {
}

// formatUserInfoResponseCacheKey returns the key which is used to store userinfo of user in cache
func formatUserInfoResponseCacheKey(sub string) string {
func formatUserInfoResponseCacheKey(claims jwt.MapClaims) string {
sub := utils.GetUserIdentifier(claims)
return fmt.Sprintf("%s_%s", UserInfoResponseCachePrefix, sub)
}

// formatAccessTokenCacheKey returns the key which is used to store the accessToken of a user in cache
func formatAccessTokenCacheKey(sub string) string {
func formatAccessTokenCacheKey(claims jwt.MapClaims) string {
sub := utils.GetUserIdentifier(claims)
return fmt.Sprintf("%s_%s", AccessTokenCachePrefix, sub)
}
18 changes: 9 additions & 9 deletions util/oidc/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ func TestGetUserInfo(t *testing.T) {
expectError bool
}{
{
key: formatUserInfoResponseCacheKey("randomUser"),
key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "randomUser"}),
expectError: true,
},
},
Expand All @@ -654,7 +654,7 @@ func TestGetUserInfo(t *testing.T) {
encrypt bool
}{
{
key: formatAccessTokenCacheKey("randomUser"),
key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "randomUser"}),
value: "FakeAccessToken",
encrypt: true,
},
Expand All @@ -673,7 +673,7 @@ func TestGetUserInfo(t *testing.T) {
expectError bool
}{
{
key: formatUserInfoResponseCacheKey("randomUser"),
key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "randomUser"}),
expectError: true,
},
},
Expand All @@ -688,7 +688,7 @@ func TestGetUserInfo(t *testing.T) {
encrypt bool
}{
{
key: formatAccessTokenCacheKey("randomUser"),
key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "randomUser"}),
value: "FakeAccessToken",
encrypt: true,
},
Expand All @@ -707,7 +707,7 @@ func TestGetUserInfo(t *testing.T) {
expectError bool
}{
{
key: formatUserInfoResponseCacheKey("randomUser"),
key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "randomUser"}),
expectError: true,
},
},
Expand All @@ -730,7 +730,7 @@ func TestGetUserInfo(t *testing.T) {
encrypt bool
}{
{
key: formatAccessTokenCacheKey("randomUser"),
key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "randomUser"}),
value: "FakeAccessToken",
encrypt: true,
},
Expand All @@ -749,7 +749,7 @@ func TestGetUserInfo(t *testing.T) {
expectError bool
}{
{
key: formatUserInfoResponseCacheKey("randomUser"),
key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "randomUser"}),
expectError: true,
},
},
Expand Down Expand Up @@ -782,7 +782,7 @@ func TestGetUserInfo(t *testing.T) {
expectError bool
}{
{
key: formatUserInfoResponseCacheKey("randomUser"),
key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "randomUser"}),
value: "{\"groups\":[\"githubOrg:engineers\"]}",
expectEncrypted: true,
expectError: false,
Expand All @@ -809,7 +809,7 @@ func TestGetUserInfo(t *testing.T) {
encrypt bool
}{
{
key: formatAccessTokenCacheKey("randomUser"),
key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "randomUser"}),
value: "FakeAccessToken",
encrypt: true,
},
Expand Down
3 changes: 2 additions & 1 deletion util/rbac/rbac.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"sync"
"time"

"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils"
"github.com/argoproj/argo-cd/v2/util/assets"
"github.com/argoproj/argo-cd/v2/util/glob"
jwtutil "github.com/argoproj/argo-cd/v2/util/jwt"
Expand Down Expand Up @@ -255,7 +256,7 @@ func (e *Enforcer) EnforceErr(rvals ...interface{}) error {
if err != nil {
break
}
if sub := jwtutil.StringField(claims, "sub"); sub != "" {
if sub := utils.GetUserIdentifier(claims); sub != "" {
rvalsStrs = append(rvalsStrs, fmt.Sprintf("sub: %s", sub))
}
if issuedAtTime, err := jwtutil.IssuedAtTime(claims); err == nil {
Expand Down
22 changes: 13 additions & 9 deletions util/session/sessionmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils"
"github.com/argoproj/argo-cd/v2/common"
"github.com/argoproj/argo-cd/v2/pkg/client/listers/application/v1alpha1"
"github.com/argoproj/argo-cd/v2/server/rbacpolicy"
Expand Down Expand Up @@ -226,7 +227,7 @@ func (mgr *SessionManager) Parse(tokenString string) (jwt.Claims, string, error)
return nil, "", err
}

subject := jwtutil.StringField(claims, "sub")
subject := utils.GetUserIdentifier(claims)
id := jwtutil.StringField(claims, "jti")

if projName, role, ok := rbacpolicy.GetProjectRoleFromSubject(subject); ok {
Expand Down Expand Up @@ -502,9 +503,17 @@ func WithAuthMiddleware(disabled bool, authn TokenVerifier, next http.Handler) h
return
}
ctx := r.Context()

// Assert that claims is of type jwt.MapClaims
mapClaims, ok := claims.(jwt.MapClaims)
if !ok {
http.Error(w, "Invalid claims type", http.StatusUnauthorized)
return
}

// Add claims to the context to inspect for RBAC
// nolint:staticcheck
ctx = context.WithValue(ctx, "claims", claims)
ctx = context.WithValue(ctx, "user_id", utils.GetUserIdentifier(mapClaims))
r = r.WithContext(ctx)
}
next.ServeHTTP(w, r)
Expand Down Expand Up @@ -593,12 +602,7 @@ func Username(ctx context.Context) string {
if !ok {
return ""
}
switch jwtutil.StringField(mapClaims, "iss") {
case SessionManagerClaimsIssuer:
return jwtutil.StringField(mapClaims, "sub")
default:
return jwtutil.StringField(mapClaims, "email")
}
return utils.GetUserIdentifier(mapClaims)
}

func Iss(ctx context.Context) string {
Expand All @@ -622,7 +626,7 @@ func Sub(ctx context.Context) string {
if !ok {
return ""
}
return jwtutil.StringField(mapClaims, "sub")
return utils.GetUserIdentifier(mapClaims)
}

func Groups(ctx context.Context, scopes []string) []string {
Expand Down
5 changes: 3 additions & 2 deletions util/session/sessionmanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/kubernetes/fake"

"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils"
"github.com/argoproj/argo-cd/v2/common"
appv1 "github.com/argoproj/argo-cd/v2/pkg/apis/application/v1alpha1"
apps "github.com/argoproj/argo-cd/v2/pkg/client/clientset/versioned/fake"
Expand Down Expand Up @@ -99,7 +100,7 @@ func TestSessionManager_AdminToken(t *testing.T) {
assert.Empty(t, newToken)

mapClaims := *(claims.(*jwt.MapClaims))
subject := mapClaims["sub"].(string)
subject := utils.GetUserIdentifier(mapClaims)
if subject != "admin" {
t.Errorf("Token claim subject \"%s\" does not match expected subject \"%s\".", subject, "admin")
}
Expand All @@ -126,7 +127,7 @@ func TestSessionManager_AdminToken_ExpiringSoon(t *testing.T) {
claims, _, err := mgr.Parse(newToken)
require.NoError(t, err)
mapClaims := *(claims.(*jwt.MapClaims))
subject := mapClaims["sub"].(string)
subject := utils.GetUserIdentifier(mapClaims)
assert.Equal(t, "admin", subject)
}

Expand Down
Loading