diff --git a/vmidentity/goclients/oidc/oidcclient.go b/vmidentity/goclients/oidc/oidcclient.go index 960fc0125..5a5433b74 100644 --- a/vmidentity/goclients/oidc/oidcclient.go +++ b/vmidentity/goclients/oidc/oidcclient.go @@ -3,8 +3,9 @@ package oidc import ( "crypto/rsa" "crypto/x509" - "gopkg.in/square/go-jose.v2" "time" + + jose "gopkg.in/square/go-jose.v2" ) // Client represents an oidc client @@ -73,6 +74,12 @@ func ParseTenantInToken(token string) (string, error) { return parseTenantInToken(token) } +// ParseSignedToken parses a token string and returns an unvalidated JWT. +func ParseSignedToken(token string) (JWT, error) { + jwt, err := parseSignedToken(token, nil) + return jwt, err +} + // Tokens represents successful acquire token response type Tokens interface { AccessToken() string diff --git a/vmidentity/goclients/oidc/token_impl.go b/vmidentity/goclients/oidc/token_impl.go index 6eaf4cdf6..1ece3033d 100644 --- a/vmidentity/goclients/oidc/token_impl.go +++ b/vmidentity/goclients/oidc/token_impl.go @@ -5,9 +5,10 @@ import ( "encoding/base64" "encoding/json" "fmt" - "gopkg.in/square/go-jose.v2" "strings" "time" + + jose "gopkg.in/square/go-jose.v2" ) const ( @@ -44,8 +45,8 @@ const ( IDTokenClass = "id_token" ) -// jwt interface represents a parsed/validated jwt -type jwt interface { +// JWT interface represents a parsed/validated JWT +type JWT interface { Issuer() string Nonce() (string, bool) Groups() ([]string, bool) @@ -91,7 +92,7 @@ func parseTenantInToken(token string) (string, error) { } func parseToken( - token string, issuer string, audience string, nonce string, signers IssuerSigners, tokenType string, logger Logger) (jwt, error) { + token string, issuer string, audience string, nonce string, signers IssuerSigners, tokenType string, logger Logger) (JWT, error) { var err error if signers == nil { @@ -128,8 +129,8 @@ func parseToken( } func verifyToken(token string, signers *jose.JSONWebKeySet, issuer string, audience string, nonce string, - clockTolerance int, logger Logger) (jwt, error) { - jwt, err := parseSignedToken(token, issuer, signers, clockTolerance) + clockTolerance int, logger Logger) (JWT, error) { + jwt, err := parseAndValidateSignedToken(token, issuer, signers, clockTolerance) if err != nil { PrintLog(logger, LogLevelError, "verifyToken: Parse signed token failed. Error: '%v'", err) @@ -484,7 +485,22 @@ type jwtImpl struct { } // parseSignedToken parses jwt token from its string representation -func parseSignedToken(token string, issuer string, keySet *jose.JSONWebKeySet, clockToleranceSecs int) (jwt, error) { +func parseAndValidateSignedToken(token string, issuer string, keySet *jose.JSONWebKeySet, clockToleranceSecs int) (JWT, error) { + jwt, err := parseSignedToken(token, keySet) + if err != nil { + return nil, err + } + jwti := jwt.(*jwtImpl) + // validate and normalize expected claims. + err = validateAndNormalizeClaims(&jwti.claims, issuer, clockToleranceSecs) + if err != nil { + return nil, err + } + return jwt, nil +} + +// parseSignedToken parses jwt token from its string representation +func parseSignedToken(token string, keySet *jose.JSONWebKeySet) (JWT, error) { jws, err := jose.ParseSigned(token) if err != nil { return nil, OIDCTokenInvalidError.MakeError("Token format is invalid", err) @@ -522,13 +538,6 @@ func parseSignedToken(token string, issuer string, keySet *jose.JSONWebKeySet, c if err != nil { return nil, err } - - // validate and normalize expected claims. - err = validateAndNormalizeClaims(&tokenBody, issuer, clockToleranceSecs) - if err != nil { - return nil, err - } - return &jwtImpl{claims: tokenBody}, nil } @@ -672,8 +681,8 @@ func decodePayload(payload []byte) (map[string]interface{}, error) { } func parseTokenMulti( - token string, audience string, nonce string, providerInfo []ProviderInfo, tokenType string, logger Logger) (jwt, int, error) { - var tok jwt + token string, audience string, nonce string, providerInfo []ProviderInfo, tokenType string, logger Logger) (JWT, int, error) { + var tok JWT var index int var err error var info ProviderInfo diff --git a/vmidentity/goclients/oidc/token_test.go b/vmidentity/goclients/oidc/token_test.go index 5b63ef33d..3ae735697 100644 --- a/vmidentity/goclients/oidc/token_test.go +++ b/vmidentity/goclients/oidc/token_test.go @@ -1,10 +1,11 @@ package oidc import ( - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestValidateExpiration(t *testing.T) { @@ -61,13 +62,13 @@ func TestValidateAudienceClaim(t *testing.T) { func TestParseSignedToken(t *testing.T) { token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2V4YW1wbGUuY29tIiwic3ViIjoic3ViSjEiLCJuYmYiOjE1MjE3NzMwOTgsImV4cCI6MTUyMTc3NjY5OCwiaWF0IjoxNTIxNzczMDk4LCJqdGkiOiJpZDEyMzQ1NiJ9.wf8E82CGm_saE8gGnoz7aX1COSzkc5ZbcO2H7xJSgIQ" - jwt, err := parseSignedToken(token, "issuer1", nil, defaultClockToleranceSecs) + jwt, err := parseAndValidateSignedToken(token, "issuer1", nil, defaultClockToleranceSecs) assert.Nil(t, jwt, "Token should be nil") if assert.NotNil(t, err, "Error is expected when using unsupported signature algo") { assert.Contains(t, err.Error(), OIDCTokenInvalidError.Name()) } - jwt, err = parseSignedToken("", "issuer1", nil, defaultClockToleranceSecs) + jwt, err = parseAndValidateSignedToken("", "issuer1", nil, defaultClockToleranceSecs) assert.Nil(t, jwt, "Token should be nil") if assert.NotNil(t, err, "Error is expected when using bad token") { assert.Contains(t, err.Error(), OIDCTokenInvalidError.Name()) @@ -85,21 +86,21 @@ func TestParseSignedToken(t *testing.T) { require.True(t, ok, "Error when getting keyset from IssuerSigners") require.NotNil(t, s, "KeySet is nil") - jwt, err = parseSignedToken(strTok, client.Issuer(), s.signers, defaultClockToleranceSecs) + jwt, err = parseAndValidateSignedToken(strTok, client.Issuer(), s.signers, defaultClockToleranceSecs) assert.Nil(t, err, "No error expected: %+v", err) assert.NotNil(t, jwt, "Token should not be nil") - jwt, err = parseSignedToken(" "+strTok+" ", client.Issuer(), s.signers, defaultClockToleranceSecs) + jwt, err = parseAndValidateSignedToken(" "+strTok+" ", client.Issuer(), s.signers, defaultClockToleranceSecs) assert.Nil(t, err, "No error expected: %+v", err) assert.NotNil(t, jwt, "Token should not be nil") - jwt, err = parseSignedToken(strTok+"a", client.Issuer(), s.signers, defaultClockToleranceSecs) + jwt, err = parseAndValidateSignedToken(strTok+"a", client.Issuer(), s.signers, defaultClockToleranceSecs) if assert.NotNil(t, err, "Error expected when parsing malformed token") { assert.Contains(t, err.Error(), OIDCTokenInvalidSignatureError.Name(), "Wrong Error code: %+v", err) } assert.Nil(t, jwt, "No token expected") - jwt, err = parseSignedToken(strTok, "wrongIssuer", s.signers, defaultClockToleranceSecs) + jwt, err = parseAndValidateSignedToken(strTok, "wrongIssuer", s.signers, defaultClockToleranceSecs) if assert.NotNil(t, err, "Error expected when token from incorrect issuer") { assert.Contains(t, err.Error(), OIDCTokenInvalidError.Name(), "Wrong Error code: %+v", err) } @@ -147,7 +148,7 @@ func TestParseHotkClaim(t *testing.T) { func testInvalidTokens(t *testing.T, token string) { testInvalidTenantInToken(t, token) - jwt, err := parseSignedToken(token, "issuer1", nil, defaultClockToleranceSecs) + jwt, err := parseAndValidateSignedToken(token, "issuer1", nil, defaultClockToleranceSecs) assert.Nil(t, jwt, "Token should be nil") if assert.NotNil(t, err, "Error is expected when using bad token") { assert.Contains(t, err.Error(), OIDCTokenInvalidError.Name())