diff --git a/selfservice/strategy/oidc/strategy_registration.go b/selfservice/strategy/oidc/strategy_registration.go index 5ed061119e7b..9a883d284924 100644 --- a/selfservice/strategy/oidc/strategy_registration.go +++ b/selfservice/strategy/oidc/strategy_registration.go @@ -43,6 +43,12 @@ var jsonnetCache, _ = ristretto.NewCache(&ristretto.Config[[]byte, []byte]{ type MetadataType string +type OIDCProviderData struct { + Provider string `json:"provider"` + Tokens *identity.CredentialsOIDCEncryptedTokens `json:"tokens"` + Claims Claims `json:"claims"` +} + type VerifiedAddress struct { Value string `json:"value"` Via identity.VerifiableAddressType `json:"via"` @@ -53,6 +59,8 @@ const ( PublicMetadata MetadataType = "identity.metadata_public" AdminMetadata MetadataType = "identity.metadata_admin" + + InternalContextKeyProviderData = "provider_data" ) func (s *Strategy) RegisterRegistrationRoutes(r *x.RouterPublic) { @@ -216,6 +224,27 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat return errors.WithStack(flow.ErrCompletedByStrategy) } + providerDataKey := flow.PrefixInternalContextKey(s.ID(), InternalContextKeyProviderData) + if oidcProviderData := gjson.GetBytes(f.InternalContext, providerDataKey); oidcProviderData.IsObject() { + var providerData OIDCProviderData + if err = json.Unmarshal([]byte(oidcProviderData.Raw), &providerData); err != nil { + return s.handleError(ctx, w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected OIDC provider data in internal context to be an object but got: %s", err))) + } + if pid != providerData.Provider { + return s.handleError(ctx, w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected OIDC provider data in internal context to have matching provider but got: %s", providerData.Provider))) + } + container := &AuthCodeContainer{ + FlowID: f.ID.String(), + Traits: p.Traits, + TransientPayload: f.TransientPayload, + } + _, err = s.processRegistration(ctx, w, r, f, providerData.Tokens, &providerData.Claims, provider, container) + if err != nil { + return s.handleError(ctx, w, r, f, pid, container.Traits, err) + } + return errors.WithStack(flow.ErrCompletedByStrategy) + } + state, pkce, err := s.GenerateState(ctx, provider, f.ID) if err != nil { return s.handleError(ctx, w, r, f, pid, nil, err) @@ -313,6 +342,13 @@ func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWrite return nil, nil } + providerDataKey := flow.PrefixInternalContextKey(s.ID(), InternalContextKeyProviderData) + if hasOIDCProviderData := gjson.GetBytes(rf.InternalContext, providerDataKey).IsObject(); !hasOIDCProviderData { + if internalContext, err := sjson.SetBytes(rf.InternalContext, providerDataKey, &OIDCProviderData{Provider: provider.Config().ID, Tokens: token, Claims: *claims}); err == nil { + rf.InternalContext = internalContext + } + } + fetch := fetcher.NewFetcher(fetcher.WithClient(s.d.HTTPClient(ctx)), fetcher.WithCache(jsonnetCache, 60*time.Minute)) jsonnetMapperSnippet, err := fetch.FetchContext(ctx, provider.Config().Mapper) if err != nil { @@ -351,6 +387,10 @@ func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWrite return nil, s.handleError(ctx, w, r, rf, provider.Config().ID, i.Traits, err) } + if internalContext, err := sjson.DeleteBytes(rf.InternalContext, providerDataKey); err == nil { + rf.InternalContext = internalContext + } + return nil, nil } diff --git a/test/e2e/cypress/integration/profiles/oidc/registration/success.spec.ts b/test/e2e/cypress/integration/profiles/oidc/registration/success.spec.ts index 132845623f31..ee8a63b1bb68 100644 --- a/test/e2e/cypress/integration/profiles/oidc/registration/success.spec.ts +++ b/test/e2e/cypress/integration/profiles/oidc/registration/success.spec.ts @@ -94,7 +94,56 @@ context("Social Sign Up Successes", () => { cy.triggerOidc(app) cy.location("pathname").should((loc) => { - expect(loc).to.be.oneOf(["/welcome", "/", "/sessions"]) + expect(loc).to.be.oneOf([ + "/welcome", + "/", + "/sessions", + "/verification", + ]) + }) + + cy.getSession().should((session) => { + shouldSession(email)(session) + expect(session.identity.traits.consent).to.equal(true) + }) + }) + + it("should redirect to oidc provider only once", () => { + const email = gen.email() + + cy.registerOidc({ + app, + email, + expectSession: false, + route: registration, + }) + + cy.get(appPrefix(app) + '[name="traits.email"]').should( + "have.value", + email, + ) + + cy.get('[name="traits.consent"][type="checkbox"]') + .siblings("label") + .click() + cy.get('[name="traits.newsletter"][type="checkbox"]') + .siblings("label") + .click() + cy.get('[name="traits.website"]').type(website) + + cy.intercept("GET", "http://*/oauth2/auth*").as("additionalRedirect") + + cy.triggerOidc(app) + + cy.get("@additionalRedirect").should("not.exist") + + cy.location("pathname").should((loc) => { + expect(loc).to.be.oneOf([ + "/welcome", + "/", + "/sessions", + "/verification", + ]) }) cy.getSession().should((session) => { diff --git a/test/e2e/cypress/integration/profiles/oidc/settings/success.spec.ts b/test/e2e/cypress/integration/profiles/oidc/settings/success.spec.ts index 8eaec262b303..5ff4d257532d 100644 --- a/test/e2e/cypress/integration/profiles/oidc/settings/success.spec.ts +++ b/test/e2e/cypress/integration/profiles/oidc/settings/success.spec.ts @@ -40,12 +40,26 @@ context("Social Sign In Settings Success", () => { cy.get("#accept").click() cy.get('input[name="traits.website"]').clear().type(website) + + cy.intercept({ + method: "POST", + url: "http://localhost:4433/self-service/registration*", + query: { flow: "*" }, + }).as("registrationCall") cy.triggerOidc(app, "hydra") - cy.get('[data-testid="ui/message/1010016"]').should( - "contain.text", - "as another way to sign in.", - ) + if (app === "react") { + cy.wait("@registrationCall").should((intercept) => { + expect(intercept.response.body.ui.messages[0].text).contain( + "as another way to sign in.", + ) + }) + } else { + cy.get('[data-testid="ui/message/1010016"]').should( + "contain.text", + "as another way to sign in.", + ) + } cy.noSession() }