From ca2d7c18ca876a0f2ac6fb900da1b047c6e7dbe7 Mon Sep 17 00:00:00 2001 From: Pedro Parra Ortega Date: Mon, 16 Dec 2024 11:50:11 +0100 Subject: [PATCH 1/2] feat: Support for tenantID in azuread provider Signed-off-by: Pedro Parra Ortega fix: revert non desired changes Signed-off-by: Pedro Parra Ortega --- gothic/gothic.go | 2 +- providers/azuread/azuread.go | 65 +++++++++++++++++++++---------- providers/azuread/azuread_test.go | 2 +- 3 files changed, 47 insertions(+), 22 deletions(-) diff --git a/gothic/gothic.go b/gothic/gothic.go index 0c32caff1..a36ec582b 100644 --- a/gothic/gothic.go +++ b/gothic/gothic.go @@ -15,6 +15,7 @@ import ( "encoding/base64" "errors" "fmt" + "github.com/go-chi/chi/v5" "io" "io/ioutil" "net/http" @@ -22,7 +23,6 @@ import ( "os" "strings" - "github.com/go-chi/chi/v5" "github.com/gorilla/mux" "github.com/gorilla/sessions" "github.com/markbates/goth" diff --git a/providers/azuread/azuread.go b/providers/azuread/azuread.go index 8717ddf37..12111423b 100644 --- a/providers/azuread/azuread.go +++ b/providers/azuread/azuread.go @@ -16,16 +16,17 @@ import ( ) const ( - authURL string = "https://login.microsoftonline.com/common/oauth2/authorize" - tokenURL string = "https://login.microsoftonline.com/common/oauth2/token" + authURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/authorize" + tokenURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/token" endpointProfile string = "https://graph.windows.net/me?api-version=1.6" graphAPIResource string = "https://graph.windows.net/" + commonTenant string = "common" ) // New creates a new AzureAD provider, and sets up important connection details. // You should always call `AzureAD.New` to get a new Provider. Never try to create // one manually. -func New(clientKey, secret, callbackURL string, resources []string, scopes ...string) *Provider { +func New(clientKey, secret, callbackURL string, opts ProviderOpts) *Provider { p := &Provider{ ClientKey: clientKey, Secret: secret, @@ -33,24 +34,32 @@ func New(clientKey, secret, callbackURL string, resources []string, scopes ...st providerName: "azuread", } - p.resources = make([]string, 0, 1+len(resources)) + p.resources = make([]string, 0, 1+len(opts.Resources)) p.resources = append(p.resources, graphAPIResource) - p.resources = append(p.resources, resources...) + p.resources = append(p.resources, opts.Resources...) - p.config = newConfig(p, scopes) + p.config = newConfig(p, opts) return p } // Provider is the implementation of `goth.Provider` for accessing AzureAD. -type Provider struct { - ClientKey string - Secret string - CallbackURL string - HTTPClient *http.Client - config *oauth2.Config - providerName string - resources []string -} +type ( + Provider struct { + ClientKey string + Secret string + CallbackURL string + HTTPClient *http.Client + config *oauth2.Config + providerName string + resources []string + } + + ProviderOpts struct { + Resources []string + Scopes []string + TenantID string + } +) // Name is the name used to retrieve this provider later. func (p *Provider) Name() string { @@ -132,20 +141,20 @@ func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) { return newToken, err } -func newConfig(provider *Provider, scopes []string) *oauth2.Config { +func newConfig(provider *Provider, opts ProviderOpts) *oauth2.Config { c := &oauth2.Config{ ClientID: provider.ClientKey, ClientSecret: provider.Secret, RedirectURL: provider.CallbackURL, Endpoint: oauth2.Endpoint{ - AuthURL: authURL, - TokenURL: tokenURL, + AuthURL: authURL(opts.TenantID), + TokenURL: tokenURL(opts.TenantID), }, Scopes: []string{}, } - if len(scopes) > 0 { - for _, scope := range scopes { + if len(opts.Scopes) > 0 { + for _, scope := range opts.Scopes { c.Scopes = append(c.Scopes, scope) } } else { @@ -185,3 +194,19 @@ func userFromReader(r io.Reader, user *goth.User) error { func authorizationHeader(session *Session) (string, string) { return "Authorization", fmt.Sprintf("Bearer %s", session.AccessToken) } + +func authURL(tenantID string) string { + if tenantID != "" { + return fmt.Sprintf(authURLTemplate, tenantID) + } else { + return fmt.Sprintf(authURLTemplate, commonTenant) + } +} + +func tokenURL(tenantID string) string { + if tenantID != "" { + return fmt.Sprintf(tokenURLTemplate, tenantID) + } else { + return fmt.Sprintf(tokenURLTemplate, commonTenant) + } +} diff --git a/providers/azuread/azuread_test.go b/providers/azuread/azuread_test.go index 5608a4756..bf148af4c 100644 --- a/providers/azuread/azuread_test.go +++ b/providers/azuread/azuread_test.go @@ -51,5 +51,5 @@ func Test_SessionFromJSON(t *testing.T) { } func azureadProvider() *azuread.Provider { - return azuread.New(os.Getenv("AZUREAD_KEY"), os.Getenv("AZUREAD_SECRET"), "/foo", nil) + return azuread.New(os.Getenv("AZUREAD_KEY"), os.Getenv("AZUREAD_SECRET"), "/foo", azuread.ProviderOpts{}) } From f561850a9a49be772ad38b669f986629409170d7 Mon Sep 17 00:00:00 2001 From: Pedro Parra Ortega Date: Mon, 16 Dec 2024 22:06:19 +0100 Subject: [PATCH 2/2] fix: correct example Signed-off-by: Pedro Parra Ortega --- examples/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/main.go b/examples/main.go index 8bb9ddf97..1977f189b 100644 --- a/examples/main.go +++ b/examples/main.go @@ -107,7 +107,7 @@ func main() { amazon.New(os.Getenv("AMAZON_KEY"), os.Getenv("AMAZON_SECRET"), "http://localhost:3000/auth/amazon/callback"), yammer.New(os.Getenv("YAMMER_KEY"), os.Getenv("YAMMER_SECRET"), "http://localhost:3000/auth/yammer/callback"), onedrive.New(os.Getenv("ONEDRIVE_KEY"), os.Getenv("ONEDRIVE_SECRET"), "http://localhost:3000/auth/onedrive/callback"), - azuread.New(os.Getenv("AZUREAD_KEY"), os.Getenv("AZUREAD_SECRET"), "http://localhost:3000/auth/azuread/callback", nil), + azuread.New(os.Getenv("AZUREAD_KEY"), os.Getenv("AZUREAD_SECRET"), "http://localhost:3000/auth/azuread/callback", azuread.ProviderOpts{}), microsoftonline.New(os.Getenv("MICROSOFTONLINE_KEY"), os.Getenv("MICROSOFTONLINE_SECRET"), "http://localhost:3000/auth/microsoftonline/callback"), battlenet.New(os.Getenv("BATTLENET_KEY"), os.Getenv("BATTLENET_SECRET"), "http://localhost:3000/auth/battlenet/callback"), eveonline.New(os.Getenv("EVEONLINE_KEY"), os.Getenv("EVEONLINE_SECRET"), "http://localhost:3000/auth/eveonline/callback"),