diff --git a/.schemastore/config.schema.json b/.schemastore/config.schema.json index 51e1783e9757..4f2a95f9cdcc 100644 --- a/.schemastore/config.schema.json +++ b/.schemastore/config.schema.json @@ -567,6 +567,12 @@ "enum": ["id_token", "userinfo"], "default": "id_token", "examples": ["id_token", "userinfo"] + }, + "pkcs_method": { + "title": "PKCS Method", + "description": "PKCSMethod is a config to enable PKCS (Proof Key for Code Exchange) using the generic provider. Can be either `S256` (sends code_challenge and code_challenge_method=S256) to authorization endpoint) and `code_verifier` to token endpoint. Can be `plain` if its impossible to support S256. (sends code verifier == code_challenge and code_challenge_method=plain to authorization endpoint). Can be empty, in which case PKCS is disabled.", + "type": "string", + "enum": ["S256", "plain"] } }, "additionalProperties": false, diff --git a/embedx/config.schema.json b/embedx/config.schema.json index a2802ed0a0b7..c8263ab586f0 100644 --- a/embedx/config.schema.json +++ b/embedx/config.schema.json @@ -567,6 +567,12 @@ "enum": ["id_token", "userinfo"], "default": "id_token", "examples": ["id_token", "userinfo"] + }, + "pkcs_method": { + "title": "PKCS Method", + "description": "PKCSMethod is a config to enable PKCS (Proof Key for Code Exchange) using the generic provider. Can be either `S256` (sends code_challenge and code_challenge_method=S256) to authorization endpoint) and `code_verifier` to token endpoint. Can be `plain` if its impossible to support S256. (sends code verifier == code_challenge and code_challenge_method=plain to authorization endpoint). Can be empty, in which case PKCS is disabled.", + "type": "string", + "enum": ["S256", "plain"] } }, "additionalProperties": false, diff --git a/selfservice/strategy/oidc/.snapshots/TestStrategy-method=TestPopulateSignUpMethod.json b/selfservice/strategy/oidc/.snapshots/TestStrategy-method=TestPopulateSignUpMethod.json index 6b0a9ba98cf8..ee8eda3fb22b 100644 --- a/selfservice/strategy/oidc/.snapshots/TestStrategy-method=TestPopulateSignUpMethod.json +++ b/selfservice/strategy/oidc/.snapshots/TestStrategy-method=TestPopulateSignUpMethod.json @@ -80,6 +80,28 @@ } } }, + { + "type": "input", + "group": "oidc", + "attributes": { + "name": "provider", + "type": "submit", + "value": "providerWithPKCS", + "disabled": false, + "node_type": "input" + }, + "messages": [], + "meta": { + "label": { + "id": 1040002, + "text": "Sign up with providerWithPKCS", + "type": "info", + "context": { + "provider": "providerWithPKCS" + } + } + } + }, { "type": "input", "group": "oidc", diff --git a/selfservice/strategy/oidc/provider_config.go b/selfservice/strategy/oidc/provider_config.go index f3db2e120e01..39c1397158a9 100644 --- a/selfservice/strategy/oidc/provider_config.go +++ b/selfservice/strategy/oidc/provider_config.go @@ -119,6 +119,13 @@ type Configuration struct { // endpoint to get the claims) or `id_token` (takes the claims from the id // token). It defaults to `id_token`. ClaimsSource string `json:"claims_source"` + + // PKCSMethod is a config to enable PKCS (Proof Key for Code Exchange) + // using the generic provider. Can be either `S256` (sends code_challenge and code_challenge_method=S256) + // to authorization endpoint) and `code_verifier` to token endpoint. + // Can be `plain` if its impossible to support S256. (sends code verifier == code_challenge and code_challenge_method=plain to authorization endpoint) + // Can be empty, in which case PKCS is disabled. + PKCSMethod string `json:"pkcs_method"` } func (p Configuration) Redir(public *url.URL) string { diff --git a/selfservice/strategy/oidc/provider_generic_oidc.go b/selfservice/strategy/oidc/provider_generic_oidc.go index 146505165807..5128b4e2c3a7 100644 --- a/selfservice/strategy/oidc/provider_generic_oidc.go +++ b/selfservice/strategy/oidc/provider_generic_oidc.go @@ -85,7 +85,9 @@ func (g *ProviderGenericOIDC) OAuth2(ctx context.Context) (*oauth2.Config, error func (g *ProviderGenericOIDC) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption { var options []oauth2.AuthCodeOption - + if g.config.PKCSMethod != "" { + options = g.addPKCSURLOptions(r, options) + } if isForced(r) { options = append(options, oauth2.SetAuthURLParam("prompt", "login")) } @@ -96,6 +98,25 @@ func (g *ProviderGenericOIDC) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption return options } +func (g *ProviderGenericOIDC) addPKCSURLOptions(r ider, options []oauth2.AuthCodeOption) []oauth2.AuthCodeOption { + flow, err := g.reg.LoginFlowPersister().GetLoginFlow(context.Background(), r.GetID()) + if err != nil { + return options + } + pkcsContext, err := GetPKCSContext(flow) + if err != nil { + return options + } + if pkcsContext.Verifier != "" && pkcsContext.Method == "S256" { + options = append(options, oauth2.S256ChallengeOption(pkcsContext.Verifier)) + } + if pkcsContext.Verifier != "" && pkcsContext.Method == "plain" { + options = append(options, oauth2.SetAuthURLParam("code_challenge", string(pkcsContext.Verifier))) + options = append(options, oauth2.SetAuthURLParam("code_challenge_method", string(pkcsContext.Method))) + } + return options +} + func (g *ProviderGenericOIDC) verifyAndDecodeClaimsWithProvider(ctx context.Context, provider *gooidc.Provider, raw string) (*Claims, error) { token, err := provider.VerifierContext(g.withHTTPClientContext(ctx), &gooidc.Config{ClientID: g.config.ClientID}).Verify(ctx, raw) if err != nil { diff --git a/selfservice/strategy/oidc/provider_generic_test.go b/selfservice/strategy/oidc/provider_generic_test.go index 7c90da7e3ec5..2c6700097ab0 100644 --- a/selfservice/strategy/oidc/provider_generic_test.go +++ b/selfservice/strategy/oidc/provider_generic_test.go @@ -9,6 +9,8 @@ import ( "net/url" "testing" + "golang.org/x/oauth2" + "github.com/ory/kratos/driver" "github.com/ory/kratos/driver/config" "github.com/ory/kratos/internal" @@ -35,7 +37,13 @@ func makeOIDCClaims() json.RawMessage { return claims } -func makeAuthCodeURL(t *testing.T, r *login.Flow, reg *driver.RegistryDefault) string { +func makeAuthCodeURL(t *testing.T, r *login.Flow, reg *driver.RegistryDefault, pkcsMethods ...string) string { + var pkcsMethod string + if len(pkcsMethods) > 0 { + pkcsMethod = pkcsMethods[0] + } else { + pkcsMethod = "" + } p := oidc.NewProviderGenericOIDC(&oidc.Configuration{ Provider: "generic", ID: "valid", @@ -43,6 +51,7 @@ func makeAuthCodeURL(t *testing.T, r *login.Flow, reg *driver.RegistryDefault) s ClientSecret: "secret", IssuerURL: "https://accounts.google.com", Mapper: "file://./stub/hydra.schema.json", + PKCSMethod: pkcsMethod, RequestedClaims: makeOIDCClaims(), }, reg) c, err := p.(oidc.OAuth2Provider).OAuth2(context.Background()) @@ -94,3 +103,55 @@ func TestProviderGenericOIDC_AddAuthCodeURLOptions(t *testing.T) { assert.Contains(t, makeAuthCodeURL(t, r, reg), "claims="+url.QueryEscape(string(makeOIDCClaims()))) }) } + +func TestProviderGenericOIDC_PKCS(t *testing.T) { + ctx := context.Background() + conf, reg := internal.NewFastRegistryWithMocks(t) + conf.MustSet(ctx, config.ViperKeyPublicBaseURL, "https://ory.sh") + + t.Run("case=PKCSMethod is set to S256", func(t *testing.T) { + r := &login.Flow{ID: x.NewUUID(), Refresh: true} + reg.LoginFlowPersister().CreateLoginFlow(ctx, r) + err := oidc.SetPKCSContext(r, oidc.PkcsContext{ + Method: "S256", + Verifier: oauth2.GenerateVerifier(), + }) + require.NoError(t, err) + err = reg.LoginFlowPersister().UpdateLoginFlow(ctx, r) + require.NoError(t, err) + actual, err := url.ParseRequestURI(makeAuthCodeURL(t, r, reg, "S256")) + require.NoError(t, err) + assert.Contains(t, actual.Query(), "code_challenge") + t.Logf("code_challenge: %s", actual.Query().Get("code_challenge")) + assert.Contains(t, actual.Query().Get("code_challenge_method"), "S256") + t.Logf("code_challenge_method: %s", actual.Query().Get("code_challenge_method")) + }) + t.Run("case=PKCSMethod is set to plain", func(t *testing.T) { + r := &login.Flow{ID: x.NewUUID(), Refresh: true} + reg.LoginFlowPersister().CreateLoginFlow(ctx, r) + verifier := oauth2.GenerateVerifier() + err := oidc.SetPKCSContext(r, oidc.PkcsContext{ + Method: "plain", + Verifier: verifier, + }) + require.NoError(t, err) + err = reg.LoginFlowPersister().UpdateLoginFlow(ctx, r) + require.NoError(t, err) + actual, err := url.ParseRequestURI(makeAuthCodeURL(t, r, reg, "plain")) + require.NoError(t, err) + assert.Contains(t, actual.Query(), "code_challenge") + t.Logf("code_challenge: %s", actual.Query().Get("code_challenge")) + assert.Contains(t, actual.Query().Get("code_challenge_method"), "plain") + t.Logf("code_challenge_method: %s", actual.Query().Get("code_challenge_method")) + assert.Equal(t, actual.Query().Get("code_challenge"), verifier) + }) + t.Run("case=PKCSMethod is empty", func(t *testing.T) { + r := &login.Flow{ID: x.NewUUID(), Refresh: true} + actual, err := url.ParseRequestURI(makeAuthCodeURL(t, r, reg)) + require.NoError(t, err) + assert.NotContains(t, actual.Query(), "code_challenge") + t.Logf("code_challenge: %s", actual.Query().Get("code_challenge")) + assert.NotContains(t, actual.Query(), "code_challenge_method") + t.Logf("code_challenge_method: %s", actual.Query().Get("code_challenge_method")) + }) +} diff --git a/selfservice/strategy/oidc/strategy.go b/selfservice/strategy/oidc/strategy.go index 603ab5a5ea62..221323d9795a 100644 --- a/selfservice/strategy/oidc/strategy.go +++ b/selfservice/strategy/oidc/strategy.go @@ -16,10 +16,10 @@ import ( "strings" "time" - "github.com/ory/x/sqlxx" - "golang.org/x/exp/maps" + "github.com/ory/x/sqlxx" + "github.com/ory/x/urlx" "go.opentelemetry.io/otel/attribute" @@ -426,7 +426,7 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt var et *identity.CredentialsOIDCEncryptedTokens switch p := provider.(type) { case OAuth2Provider: - token, err := s.ExchangeCode(r.Context(), provider, code) + token, err := s.ExchangeCode(r.Context(), provider, code, req) if err != nil { s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err)) return @@ -510,7 +510,7 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt } } -func (s *Strategy) ExchangeCode(ctx context.Context, provider Provider, code string) (token *oauth2.Token, err error) { +func (s *Strategy) ExchangeCode(ctx context.Context, provider Provider, code string, flow flow.Flow) (token *oauth2.Token, err error) { ctx, span := s.d.Tracer(ctx).Tracer().Start(ctx, "strategy.oidc.ExchangeCode") defer otelx.End(span, &err) span.SetAttributes(attribute.String("provider_id", provider.Config().ID)) @@ -525,11 +525,23 @@ func (s *Strategy) ExchangeCode(ctx context.Context, provider Provider, code str return nil, err } } - client := s.d.HTTPClient(ctx) ctx = context.WithValue(ctx, oauth2.HTTPClient, client.HTTPClient) - token, err = te.Exchange(ctx, code) - return token, err + switch loginFlow := flow.(type) { + case *login.Flow: + if provider.Config().PKCSMethod != "" { + pkcsContext, err := GetPKCSContext(loginFlow) + if err != nil { + return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to decode PKCS context: %s", err)) + } + if pkcsContext.Verifier != "" && (pkcsContext.Method == "S256" || pkcsContext.Method == "plain") { + return te.Exchange(ctx, code, oauth2.VerifierOption(pkcsContext.Verifier)) + } else { + return nil, errors.Errorf("Invalid PKCS method: %s or empty verifier: %s", pkcsContext.Method, pkcsContext.Verifier) + } + } + } + return te.Exchange(ctx, code) default: return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The chosen provider is not capable of exchanging an OAuth 2.0 code for an access token.")) } diff --git a/selfservice/strategy/oidc/strategy_login.go b/selfservice/strategy/oidc/strategy_login.go index d290276492a7..d5edb77d25d5 100644 --- a/selfservice/strategy/oidc/strategy_login.go +++ b/selfservice/strategy/oidc/strategy_login.go @@ -10,22 +10,22 @@ import ( "strings" "time" - "github.com/ory/kratos/selfservice/strategy/idfirst" - "github.com/ory/x/stringsx" - - "github.com/ory/kratos/selfservice/flowhelpers" - "github.com/julienschmidt/httprouter" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "golang.org/x/oauth2" "github.com/ory/kratos/session" + "github.com/ory/kratos/text" "github.com/ory/kratos/ui/node" "github.com/ory/x/otelx" "github.com/ory/x/sqlcon" + "github.com/ory/x/stringsx" "github.com/ory/kratos/selfservice/flow/registration" - - "github.com/ory/kratos/text" + "github.com/ory/kratos/selfservice/flowhelpers" + "github.com/ory/kratos/selfservice/strategy/idfirst" "github.com/ory/kratos/continuity" @@ -48,6 +48,13 @@ func (s *Strategy) RegisterLoginRoutes(r *x.RouterPublic) { s.setRoutes(r) } +const internalContextPKCSPath = "pkcs" + +type PkcsContext struct { + Method string `json:"method"` + Verifier string `json:"verifier"` +} + // Update Login Flow with OpenID Connect Method // // swagger:model updateLoginFlowWithOidcMethod @@ -255,6 +262,17 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, } state := generateState(f.ID.String()) + + if provider.Config().PKCSMethod != "" { + err := SetPKCSContext(f, PkcsContext{ + Method: provider.Config().PKCSMethod, + Verifier: oauth2.GenerateVerifier(), + }) + if err != nil { + return nil, s.handleError(w, r, f, pid, nil, err) + } + } + if code, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(ctx, f.ID); hasCode { state.setCode(code.InitCode) } @@ -386,3 +404,34 @@ func (s *Strategy) PopulateLoginMethodIdentifierFirstCredentials(r *http.Request func (s *Strategy) PopulateLoginMethodIdentifierFirstIdentification(r *http.Request, f *login.Flow) error { return s.populateMethod(r, f, text.NewInfoLoginWith) } + +func SetPKCSContext(flow flow.InternalContexter, context PkcsContext) error { + if flow.GetInternalContext() == nil { + flow.EnsureInternalContext() + } + bytes, err := sjson.SetBytes( + flow.GetInternalContext(), + internalContextPKCSPath, + context, + ) + if err != nil { + return err + } + flow.SetInternalContext(bytes) + + return nil +} + +func GetPKCSContext(flow flow.InternalContexter) (*PkcsContext, error) { + if flow.GetInternalContext() == nil { + flow.EnsureInternalContext() + } + raw := gjson.GetBytes(flow.GetInternalContext(), internalContextPKCSPath) + if !raw.IsObject() { + return nil, nil + } + var context PkcsContext + err := json.Unmarshal([]byte(raw.Raw), &context) + + return &context, err +} diff --git a/selfservice/strategy/oidc/strategy_test.go b/selfservice/strategy/oidc/strategy_test.go index 25d45486391b..f8cb7ed37171 100644 --- a/selfservice/strategy/oidc/strategy_test.go +++ b/selfservice/strategy/oidc/strategy_test.go @@ -86,6 +86,9 @@ func TestStrategy(t *testing.T) { newOIDCProvider(t, ts, remotePublic, remoteAdmin, "claimsViaUserInfo", func(c *oidc.Configuration) { c.ClaimsSource = oidc.ClaimsSourceUserInfo }), + newOIDCProvider(t, ts, remotePublic, remoteAdmin, "providerWithPKCS", func(c *oidc.Configuration) { + c.PKCSMethod = "S256" + }), oidc.Configuration{ Provider: "generic", ID: "invalid-issuer", @@ -1072,7 +1075,6 @@ func TestStrategy(t *testing.T) { }) }) } - }) t.Run("case=should fail to register and return fresh login flow if email is already being used by password credentials", func(t *testing.T) {