diff --git a/selfservice/strategy/oidc/strategy.go b/selfservice/strategy/oidc/strategy.go index 2e536c8a35cf..8cf7b5069866 100644 --- a/selfservice/strategy/oidc/strategy.go +++ b/selfservice/strategy/oidc/strategy.go @@ -21,6 +21,7 @@ import ( "golang.org/x/oauth2" "github.com/ory/kratos/cipher" + "github.com/ory/kratos/selfservice/flowhelpers" "github.com/ory/kratos/selfservice/sessiontokenexchange" "github.com/ory/x/jsonnetsecure" "github.com/ory/x/otelx" @@ -484,15 +485,44 @@ func (s *Strategy) ExchangeCode(ctx context.Context, provider Provider, code str return token, err } -func (s *Strategy) populateMethod(r *http.Request, c *container.Container, message func(provider string) *text.Message) error { +func (s *Strategy) populateMethod(r *http.Request, f flow.Flow, message func(provider string) *text.Message) error { conf, err := s.Config(r.Context()) if err != nil { return err } + providers := conf.Providers + + if lf, ok := f.(*login.Flow); ok && lf.IsForced() { + if _, id, c := flowhelpers.GuessForcedLoginIdentifier(r, s.d, lf, s.ID()); id != nil { + if c == nil { + // no OIDC credentials, don't add any providers + providers = nil + } else { + var credentials identity.CredentialsOIDC + if err := json.Unmarshal(c.Config, &credentials); err != nil { + // failed to read OIDC credentials, don't add any providers + providers = nil + } else { + // add only providers that can actually be used to log in as this identity + providers = make([]Configuration, 0, len(conf.Providers)) + for i := range conf.Providers { + for j := range credentials.Providers { + if conf.Providers[i].ID == credentials.Providers[j].Provider { + providers = append(providers, conf.Providers[i]) + break + } + } + } + } + } + } + } + // does not need sorting because there is only one field + c := f.GetUI() c.SetCSRF(s.d.GenerateCSRFToken(r)) - AddProviders(c, conf.Providers, message) + AddProviders(c, providers, message) return nil } diff --git a/selfservice/strategy/oidc/strategy_login.go b/selfservice/strategy/oidc/strategy_login.go index c81537015bc9..a3048df1e15b 100644 --- a/selfservice/strategy/oidc/strategy_login.go +++ b/selfservice/strategy/oidc/strategy_login.go @@ -48,7 +48,7 @@ func (s *Strategy) PopulateLoginMethod(r *http.Request, requestedAAL identity.Au return nil } - return s.populateMethod(r, l.UI, text.NewInfoLoginWith) + return s.populateMethod(r, l, text.NewInfoLoginWith) } // Update Login Flow with OpenID Connect Method diff --git a/selfservice/strategy/oidc/strategy_registration.go b/selfservice/strategy/oidc/strategy_registration.go index d07e6c6fa8ed..d3f3b217f760 100644 --- a/selfservice/strategy/oidc/strategy_registration.go +++ b/selfservice/strategy/oidc/strategy_registration.go @@ -62,7 +62,7 @@ func (s *Strategy) RegisterRegistrationRoutes(r *x.RouterPublic) { } func (s *Strategy) PopulateRegistrationMethod(r *http.Request, f *registration.Flow) error { - return s.populateMethod(r, f.UI, text.NewInfoRegistrationWith) + return s.populateMethod(r, f, text.NewInfoRegistrationWith) } // Update Registration Flow with OpenID Connect Method diff --git a/selfservice/strategy/oidc/strategy_settings_test.go b/selfservice/strategy/oidc/strategy_settings_test.go index 69f2bc03a560..a6b819c94202 100644 --- a/selfservice/strategy/oidc/strategy_settings_test.go +++ b/selfservice/strategy/oidc/strategy_settings_test.go @@ -327,7 +327,17 @@ func TestSettingsStrategy(t *testing.T) { _, res, req := unlink(t, agent, provider) assert.Contains(t, res.Request.URL.String(), uiTS.URL+"/login") - rs, _, err := testhelpers.NewSDKCustomClient(publicTS, agents[agent]).FrontendApi.GetSettingsFlow(context.Background()).Id(req.Id).Execute() + fa := testhelpers.NewSDKCustomClient(publicTS, agents[agent]).FrontendApi + lf, _, err := fa.GetLoginFlow(context.Background()).Id(res.Request.URL.Query()["flow"][0]).Execute() + require.NoError(t, err) + + for _, node := range lf.Ui.Nodes { + if node.Group == "oidc" && node.Attributes.UiNodeInputAttributes.Name == "provider" { + assert.Contains(t, []string{"ory", "github"}, node.Attributes.UiNodeInputAttributes.Value) + } + } + + rs, _, err := fa.GetSettingsFlow(context.Background()).Id(req.Id).Execute() require.NoError(t, err) require.EqualValues(t, flow.StateShowForm, rs.State) @@ -554,7 +564,17 @@ func TestSettingsStrategy(t *testing.T) { _, res, req := link(t, agent, provider) assert.Contains(t, res.Request.URL.String(), uiTS.URL+"/login") - rs, _, err := testhelpers.NewSDKCustomClient(publicTS, agents[agent]).FrontendApi.GetSettingsFlow(context.Background()).Id(req.Id).Execute() + fa := testhelpers.NewSDKCustomClient(publicTS, agents[agent]).FrontendApi + lf, _, err := fa.GetLoginFlow(context.Background()).Id(res.Request.URL.Query()["flow"][0]).Execute() + require.NoError(t, err) + + for _, node := range lf.Ui.Nodes { + if node.Group == "oidc" && node.Attributes.UiNodeInputAttributes.Name == "provider" { + assert.Contains(t, []string{"ory", "github"}, node.Attributes.UiNodeInputAttributes.Value) + } + } + + rs, _, err := fa.GetSettingsFlow(context.Background()).Id(req.Id).Execute() require.NoError(t, err) require.EqualValues(t, flow.StateShowForm, rs.State) diff --git a/test/e2e/cypress/integration/profiles/oidc/settings/success.spec.ts b/test/e2e/cypress/integration/profiles/oidc/settings/success.spec.ts index dadbc06d9453..674e24e0d668 100644 --- a/test/e2e/cypress/integration/profiles/oidc/settings/success.spec.ts +++ b/test/e2e/cypress/integration/profiles/oidc/settings/success.spec.ts @@ -196,6 +196,19 @@ context("Social Sign In Settings Success", () => { hydraReauthFails() }) + it("should show only linked providers during reauth", () => { + cy.shortPrivilegedSessionTime() + + cy.get('input[name="password"]').type(gen.password()) + cy.get('[value="password"]').click() + + cy.location("pathname").should("equal", "/login") + + cy.get('[value="hydra"]').should("exist") + cy.get('[value="google"]').should("not.exist") + cy.get('[value="github"]').should("not.exist") + }) + it("settings screen stays intact when the original sign up method gets removed", () => { const expectSettingsOk = () => { cy.get('[value="google"]', { timeout: 1000 })