Skip to content

Commit

Permalink
fix: SAML issues (#2041)
Browse files Browse the repository at this point in the history
Rename identities table columns for more clarity. Rename parameters,
arguments etc. to accommodate these changes.

Change that the SAML provider domain is persisted in the identities
table as the provider ID. Use the SAML Entity ID/Issuer ID of the
IdP instead.

Introduce saml identity entity (including migrations and a persister)
as a specialization of an identity to allow for determining the
correct provider name to return to the client/frontend and for assisting
in determining whether an identity is a SAML identity (i.e. SAML
identities should have a corresponding SAML Identity instance while
OAuth/OIDC entities do not).
# Conflicts:
#	backend/config/config_default.go
#	backend/handler/thirdparty_test.go
#	backend/test/fixtures/thirdparty/identities.yaml
#	backend/thirdparty/provider_facebook.go
  • Loading branch information
lfleischmann authored and FreddyDevelop committed Jan 31, 2025
1 parent 120b598 commit 58dc7ef
Show file tree
Hide file tree
Showing 33 changed files with 289 additions and 106 deletions.
12 changes: 6 additions & 6 deletions backend/config/config_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,32 +121,32 @@ func DefaultConfig() *Config {
Apple: ThirdPartyProvider{
DisplayName: "Apple",
AllowLinking: true,
Name: "apple",
ID: "apple",
},
Discord: ThirdPartyProvider{
DisplayName: "Discord",
AllowLinking: true,
Name: "discord",
ID: "discord",
},
LinkedIn: ThirdPartyProvider{
DisplayName: "LinkedIn",
AllowLinking: true,
Name: "linkedin",
ID: "linkedin",
},
Microsoft: ThirdPartyProvider{
DisplayName: "Microsoft",
AllowLinking: true,
Name: "microsoft",
ID: "microsoft",
},
GitHub: ThirdPartyProvider{
DisplayName: "GitHub",
AllowLinking: true,
Name: "github",
ID: "github",
},
Google: ThirdPartyProvider{
DisplayName: "Google",
AllowLinking: true,
Name: "google",
ID: "google",
},
},
},
Expand Down
14 changes: 7 additions & 7 deletions backend/config/config_third_party.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (t *ThirdParty) PostProcess() error {
for key, provider := range t.CustomProviders {
// add prefix per default to ensure built-in and custom providers can be distinguished
keyLower := strings.ToLower(key)
provider.Name = "custom_" + keyLower
provider.ID = "custom_" + keyLower
providers[keyLower] = provider
}
t.CustomProviders = providers
Expand Down Expand Up @@ -210,7 +210,7 @@ func (p *CustomThirdPartyProviders) Validate() error {
if err != nil {
return fmt.Errorf(
"failed to validate third party provider %s: %w",
strings.TrimPrefix(v.Name, "custom_"),
strings.TrimPrefix(v.ID, "custom_"),
err,
)
}
Expand Down Expand Up @@ -249,9 +249,9 @@ type CustomThirdPartyProvider struct {
//
// Required if `use_discovery` is false or omitted.
AuthorizationEndpoint string `yaml:"authorization_endpoint" json:"authorization_endpoint,omitempty" koanf:"authorization_endpoint"`
// `name` is a unique identifier for the provider, derived from the key in the `custom_providers` map, by
// `ID` is a unique identifier for the provider, derived from the key in the `custom_providers` map, by
// concatenating the prefix "custom_". This allows distinguishing between built-in and custom providers at runtime.
Name string `jsonschema:"-" yaml:"-" json:"-" koanf:"-"`
ID string `jsonschema:"-" yaml:"-" json:"-" koanf:"-"`
// `issuer` is the provider's issuer identifier. It should be a URL that uses the "https"
// scheme and has no query or fragment components.
//
Expand Down Expand Up @@ -443,9 +443,9 @@ type ThirdPartyProvider struct {
//
// Required if the provider is `enabled`.
Secret string `yaml:"secret" json:"secret,omitempty" koanf:"secret"`
// `name` is a unique name/slug/identifier for the provider. It is the lowercased key of the corresponding field
// in ThirdPartyProviders. See also: CustomThirdPartyProvider.Name.
Name string `jsonschema:"-" yaml:"-" json:"-" koanf:"-"`
// `ID` is a unique name/slug/identifier for the provider. It is the lowercased key of the corresponding field
// in ThirdPartyProviders. See also: CustomThirdPartyProvider.ID.
ID string `jsonschema:"-" yaml:"-" json:"-" koanf:"-"`
}

func (ThirdPartyProvider) JSONSchemaExtend(schema *jsonschema.Schema) {
Expand Down
4 changes: 2 additions & 2 deletions backend/dto/admin/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ type Identity struct {
func FromIdentityModel(model models.Identity) Identity {
return Identity{
ID: model.ID,
ProviderID: model.ProviderID,
ProviderName: model.ProviderName,
ProviderID: model.ProviderUserID,
ProviderName: model.ProviderID,
EmailID: model.EmailID,
CreatedAt: model.CreatedAt,
UpdatedAt: model.UpdatedAt,
Expand Down
16 changes: 11 additions & 5 deletions backend/dto/thirdparty.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,29 @@ func FromIdentityModel(identity *models.Identity, cfg *config.Config) *Identity
}

return &Identity{
ID: identity.ProviderID,
ID: identity.ProviderUserID,
Provider: getProviderDisplayName(identity, cfg),
}
}

func getProviderDisplayName(identity *models.Identity, cfg *config.Config) string {
if strings.HasPrefix(identity.ProviderName, "custom_") {
providerNameWithoutPrefix := strings.TrimPrefix(identity.ProviderName, "custom_")
if identity.SamlIdentity != nil {
for _, ip := range cfg.Saml.IdentityProviders {
if ip.Enabled && ip.Domain == identity.SamlIdentity.Domain {
return ip.Name
}
}
} else if strings.HasPrefix(identity.ProviderID, "custom_") {
providerNameWithoutPrefix := strings.TrimPrefix(identity.ProviderID, "custom_")
return cfg.ThirdParty.CustomProviders[providerNameWithoutPrefix].DisplayName
} else {
s := structs.New(config.ThirdPartyProviders{})
for _, field := range s.Fields() {
if strings.ToLower(field.Name()) == strings.ToLower(identity.ProviderName) {
if strings.ToLower(field.Name()) == strings.ToLower(identity.ProviderID) {
return field.Name()
}
}
}

return strings.TrimSpace(identity.ProviderName)
return strings.TrimSpace(identity.ProviderID)
}
17 changes: 14 additions & 3 deletions backend/ee/saml/config/saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,21 @@ func (s *Saml) Validate() error {
return errors.New("at least one SAML provider is needed")
}

configuredDomains := make(map[string]int)
for _, provider := range s.IdentityProviders {
validationErrors = provider.Validate()
if validationErrors != nil {
return validationErrors
if provider.Enabled {
validationErrors = provider.Validate()
if validationErrors != nil {
return validationErrors
}

configuredDomains[provider.Domain] += 1
}
}

for configuredDomain, configuredDomainCount := range configuredDomains {
if configuredDomainCount > 1 {
return fmt.Errorf("provider domains must be unique, found domain %s configured %d times", configuredDomain, configuredDomainCount)
}
}
}
Expand Down
7 changes: 4 additions & 3 deletions backend/ee/saml/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,17 @@ func (handler *Handler) linkAccount(c echo.Context, redirectTo *url.URL, state *
var samlError error
samlError = handler.samlService.Persister().Transaction(func(tx *pop.Connection) error {
userdata := provider.GetUserData(assertionInfo)

linkResult, samlErrorTx := thirdparty.LinkAccount(tx, handler.samlService.Config(), handler.samlService.Persister(), userdata, state.Provider, true, state.IsFlow)
identityProviderIssuer := assertionInfo.Assertions[0].Issuer
samlDomain := provider.GetDomain()
linkResult, samlErrorTx := thirdparty.LinkAccount(tx, handler.samlService.Config(), handler.samlService.Persister(), userdata, identityProviderIssuer.Value, true, &samlDomain, state.IsFlow)
if samlErrorTx != nil {
return samlErrorTx
}

accountLinkingResult = linkResult

emailModel := linkResult.User.Emails.GetEmailByAddress(userdata.Metadata.Email)
identityModel := emailModel.Identities.GetIdentity(provider.GetDomain(), userdata.Metadata.Subject)
identityModel := emailModel.Identities.GetIdentity(identityProviderIssuer.Value, userdata.Metadata.Subject)

token, tokenError := models.NewToken(
linkResult.User.ID,
Expand Down
2 changes: 1 addition & 1 deletion backend/ee/saml/provider/saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func NewBaseSamlProvider(cfg *config.Config, idpConfig samlConfig.IdentityProvid
IDPCertificateStore: &idpMetadata.certs,

AssertionConsumerServiceURL: fmt.Sprintf("%s/saml/callback", cfg.Saml.Endpoint),
ServiceProviderIssuer: fmt.Sprintf("%s/saml/metadata", cfg.Saml.Endpoint),
ServiceProviderIssuer: cfg.Saml.Endpoint,
ServiceProviderSLOURL: fmt.Sprintf("%s/saml/logout", cfg.Saml.Endpoint),
SPKeyStore: serviceProviderCertStore,

Expand Down
2 changes: 1 addition & 1 deletion backend/flow_api/flow/shared/action_exchange_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (a ExchangeToken) Execute(c flowpilot.ExecutionContext) error {
return fmt.Errorf("failed to set login_method to the stash: %w", err)
}

if err := c.Stash().Set(StashPathThirdPartyProvider, identity.ProviderName); err != nil {
if err := c.Stash().Set(StashPathThirdPartyProvider, identity.ProviderID); err != nil {
return fmt.Errorf("failed to set third_party_provider to the stash: %w", err)
}

Expand Down
4 changes: 2 additions & 2 deletions backend/flow_api/flow/shared/action_thirdparty_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ func (a ThirdPartyOAuth) Initialize(c flowpilot.InitializationContext) {
Required(true)

for _, provider := range enabledThirdPartyProviders {
providerInput.AllowedValue(provider.DisplayName, provider.Name)
providerInput.AllowedValue(provider.DisplayName, provider.ID)
}

slices.SortFunc(enabledCustomThirdPartyProviders, func(a, b config.CustomThirdPartyProvider) bool {
return a.DisplayName < b.DisplayName
})

for _, provider := range enabledCustomThirdPartyProviders {
providerInput.AllowedValue(provider.DisplayName, provider.Name)
providerInput.AllowedValue(provider.DisplayName, provider.ID)
}

c.AddInputs(flowpilot.StringInput("redirect_to").Hidden(true).Required(true), providerInput)
Expand Down
4 changes: 2 additions & 2 deletions backend/flow_api/services/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

func UserCanDoThirdParty(cfg config.Config, identities models.Identities) bool {
for _, identity := range identities {
if provider := cfg.ThirdParty.Providers.Get(identity.ProviderName); provider != nil {
if provider := cfg.ThirdParty.Providers.Get(identity.ProviderID); provider != nil {
return provider.Enabled
}
}
Expand All @@ -18,7 +18,7 @@ func UserCanDoThirdParty(cfg config.Config, identities models.Identities) bool {

func UserCanDoSaml(cfg config.Config, identities models.Identities) bool {
for _, identity := range identities {
if provider := cfg.Saml.GetProviderByDomain(identity.ProviderName); provider != nil {
if provider := cfg.Saml.GetProviderByDomain(identity.ProviderID); provider != nil {
return cfg.Saml.Enabled && provider.Enabled
}
}
Expand Down
6 changes: 3 additions & 3 deletions backend/handler/thirdparty.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (h *ThirdPartyHandler) Auth(c echo.Context) error {
return h.redirectError(c, thirdparty.ErrorInvalidRequest(err.Error()).WithCause(err), errorRedirectTo)
}

state, err := thirdparty.GenerateState(h.cfg, provider.Name(), request.RedirectTo)
state, err := thirdparty.GenerateState(h.cfg, provider.ID(), request.RedirectTo)
if err != nil {
return h.redirectError(c, thirdparty.ErrorServer("could not generate state").WithCause(err), errorRedirectTo)
}
Expand Down Expand Up @@ -143,14 +143,14 @@ func (h *ThirdPartyHandler) Callback(c echo.Context) error {
return thirdparty.ErrorInvalidRequest("could not retrieve user data from provider").WithCause(terr)
}

linkingResult, terr := thirdparty.LinkAccount(tx, h.cfg, h.persister, userData, provider.Name(), false, state.IsFlow)
linkingResult, terr := thirdparty.LinkAccount(tx, h.cfg, h.persister, userData, provider.ID(), false, nil, state.IsFlow)
if terr != nil {
return terr
}
accountLinkingResult = linkingResult

emailModel := linkingResult.User.Emails.GetEmailByAddress(userData.Metadata.Email)
identityModel := emailModel.Identities.GetIdentity(provider.Name(), userData.Metadata.Subject)
identityModel := emailModel.Identities.GetIdentity(provider.ID(), userData.Metadata.Subject)

token, terr := models.NewToken(
linkingResult.User.ID,
Expand Down
10 changes: 5 additions & 5 deletions backend/handler/thirdparty_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,35 +60,35 @@ func (s *thirdPartySuite) setUpConfig(enabledProviders []string, allowedRedirect
cfg.ThirdParty = config.ThirdParty{
Providers: config.ThirdPartyProviders{
Apple: config.ThirdPartyProvider{
Name: "apple",
ID: "apple",
Enabled: false,
ClientID: "fakeClientID",
Secret: "fakeClientSecret",
AllowLinking: true,
},
Google: config.ThirdPartyProvider{
Name: "google",
ID: "google",
Enabled: false,
ClientID: "fakeClientID",
Secret: "fakeClientSecret",
AllowLinking: true,
},
GitHub: config.ThirdPartyProvider{
Name: "github",
ID: "github",
Enabled: false,
ClientID: "fakeClientID",
Secret: "fakeClientSecret",
AllowLinking: true,
},
Discord: config.ThirdPartyProvider{
Name: "discord",
ID: "discord",
Enabled: false,
ClientID: "fakeClientID",
Secret: "fakeClientSecret",
AllowLinking: true,
},
Microsoft: config.ThirdPartyProvider{
Name: "microsoft",
ID: "microsoft",
Enabled: false,
ClientID: "fakeClientID",
Secret: "fakeClientSecret",
Expand Down
6 changes: 3 additions & 3 deletions backend/persistence/identity_persister.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

type IdentityPersister interface {
Get(userProviderID string, providerID string) (*models.Identity, error)
Get(providerUserID string, providerID string) (*models.Identity, error)
GetByID(identityID uuid.UUID) (*models.Identity, error)
Create(identity models.Identity) error
Update(identity models.Identity) error
Expand All @@ -32,9 +32,9 @@ func (p identityPersister) GetByID(identityID uuid.UUID) (*models.Identity, erro
return identity, nil
}

func (p identityPersister) Get(userProviderID string, providerID string) (*models.Identity, error) {
func (p identityPersister) Get(providerUserID string, providerID string) (*models.Identity, error) {
identity := &models.Identity{}
if err := p.db.EagerPreload().Where("provider_id = ? AND provider_name = ?", userProviderID, providerID).First(identity); err != nil {
if err := p.db.EagerPreload().Where("provider_user_id = ? AND provider_id = ?", providerUserID, providerID).First(identity); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
drop_index("identities", "identities_provider_user_id_provider_id_idx")

rename_column("identities", "provider_id", "provider_name")
rename_column("identities", "provider_user_id", "provider_id")

add_index("identities", ["provider_id", "provider_name"], {unique: true})
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
drop_index("identities", "identities_provider_id_provider_name_idx")

rename_column("identities", "provider_id", "provider_user_id")
rename_column("identities", "provider_name", "provider_id")

add_index("identities", ["provider_user_id", "provider_id"], {unique: true})
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
drop_table("saml_identities")
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
create_table("saml_identities") {
t.Column("id", "uuid", {primary: true})
t.Column("identity_id", "uuid", { "null": false })
t.Column("domain", "string", { "null": false })
t.Timestamps()
t.ForeignKey("identity_id", {"identities": ["id"]}, {"on_delete": "cascade", "on_update": "cascade"})
t.Index(["identity_id", "domain"], {"unique": true})
}
9 changes: 9 additions & 0 deletions backend/persistence/models/email.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ func (email *Email) IsPrimary() bool {
return false
}

func (email *Email) GetSamlIdentityForDomain(domain string) *SamlIdentity {
for _, identity := range email.Identities {
if identity.SamlIdentity != nil && identity.SamlIdentity.Domain == domain {
return identity.SamlIdentity
}
}
return nil
}

func (emails *Emails) GetVerified() Emails {
var list Emails
for _, email := range *emails {
Expand Down
Loading

0 comments on commit 58dc7ef

Please sign in to comment.