diff --git a/cmd/ssokenizer/main.go b/cmd/ssokenizer/main.go index 1fef780..73b3394 100644 --- a/cmd/ssokenizer/main.go +++ b/cmd/ssokenizer/main.go @@ -220,6 +220,10 @@ type IdentityProviderConfig struct { TokenURL string `yaml:"token_url"` SecretAuth SecretAuthConfig `yaml:"secret_auth"` + + // source_id parameter to pass to Vanta provider. Only needed for "vanta" profile + // TODO: figure out a way to pass a map[string]string of "extra" stuff to providers. + SourceID string `yaml:"source_id"` } func (ic *IdentityProviderConfig) provider(name string, c *Config) (ssokenizer.Provider, error) { @@ -273,13 +277,18 @@ func (ic *IdentityProviderConfig) provider(name string, c *Config) (ssokenizer.P switch ic.Profile { case "vanta": + if ic.SourceID == "" { + return nil, errors.New("missing source_id") + } + op.OAuthConfig.Endpoint = xoauth2.Endpoint{ AuthURL: "https://app.vanta.com/oauth/authorize", TokenURL: "https://api.vanta.com/oauth/token", AuthStyle: xoauth2.AuthStyleInParams, } - op.ForwardParams = []string{"source_id"} + op.AuthRequestParams = map[string]string{"source_id": ic.SourceID} + op.TokenRequestParams = map[string]string{"source_id": ic.SourceID} return &vanta.Provider{Provider: op}, nil case "oauth": diff --git a/oauth2/oauth2.go b/oauth2/oauth2.go index abd7d8d..4c70de4 100644 --- a/oauth2/oauth2.go +++ b/oauth2/oauth2.go @@ -24,6 +24,12 @@ type Provider struct { // ForwardParams are the parameters that should be forwarded from the start // request to the auth URL. ForwardParams []string + + // Params to add to the auth request. + AuthRequestParams map[string]string + + // Params to add to the token request. + TokenRequestParams map[string]string } var _ ssokenizer.Provider = (*Provider)(nil) @@ -86,6 +92,10 @@ func (p *Provider) handleStart(w http.ResponseWriter, r *http.Request) { } } + for key, value := range p.AuthRequestParams { + opts = append(opts, oauth2.SetAuthURLParam(key, value)) + } + if p.OAuthConfig.RedirectURL == "" { opts = append(opts, oauth2.SetAuthURLParam("redirect_uri", p.URL.JoinPath(callbackPath).String())) } @@ -128,7 +138,12 @@ func (p *Provider) handleCallback(w http.ResponseWriter, r *http.Request) { return } - tok, err := p.OAuthConfig.Exchange(r.Context(), code, oauth2.AccessTypeOffline) + opts := []oauth2.AuthCodeOption{oauth2.AccessTypeOffline} + for key, value := range p.TokenRequestParams { + opts = append(opts, oauth2.SetAuthURLParam(key, value)) + } + + tok, err := p.OAuthConfig.Exchange(r.Context(), code, opts...) if err != nil { r = withError(r, fmt.Errorf("failed exchange: %w", err)) tr.ReturnError(w, r, "bad response")