From 655dcf41328ecff8026bcb9b0834293b73b45834 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bl=C3=BCcher?= Date: Thu, 21 Nov 2024 09:11:24 +0100 Subject: [PATCH] fix: gracefully handle malformed jwk headers --- parse.go | 49 +++++++++-- parse_test.go | 238 +++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 276 insertions(+), 11 deletions(-) diff --git a/parse.go b/parse.go index 92a8ffd..a2c9a22 100644 --- a/parse.go +++ b/parse.go @@ -181,24 +181,43 @@ func keyFunc(t *jwt.Token) (interface{}, error) { // Parses a JWK and inherently strips it of optional fields func parseJwk(jwkMap map[string]interface{}) (interface{}, error) { - switch jwkMap["kty"].(string) { + // Ensure that JWK kty is present and is a string. + kty, ok := jwkMap["kty"].(string) + if !ok { + return nil, ErrInvalidProof + } + switch kty { case "EC": + // Ensure that the required fields are present and are strings. + x, ok := jwkMap["x"].(string) + if !ok { + return nil, ErrInvalidProof + } + y, ok := jwkMap["y"].(string) + if !ok { + return nil, ErrInvalidProof + } + crv, ok := jwkMap["crv"].(string) + if !ok { + return nil, ErrInvalidProof + } + // Decode the coordinates from Base64. // // According to RFC 7518, they are Base64 URL unsigned integers. // https://tools.ietf.org/html/rfc7518#section-6.3 - xCoordinate, err := base64urlTrailingPadding(jwkMap["x"].(string)) + xCoordinate, err := base64urlTrailingPadding(x) if err != nil { return nil, err } - yCoordinate, err := base64urlTrailingPadding(jwkMap["y"].(string)) + yCoordinate, err := base64urlTrailingPadding(y) if err != nil { return nil, err } // Read the specified curve of the key. var curve elliptic.Curve - switch jwkMap["crv"].(string) { + switch crv { case "P-256": curve = elliptic.P256() case "P-384": @@ -215,15 +234,25 @@ func parseJwk(jwkMap map[string]interface{}) (interface{}, error) { Curve: curve, }, nil case "RSA": + // Ensure that the required fields are present and are strings. + e, ok := jwkMap["e"].(string) + if !ok { + return nil, ErrInvalidProof + } + n, ok := jwkMap["n"].(string) + if !ok { + return nil, ErrInvalidProof + } + // Decode the exponent and modulus from Base64. // // According to RFC 7518, they are Base64 URL unsigned integers. // https://tools.ietf.org/html/rfc7518#section-6.3 - exponent, err := base64urlTrailingPadding(jwkMap["e"].(string)) + exponent, err := base64urlTrailingPadding(e) if err != nil { return nil, err } - modulus, err := base64urlTrailingPadding(jwkMap["n"].(string)) + modulus, err := base64urlTrailingPadding(n) if err != nil { return nil, err } @@ -232,7 +261,13 @@ func parseJwk(jwkMap map[string]interface{}) (interface{}, error) { E: int(big.NewInt(0).SetBytes(exponent).Uint64()), }, nil case "OKP": - publicKey, err := base64urlTrailingPadding(jwkMap["x"].(string)) + // Ensure that the required fields are present and are strings. + x, ok := jwkMap["x"].(string) + if !ok { + return nil, ErrInvalidProof + } + + publicKey, err := base64urlTrailingPadding(x) if err != nil { return nil, err } diff --git a/parse_test.go b/parse_test.go index f7edca2..72bba7b 100644 --- a/parse_test.go +++ b/parse_test.go @@ -633,7 +633,7 @@ func TestParse_ProofWithExtraKeyMembersEC(t *testing.T) { Path: "/token", } - // Set an optional member in the key used in the proof, the member should be disregarded in the thubprint + // Set an optional member in the key used in the proof, the member should be disregarded in the thumbprint jwkWithOptionalParameters := map[string]interface{}{ "x": base64.RawURLEncoding.EncodeToString(privateKey.X.Bytes()), "y": base64.RawURLEncoding.EncodeToString(privateKey.Y.Bytes()), @@ -685,6 +685,91 @@ func TestParse_ProofWithExtraKeyMembersEC(t *testing.T) { } +func TestParse_ProofWithMalformedJwkEC(t *testing.T) { + // Arrange + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } + + tokenClaims := dpop.ProofTokenClaims{ + RegisteredClaims: &jwt.RegisteredClaims{ + Issuer: "client", + Subject: "user", + Audience: jwt.ClaimStrings{"https://server.example.com/token"}, + ID: "random_id", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + Method: dpop.POST, + URL: "https://server.example.com/token", + } + httpUrl := url.URL{ + Scheme: "https", + Host: "server.example.com", + Path: "/token", + } + + testCases := []map[string]interface{}{ + { + "x": 1, + "y": base64.RawURLEncoding.EncodeToString(privateKey.Y.Bytes()), + "crv": privateKey.Curve.Params().Name, + "kty": "EC", + }, + { + "x": base64.RawURLEncoding.EncodeToString(privateKey.X.Bytes()), + "y": 1, + "crv": privateKey.Curve.Params().Name, + "kty": "EC", + }, + { + "x": base64.RawURLEncoding.EncodeToString(privateKey.X.Bytes()), + "y": base64.RawURLEncoding.EncodeToString(privateKey.Y.Bytes()), + "crv": 1, + "kty": "EC", + }, + { + "x": base64.RawURLEncoding.EncodeToString(privateKey.X.Bytes()), + "y": base64.RawURLEncoding.EncodeToString(privateKey.Y.Bytes()), + "crv": privateKey.Curve.Params().Name, + "kty": 1, + }, + } + + for _, testCase := range testCases { + t.Run("", func(t *testing.T) { + // Act + token := &jwt.Token{ + Header: map[string]interface{}{ + "typ": "dpop+jwt", + "alg": jwt.SigningMethodES256.Alg(), + "jwk": testCase, + }, + Claims: tokenClaims, + Method: jwt.SigningMethodES256, + } + tokenString, err := token.SignedString(privateKey) + if err != nil { + t.Error(err) + } + + parsedProof, err := dpop.Parse(tokenString, dpop.POST, &httpUrl, dpop.ParseOptions{}) + + // Assert + if err == nil { + t.Error("Expected error") + } + if err != nil { + AssertJoinedError(t, err, dpop.ErrInvalidProof) + } + if parsedProof != nil { + t.Errorf("Expected nil token") + } + }) + } +} + func TestParse_ProofWithExtraKeyMembersRSA(t *testing.T) { // Arrange rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) @@ -710,7 +795,7 @@ func TestParse_ProofWithExtraKeyMembersRSA(t *testing.T) { Path: "/token", } - // Set an optional member in the key used in the proof, the member should be disregarded in the thubprint + // Set an optional member in the key used in the proof, the member should be disregarded in the thumbprint jwkWithOptionalParameters := map[string]interface{}{ "n": base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()), "e": base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()), @@ -757,10 +842,85 @@ func TestParse_ProofWithExtraKeyMembersRSA(t *testing.T) { if parsedProof == nil { t.Error("Expected proof to be parsed") } +} + +func TestParse_ProofWithMalformedJwkRSA(t *testing.T) { + // Arrange + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Errorf("Error when generating RSA key: %v", err) + } + + tokenClaims := dpop.ProofTokenClaims{ + RegisteredClaims: &jwt.RegisteredClaims{ + Issuer: "client", + Subject: "user", + Audience: jwt.ClaimStrings{"https://server.example.com/token"}, + ID: "random_id", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + Method: dpop.POST, + URL: "https://server.example.com/token", + } + httpUrl := url.URL{ + Scheme: "https", + Host: "server.example.com", + Path: "/token", + } + + testCases := []map[string]interface{}{ + { + "n": 1, + "e": base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()), + "kty": "RSA", + }, + { + "n": base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()), + "e": 1, + "kty": "RSA", + }, + { + "n": base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()), + "e": base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()), + "kty": 1, + }, + } + for _, testCase := range testCases { + t.Run("", func(t *testing.T) { + // Act + token := &jwt.Token{ + Header: map[string]interface{}{ + "typ": "dpop+jwt", + "alg": jwt.SigningMethodRS512.Alg(), + "jwk": testCase, + }, + Claims: tokenClaims, + Method: jwt.SigningMethodRS512, + } + tokenString, err := token.SignedString(rsaKey) + if err != nil { + t.Error(err) + } + + parsedProof, err := dpop.Parse(tokenString, dpop.POST, &httpUrl, dpop.ParseOptions{}) + + // Assert + if err == nil { + t.Error("Expected error") + } + if err != nil { + AssertJoinedError(t, err, dpop.ErrInvalidProof) + } + if parsedProof != nil { + t.Errorf("Expected nil token") + } + }) + } } -func TestParse_ProofWithExtraKeyMembersOKT(t *testing.T) { +func TestParse_ProofWithExtraKeyMembersOKP(t *testing.T) { // Arrange public, private, err := ed25519.GenerateKey(rand.Reader) if err != nil { @@ -785,7 +945,7 @@ func TestParse_ProofWithExtraKeyMembersOKT(t *testing.T) { Path: "/token", } - // Set an optional member in the key used in the proof, the member should be disregarded in the thubprint + // Set an optional member in the key used in the proof, the member should be disregarded in the thumbprint jwkWithOptionalParameters := map[string]interface{}{ "ext": true, "crv": "Ed25519", @@ -832,7 +992,77 @@ func TestParse_ProofWithExtraKeyMembersOKT(t *testing.T) { if parsedProof == nil { t.Error("Expected proof to be parsed") } +} + +func TestParse_ProofWithMalformedJwkOKP(t *testing.T) { + // Arrange + public, private, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Errorf("Error when generating RSA key: %v", err) + } + tokenClaims := dpop.ProofTokenClaims{ + RegisteredClaims: &jwt.RegisteredClaims{ + Issuer: "client", + Subject: "user", + Audience: jwt.ClaimStrings{"https://server.example.com/token"}, + ID: "random_id", + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + Method: dpop.POST, + URL: "https://server.example.com/token", + } + httpUrl := url.URL{ + Scheme: "https", + Host: "server.example.com", + Path: "/token", + } + + testCases := []map[string]interface{}{ + { + "crv": "Ed25519", + "x": 1, + "kty": "OKP", + }, + { + "crv": "Ed25519", + "x": base64.RawURLEncoding.EncodeToString(public), + "kty": 1, + }, + } + + for _, testCase := range testCases { + t.Run("", func(t *testing.T) { + // Act + token := &jwt.Token{ + Header: map[string]interface{}{ + "typ": "dpop+jwt", + "alg": jwt.SigningMethodEdDSA.Alg(), + "jwk": testCase, + }, + Claims: tokenClaims, + Method: jwt.SigningMethodEdDSA, + } + tokenString, err := token.SignedString(private) + if err != nil { + t.Error(err) + } + + parsedProof, err := dpop.Parse(tokenString, dpop.POST, &httpUrl, dpop.ParseOptions{}) + + // Assert + if err == nil { + t.Error("Expected error") + } + if err != nil { + AssertJoinedError(t, err, dpop.ErrInvalidProof) + } + if parsedProof != nil { + t.Errorf("Expected nil token") + } + }) + } } func TestParse_ProofWithLeadingZeroesEC(t *testing.T) {