diff --git a/.gitignore b/.gitignore index 7f352b8a4c..57f3044462 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ c.out _obj _test .idea/ +.vscode/ # Architecture specific extensions/prefixes *.[568vq] diff --git a/CHANGELOG.md b/CHANGELOG.md index 452d644b2a..6db90df988 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,11 @@ ## Release Highlights +- [#1361](https://github.com/oauth2-proxy/oauth2-proxy/pull/1541) PKCE Code Challenge Support - RFC-7636 (@braunsonm) + - At this time the `--code-challenge-method` flag can be used to enable it with the method of your choice. +- Parital support for OAuth2 Authorization Server Metadata for detecting code challenge methods (@braunsonm) + - A warning will be displayed when your provider advertises support for PKCE but you have not enabled it. + ## Important Notes ## Breaking Changes @@ -23,6 +28,7 @@ - [#1474](https://github.com/oauth2-proxy/oauth2-proxy/pull/1474) Support configuration of minimal acceptable TLS version (@polarctos) - [#1545](https://github.com/oauth2-proxy/oauth2-proxy/pull/1545) Fix issue with query string allowed group panic on skip methods (@andytson) - [#1286](https://github.com/oauth2-proxy/oauth2-proxy/pull/1286) Add the `allowed_email_domains` and the `allowed_groups` on the `auth_request` + support standard wildcard char for validation with sub-domain and email-domain. (@w3st3ry @armandpicard) +- [#1361](https://github.com/oauth2-proxy/oauth2-proxy/pull/1541) PKCE Code Challenge Support - RFC-7636 (@braunsonm) # V7.2.1 diff --git a/contrib/local-environment/docker-compose.yaml b/contrib/local-environment/docker-compose.yaml index be80d23030..6ec367207d 100644 --- a/contrib/local-environment/docker-compose.yaml +++ b/contrib/local-environment/docker-compose.yaml @@ -29,8 +29,8 @@ services: - httpbin dex: container_name: dex - image: quay.io/dexidp/dex:v2.23.0 - command: serve /dex.yaml + image: ghcr.io/dexidp/dex:v2.30.3 + command: dex serve /dex.yaml ports: - 4190:4190/tcp hostname: dex @@ -47,6 +47,8 @@ services: httpbin: container_name: httpbin image: kennethreitz/httpbin + ports: + - 8080:80/tcp networks: httpbin: {} etcd: diff --git a/docs/docs/configuration/alpha_config.md b/docs/docs/configuration/alpha_config.md index 53097c29ee..cc9f5ae43f 100644 --- a/docs/docs/configuration/alpha_config.md +++ b/docs/docs/configuration/alpha_config.md @@ -419,6 +419,7 @@ Provider holds all configuration for a single provider | `validateURL` | _string_ | ValidateURL is the access token validation endpoint | | `scope` | _string_ | Scope is the OAuth scope specification | | `allowedGroups` | _[]string_ | AllowedGroups is a list of restrict logins to members of this group | +| `force_code_challenge_method` | _string_ | The forced code challenge method | ### ProviderType #### (`string` alias) diff --git a/docs/docs/configuration/overview.md b/docs/docs/configuration/overview.md index 35ed4be3d9..0cd01e5859 100644 --- a/docs/docs/configuration/overview.md +++ b/docs/docs/configuration/overview.md @@ -84,6 +84,7 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/ | `--client-id` | string | the OAuth Client ID, e.g. `"123456.apps.googleusercontent.com"` | | | `--client-secret` | string | the OAuth Client Secret | | | `--client-secret-file` | string | the file with OAuth Client Secret | | +| `--code-challenge-method` | string | use PKCE code challenges with the specified method. Either 'plain' or 'S256' (recommended) | | | `--config` | string | path to config file | | | `--cookie-domain` | string \| list | Optional cookie domains to force cookies to (e.g. `.yourcompany.com`). The longest domain matching the request's host will be used (or the shortest cookie domain if there is no match). | | | `--cookie-expire` | duration | expire timeframe for cookie | 168h0m0s | diff --git a/oauthproxy.go b/oauthproxy.go index 3287dadb96..308f806b44 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -2,6 +2,7 @@ package main import ( "context" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -25,6 +26,7 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app/redirect" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" proxyhttp "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/http" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" @@ -680,7 +682,29 @@ func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, ove extraParams := p.provider.Data().LoginURLParams(overrides) prepareNoCache(rw) - csrf, err := cookies.NewCSRF(p.CookieOptions) + var codeChallenge, codeVerifier, codeChallengeMethod string + if p.provider.Data().CodeChallengeMethod != "" { + codeChallengeMethod = p.provider.Data().CodeChallengeMethod + preEncodedCodeVerifier, err := encryption.Nonce(96) + if err != nil { + logger.Errorf("Unable to build random string: %v", err) + p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) + return + } + codeVerifier = base64.RawURLEncoding.EncodeToString(preEncodedCodeVerifier) + + codeChallenge, err = encryption.GenerateCodeChallenge(p.provider.Data().CodeChallengeMethod, codeVerifier) + if err != nil { + logger.Errorf("Error creating code challenge: %v", err) + p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) + return + } + + extraParams.Add("code_challenge", codeChallenge) + extraParams.Add("code_challenge_method", codeChallengeMethod) + } + + csrf, err := cookies.NewCSRF(p.CookieOptions, codeVerifier) if err != nil { logger.Errorf("Error creating CSRF nonce: %v", err) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) @@ -732,24 +756,24 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { return } - session, err := p.redeemCode(req) + csrf, err := cookies.LoadCSRFCookie(req, p.CookieOptions) if err != nil { - logger.Errorf("Error redeeming code during OAuth2 callback: %v", err) - p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) + logger.Println(req, logger.AuthFailure, "Invalid authentication via OAuth2: unable to obtain CSRF cookie") + p.ErrorPage(rw, req, http.StatusForbidden, err.Error(), "Login Failed: Unable to find a valid CSRF token. Please try again.") return } - err = p.enrichSessionState(req.Context(), session) + session, err := p.redeemCode(req, csrf.GetCodeVerifier()) if err != nil { - logger.Errorf("Error creating session during OAuth2 callback: %v", err) + logger.Errorf("Error redeeming code during OAuth2 callback: %v", err) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) return } - csrf, err := cookies.LoadCSRFCookie(req, p.CookieOptions) + err = p.enrichSessionState(req.Context(), session) if err != nil { - logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: unable to obtain CSRF cookie") - p.ErrorPage(rw, req, http.StatusForbidden, err.Error(), "Login Failed: Unable to find a valid CSRF token. Please try again.") + logger.Errorf("Error creating session during OAuth2 callback: %v", err) + p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) return } @@ -799,14 +823,14 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { } } -func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, error) { +func (p *OAuthProxy) redeemCode(req *http.Request, codeVerifier string) (*sessionsapi.SessionState, error) { code := req.Form.Get("code") if code == "" { return nil, providers.ErrMissingCode } redirectURI := p.getOAuthRedirectURI(req) - s, err := p.provider.Redeem(req.Context(), redirectURI, code) + s, err := p.provider.Redeem(req.Context(), redirectURI, code, codeVerifier) if err != nil { return nil, err } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 7c72069a39..90b27d592f 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -115,7 +115,7 @@ func Test_redeemCode(t *testing.T) { } req := httptest.NewRequest(http.MethodGet, "/", nil) - _, err = proxy.redeemCode(req) + _, err = proxy.redeemCode(req, "") assert.Equal(t, providers.ErrMissingCode, err) } @@ -405,7 +405,7 @@ func (patTest *PassAccessTokenTest) Close() { func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int, cookie string) { rw := httptest.NewRecorder() - csrf, err := cookies.NewCSRF(patTest.proxy.CookieOptions) + csrf, err := cookies.NewCSRF(patTest.proxy.CookieOptions, "") if err != nil { panic(err) } diff --git a/pkg/apis/options/legacy_options.go b/pkg/apis/options/legacy_options.go index 3233c23399..89e3230a8d 100644 --- a/pkg/apis/options/legacy_options.go +++ b/pkg/apis/options/legacy_options.go @@ -519,6 +519,8 @@ type LegacyProvider struct { JWTKey string `flag:"jwt-key" cfg:"jwt_key"` JWTKeyFile string `flag:"jwt-key-file" cfg:"jwt_key_file"` PubJWKURL string `flag:"pubjwk-url" cfg:"pubjwk_url"` + // PKCE Code Challenge method to use (either S256 or plain) + CodeChallengeMethod string `flag:"code-challenge-method" cfg:"force_code_challenge_method"` } func legacyProviderFlagSet() *pflag.FlagSet { @@ -563,6 +565,7 @@ func legacyProviderFlagSet() *pflag.FlagSet { flagSet.String("scope", "", "OAuth scope specification") flagSet.String("prompt", "", "OIDC prompt") flagSet.String("approval-prompt", "force", "OAuth approval_prompt") + flagSet.String("code-challenge-method", "", "use PKCE code challenges with the specified method. Either 'plain' or 'S256'") flagSet.String("acr-values", "", "acr values string: optional") flagSet.String("jwt-key", "", "private key in PEM format used to sign JWT, so that you can say something like -jwt-key=\"${OAUTH2_PROXY_JWT_KEY}\": required by login.gov") @@ -621,18 +624,19 @@ func (l *LegacyProvider) convert() (Providers, error) { providers := Providers{} provider := Provider{ - ClientID: l.ClientID, - ClientSecret: l.ClientSecret, - ClientSecretFile: l.ClientSecretFile, - Type: ProviderType(l.ProviderType), - CAFiles: l.ProviderCAFiles, - LoginURL: l.LoginURL, - RedeemURL: l.RedeemURL, - ProfileURL: l.ProfileURL, - ProtectedResource: l.ProtectedResource, - ValidateURL: l.ValidateURL, - Scope: l.Scope, - AllowedGroups: l.AllowedGroups, + ClientID: l.ClientID, + ClientSecret: l.ClientSecret, + ClientSecretFile: l.ClientSecretFile, + Type: ProviderType(l.ProviderType), + CAFiles: l.ProviderCAFiles, + LoginURL: l.LoginURL, + RedeemURL: l.RedeemURL, + ProfileURL: l.ProfileURL, + ProtectedResource: l.ProtectedResource, + ValidateURL: l.ValidateURL, + Scope: l.Scope, + AllowedGroups: l.AllowedGroups, + CodeChallengeMethod: l.CodeChallengeMethod, } // This part is out of the switch section for all providers that support OIDC diff --git a/pkg/apis/options/providers.go b/pkg/apis/options/providers.go index 5eae85165d..775ce61827 100644 --- a/pkg/apis/options/providers.go +++ b/pkg/apis/options/providers.go @@ -76,6 +76,8 @@ type Provider struct { Scope string `json:"scope,omitempty"` // AllowedGroups is a list of restrict logins to members of this group AllowedGroups []string `json:"allowedGroups,omitempty"` + // The forced code challenge method + CodeChallengeMethod string `json:"force_code_challenge_method,omitempty"` } // ProviderType is used to enumerate the different provider type options diff --git a/pkg/cookies/csrf.go b/pkg/cookies/csrf.go index 0af74173e7..d7c0d9a22b 100644 --- a/pkg/cookies/csrf.go +++ b/pkg/cookies/csrf.go @@ -20,6 +20,7 @@ type CSRF interface { HashOIDCNonce() string CheckOAuthState(string) bool CheckOIDCNonce(string) bool + GetCodeVerifier() string SetSessionNonce(s *sessions.SessionState) @@ -38,24 +39,30 @@ type csrf struct { // is used to mitigate replay attacks. OIDCNonce []byte `msgpack:"n,omitempty"` + // CodeVerifier holds the unobfuscated PKCE code verification string + // which is used to compare the code challenge when exchanging the + // authentication code. + CodeVerifier string `msgpack:"cv,omitempty"` + cookieOpts *options.Cookie time clock.Clock } // NewCSRF creates a CSRF with random nonces -func NewCSRF(opts *options.Cookie) (CSRF, error) { - state, err := encryption.Nonce() +func NewCSRF(opts *options.Cookie, codeVerifier string) (CSRF, error) { + state, err := encryption.Nonce(32) if err != nil { return nil, err } - nonce, err := encryption.Nonce() + nonce, err := encryption.Nonce(32) if err != nil { return nil, err } return &csrf{ - OAuthState: state, - OIDCNonce: nonce, + OAuthState: state, + OIDCNonce: nonce, + CodeVerifier: codeVerifier, cookieOpts: opts, }, nil @@ -71,6 +78,10 @@ func LoadCSRFCookie(req *http.Request, opts *options.Cookie) (CSRF, error) { return decodeCSRFCookie(cookie, opts) } +func (c *csrf) GetCodeVerifier() string { + return c.CodeVerifier +} + // HashOAuthState returns the hash of the OAuth state nonce func (c *csrf) HashOAuthState() string { return encryption.HashNonce(c.OAuthState) diff --git a/pkg/cookies/csrf_test.go b/pkg/cookies/csrf_test.go index 85f3e750ec..69fbfbb394 100644 --- a/pkg/cookies/csrf_test.go +++ b/pkg/cookies/csrf_test.go @@ -33,7 +33,7 @@ var _ = Describe("CSRF Cookie Tests", func() { } var err error - publicCSRF, err = NewCSRF(cookieOpts) + publicCSRF, err = NewCSRF(cookieOpts, "verifier") Expect(err).ToNot(HaveOccurred()) privateCSRF = publicCSRF.(*csrf) @@ -44,14 +44,16 @@ var _ = Describe("CSRF Cookie Tests", func() { Expect(privateCSRF.OAuthState).ToNot(BeEmpty()) Expect(privateCSRF.OIDCNonce).ToNot(BeEmpty()) Expect(privateCSRF.OAuthState).ToNot(Equal(privateCSRF.OIDCNonce)) + Expect(privateCSRF.CodeVerifier).To(Equal("verifier")) }) It("makes unique nonces between multiple CSRFs", func() { - other, err := NewCSRF(cookieOpts) + other, err := NewCSRF(cookieOpts, "verifier") Expect(err).ToNot(HaveOccurred()) Expect(privateCSRF.OAuthState).ToNot(Equal(other.(*csrf).OAuthState)) Expect(privateCSRF.OIDCNonce).ToNot(Equal(other.(*csrf).OIDCNonce)) + Expect(privateCSRF.CodeVerifier).To(Equal("verifier")) }) }) @@ -72,6 +74,7 @@ var _ = Describe("CSRF Cookie Tests", func() { Expect(publicCSRF.CheckOIDCNonce(csrfNonce + csrfState)).To(BeFalse()) Expect(publicCSRF.CheckOAuthState("")).To(BeFalse()) Expect(publicCSRF.CheckOIDCNonce("")).To(BeFalse()) + Expect(publicCSRF.GetCodeVerifier()).To(Equal("verifier")) }) }) diff --git a/pkg/encryption/nonce.go b/pkg/encryption/nonce.go index 39e3b52026..b0ce68d459 100644 --- a/pkg/encryption/nonce.go +++ b/pkg/encryption/nonce.go @@ -8,9 +8,9 @@ import ( "golang.org/x/crypto/blake2b" ) -// Nonce generates a random 32-byte slice to be used as a nonce -func Nonce() ([]byte, error) { - b := make([]byte, 32) +// Nonce generates a random n-byte slice +func Nonce(length int) ([]byte, error) { + b := make([]byte, length) _, err := rand.Read(b) if err != nil { return nil, err diff --git a/pkg/encryption/utils.go b/pkg/encryption/utils.go index c9d19249d2..cd876400f8 100644 --- a/pkg/encryption/utils.go +++ b/pkg/encryption/utils.go @@ -12,6 +12,11 @@ import ( "time" ) +const ( + CodeChallengeMethodPlain = "plain" + CodeChallengeMethodS256 = "S256" +) + // SecretBytes attempts to base64 decode the secret, if that fails it treats the secret as binary func SecretBytes(secret string) []byte { b, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(secret, "=")) @@ -75,6 +80,18 @@ func SignedValue(seed string, key string, value []byte, now time.Time) (string, return cookieVal, nil } +func GenerateCodeChallenge(method, codeVerifier string) (string, error) { + switch method { + case CodeChallengeMethodPlain: + return codeVerifier, nil + case CodeChallengeMethodS256: + shaSum := sha256.Sum256([]byte(codeVerifier)) + return base64.RawURLEncoding.EncodeToString(shaSum[:]), nil + default: + return "", fmt.Errorf("unknown challenge method: %v", method) + } +} + func cookieSignature(signer func() hash.Hash, args ...string) (string, error) { h := hmac.New(signer, []byte(args[0])) for _, arg := range args[1:] { diff --git a/pkg/validation/sessions.go b/pkg/validation/sessions.go index 5944bf7b3b..96ea6d4fa2 100644 --- a/pkg/validation/sessions.go +++ b/pkg/validation/sessions.go @@ -51,7 +51,7 @@ func validateRedisSessionStore(o *options.Options) []string { return []string{fmt.Sprintf("unable to initialize a redis client: %v", err)} } - n, err := encryption.Nonce() + n, err := encryption.Nonce(32) if err != nil { return []string{fmt.Sprintf("unable to generate a redis initialization test key: %v", err)} } diff --git a/providers/azure.go b/providers/azure.go index 0620d2a8f2..48d55ec2af 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -112,8 +112,8 @@ func (p *AzureProvider) GetLoginURL(redirectURI, state, _ string, extraParams ur } // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { - params, err := p.prepareRedeem(redirectURL, code) +func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code, codeVerifier string) (*sessions.SessionState, error) { + params, err := p.prepareRedeem(redirectURL, code, codeVerifier) if err != nil { return nil, err } @@ -187,7 +187,7 @@ func (p *AzureProvider) EnrichSession(ctx context.Context, s *sessions.SessionSt return nil } -func (p *AzureProvider) prepareRedeem(redirectURL, code string) (url.Values, error) { +func (p *AzureProvider) prepareRedeem(redirectURL, code, codeVerifier string) (url.Values, error) { params := url.Values{} if code == "" { return params, ErrMissingCode @@ -202,6 +202,9 @@ func (p *AzureProvider) prepareRedeem(redirectURL, code string) (url.Values, err params.Add("client_secret", clientSecret) params.Add("code", code) params.Add("grant_type", "authorization_code") + if codeVerifier != "" { + params.Add("code_verifier", codeVerifier) + } if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { params.Add("resource", p.ProtectedResource.String()) } diff --git a/providers/azure_test.go b/providers/azure_test.go index 6d96ec00ae..b198d0199b 100644 --- a/providers/azure_test.go +++ b/providers/azure_test.go @@ -326,7 +326,7 @@ func TestAzureProviderRedeem(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host, options.AzureOptions{}) p.Data().RedeemURL.Path = "/common/oauth2/token" - s, err := p.Redeem(context.Background(), "https://localhost", "1234") + s, err := p.Redeem(context.Background(), "https://localhost", "1234", "123") if testCase.InjectRedeemURLError { assert.NotNil(t, err) } else { diff --git a/providers/google.go b/providers/google.go index 53e4c70548..40501aa6d5 100644 --- a/providers/google.go +++ b/providers/google.go @@ -141,7 +141,7 @@ func claimsFromIDToken(idToken string) (*claims, error) { } // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { +func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code, codeVerifier string) (*sessions.SessionState, error) { if code == "" { return nil, ErrMissingCode } @@ -156,6 +156,9 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( params.Add("client_secret", clientSecret) params.Add("code", code) params.Add("grant_type", "authorization_code") + if codeVerifier != "" { + params.Add("code_verifier", codeVerifier) + } var jsonResponse struct { AccessToken string `json:"access_token"` diff --git a/providers/google_test.go b/providers/google_test.go index 3ff3c60a03..780420e4cc 100644 --- a/providers/google_test.go +++ b/providers/google_test.go @@ -112,7 +112,7 @@ func TestGoogleProviderGetEmailAddress(t *testing.T) { p.RedeemURL, server = newRedeemServer(body) defer server.Close() - session, err := p.Redeem(context.Background(), "http://redirect/", "code1234") + session, err := p.Redeem(context.Background(), "http://redirect/", "code1234", "123") assert.Equal(t, nil, err) assert.NotEqual(t, session, nil) assert.Equal(t, "michael.bland@gsa.gov", session.Email) @@ -178,7 +178,7 @@ func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) { p.RedeemURL, server = newRedeemServer(body) defer server.Close() - session, err := p.Redeem(context.Background(), "http://redirect/", "code1234") + session, err := p.Redeem(context.Background(), "http://redirect/", "code1234", "123") assert.NotEqual(t, nil, err) if session != nil { t.Errorf("expect nill session %#v", session) @@ -189,7 +189,7 @@ func TestGoogleProviderRedeemFailsNoCLientSecret(t *testing.T) { p := newGoogleProvider(t) p.ProviderData.ClientSecretFile = "srvnoerre" - session, err := p.Redeem(context.Background(), "http://redirect/", "code1234") + session, err := p.Redeem(context.Background(), "http://redirect/", "code1234", "123") assert.NotEqual(t, nil, err) if session != nil { t.Errorf("expect nill session %#v", session) @@ -209,7 +209,7 @@ func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) { p.RedeemURL, server = newRedeemServer(body) defer server.Close() - session, err := p.Redeem(context.Background(), "http://redirect/", "code1234") + session, err := p.Redeem(context.Background(), "http://redirect/", "code1234", "123") assert.NotEqual(t, nil, err) if session != nil { t.Errorf("expect nill session %#v", session) @@ -228,7 +228,7 @@ func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) { p.RedeemURL, server = newRedeemServer(body) defer server.Close() - session, err := p.Redeem(context.Background(), "http://redirect/", "code1234") + session, err := p.Redeem(context.Background(), "http://redirect/", "code1234", "123") assert.NotEqual(t, nil, err) if session != nil { t.Errorf("expect nill session %#v", session) diff --git a/providers/logingov.go b/providers/logingov.go index 14fed649f2..e321b9e045 100644 --- a/providers/logingov.go +++ b/providers/logingov.go @@ -202,7 +202,7 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint } // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *LoginGovProvider) Redeem(ctx context.Context, _, code string) (*sessions.SessionState, error) { +func (p *LoginGovProvider) Redeem(ctx context.Context, _, code, codeVerifier string) (*sessions.SessionState, error) { if code == "" { return nil, ErrMissingCode } @@ -225,6 +225,9 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, _, code string) (*session params.Add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") params.Add("code", code) params.Add("grant_type", "authorization_code") + if codeVerifier != "" { + params.Add("code_verifier", codeVerifier) + } // Get the token from the body that we got from the token endpoint. var jsonResponse struct { diff --git a/providers/logingov_test.go b/providers/logingov_test.go index c8f9f162a2..00fe8dccac 100644 --- a/providers/logingov_test.go +++ b/providers/logingov_test.go @@ -235,7 +235,7 @@ func TestLoginGovProviderSessionData(t *testing.T) { p.PubJWKURL, pubjwkserver = newLoginGovServer(pubjwkbody) defer pubjwkserver.Close() - session, err := p.Redeem(context.Background(), "http://redirect/", "code1234") + session, err := p.Redeem(context.Background(), "http://redirect/", "code1234", "123") assert.NoError(t, err) assert.NotEqual(t, session, nil) assert.Equal(t, "timothy.spencer@gsa.gov", session.Email) @@ -329,7 +329,7 @@ func TestLoginGovProviderBadNonce(t *testing.T) { p.PubJWKURL, pubjwkserver = newLoginGovServer(pubjwkbody) defer pubjwkserver.Close() - _, err = p.Redeem(context.Background(), "http://redirect/", "code1234") + _, err = p.Redeem(context.Background(), "http://redirect/", "code1234", "123") // The "badfakenonce" in the idtoken above should cause this to error out assert.Error(t, err) diff --git a/providers/oidc.go b/providers/oidc.go index 9c0fc63096..d3902afb72 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -43,12 +43,17 @@ func (p *OIDCProvider) GetLoginURL(redirectURI, state, nonce string, extraParams } // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { +func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code, codeVerifier string) (*sessions.SessionState, error) { clientSecret, err := p.GetClientSecret() if err != nil { return nil, err } + var opts []oauth2.AuthCodeOption + if codeVerifier != "" { + opts = append(opts, oauth2.SetAuthURLParam("code_verifier", codeVerifier)) + } + c := oauth2.Config{ ClientID: p.ClientID, ClientSecret: clientSecret, @@ -57,7 +62,7 @@ func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (*s }, RedirectURL: redirectURL, } - token, err := c.Exchange(ctx, code) + token, err := c.Exchange(ctx, code, opts...) if err != nil { return nil, fmt.Errorf("token exchange failed: %v", err) } diff --git a/providers/oidc_test.go b/providers/oidc_test.go index 203b86a930..6a49f8ff8f 100644 --- a/providers/oidc_test.go +++ b/providers/oidc_test.go @@ -91,7 +91,7 @@ func TestOIDCProviderGetLoginURL(t *testing.T) { } provider := newOIDCProvider(serverURL, true) - n, err := encryption.Nonce() + n, err := encryption.Nonce(32) assert.NoError(t, err) nonce := base64.RawURLEncoding.EncodeToString(n) @@ -102,6 +102,8 @@ func TestOIDCProviderGetLoginURL(t *testing.T) { provider.SkipNonce = false withNonce := provider.GetLoginURL("http://redirect/", "", nonce, url.Values{}) assert.Contains(t, withNonce, fmt.Sprintf("nonce=%s", nonce)) + assert.NotContains(t, withNonce, "code_challenge") + assert.NotContains(t, withNonce, "code_challenge_method") } func TestOIDCProviderRedeem(t *testing.T) { @@ -117,7 +119,7 @@ func TestOIDCProviderRedeem(t *testing.T) { server, provider := newTestOIDCSetup(body) defer server.Close() - session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234") + session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234", "") assert.Equal(t, nil, err) assert.Equal(t, defaultIDToken.Email, session.Email) assert.Equal(t, accessToken, session.AccessToken) @@ -140,7 +142,7 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) { provider.EmailClaim = "phone_number" defer server.Close() - session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234") + session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234", "") assert.Equal(t, nil, err) assert.Equal(t, defaultIDToken.Phone, session.Email) } diff --git a/providers/provider_data.go b/providers/provider_data.go index c68602827d..647ea172f7 100644 --- a/providers/provider_data.go +++ b/providers/provider_data.go @@ -37,6 +37,10 @@ type ProviderData struct { ClientSecret string ClientSecretFile string Scope string + // The picked CodeChallenge Method or empty if none. + CodeChallengeMethod string + // Code challenge methods supported by the Provider + SupportedCodeChallengeMethods []string `json:"code_challenge_methods_supported,omitempty"` // Common OIDC options for any OIDC-based providers to consume AllowUnverifiedEmail bool diff --git a/providers/provider_default.go b/providers/provider_default.go index af88f0fe7c..756b5f69e5 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -33,13 +33,16 @@ var ( ) // GetLoginURL with typical oauth parameters +// codeChallenge and codeChallengeMethod are the PKCE challenge and method to append to the URL params. +// they will be empty strings if no code challenge should be presented func (p *ProviderData) GetLoginURL(redirectURI, state, _ string, extraParams url.Values) string { loginURL := makeLoginURL(p, redirectURI, state, extraParams) return loginURL.String() } // Redeem provides a default implementation of the OAuth2 token redemption process -func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { +// The codeVerifier is set if a code_verifier parameter should be sent for PKCE +func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code, codeVerifier string) (*sessions.SessionState, error) { if code == "" { return nil, ErrMissingCode } @@ -54,6 +57,9 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*s params.Add("client_secret", clientSecret) params.Add("code", code) params.Add("grant_type", "authorization_code") + if codeVerifier != "" { + params.Add("code_verifier", codeVerifier) + } if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { params.Add("resource", p.ProtectedResource.String()) } diff --git a/providers/provider_default_test.go b/providers/provider_default_test.go index 75b4f5d3a9..80d5b4ce93 100644 --- a/providers/provider_default_test.go +++ b/providers/provider_default_test.go @@ -2,6 +2,7 @@ package providers import ( "context" + "net/url" "testing" "time" @@ -29,6 +30,37 @@ func TestRefresh(t *testing.T) { assert.Equal(t, ErrNotImplemented, err) } +func TestCodeChallengeConfigured(t *testing.T) { + p := &ProviderData{ + LoginURL: &url.URL{ + Scheme: "http", + Host: "my.test.idp", + Path: "/oauth/authorize", + }, + } + + extraValues := url.Values{} + extraValues["code_challenge"] = []string{"challenge"} + extraValues["code_challenge_method"] = []string{"method"} + result := p.GetLoginURL("https://my.test.app/oauth", "", "", extraValues) + assert.Contains(t, result, "code_challenge=challenge") + assert.Contains(t, result, "code_challenge_method=method") +} + +func TestCodeChallengeNotConfigured(t *testing.T) { + p := &ProviderData{ + LoginURL: &url.URL{ + Scheme: "http", + Host: "my.test.idp", + Path: "/oauth/authorize", + }, + } + + result := p.GetLoginURL("https://my.test.app/oauth", "", "", url.Values{}) + assert.NotContains(t, result, "code_challenge") + assert.NotContains(t, result, "code_challenge_method") +} + func TestProviderDataEnrichSession(t *testing.T) { g := NewWithT(t) p := &ProviderData{} diff --git a/providers/providers.go b/providers/providers.go index 58faf0494e..d2570081ba 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -7,15 +7,21 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/providers/oidc" k8serrors "k8s.io/apimachinery/pkg/util/errors" ) +const ( + CodeChallengeMethodPlain = "plain" + CodeChallengeMethodS256 = "S256" +) + // Provider represents an upstream identity provider implementation type Provider interface { Data() *ProviderData - GetLoginURL(redirectURI, finalRedirect string, nonce string, extraParams url.Values) string - Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error) + GetLoginURL(redirectURI, finalRedirect, nonce string, extraParams url.Values) string + Redeem(ctx context.Context, redirectURI, code, codeVerifier string) (*sessions.SessionState, error) // Deprecated: Migrate to EnrichSession GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) EnrichSession(ctx context.Context, s *sessions.SessionState) error @@ -95,10 +101,12 @@ func newProviderDataFromConfig(providerConfig options.Provider) (*ProviderData, if pv.DiscoveryEnabled() { // Use the discovered values rather than any specified values endpoints := pv.Provider().Endpoints() + pkce := pv.Provider().PKCE() providerConfig.LoginURL = endpoints.AuthURL providerConfig.RedeemURL = endpoints.TokenURL providerConfig.ProfileURL = endpoints.UserInfoURL providerConfig.OIDCConfig.JwksURL = endpoints.JWKsURL + p.SupportedCodeChallengeMethods = pkce.CodeChallengeAlgs } } @@ -131,6 +139,12 @@ func newProviderDataFromConfig(providerConfig options.Provider) (*ProviderData, p.EmailClaim = providerConfig.OIDCConfig.EmailClaim p.GroupsClaim = providerConfig.OIDCConfig.GroupsClaim + // Set PKCE enabled or disabled based on discovery and force options + p.CodeChallengeMethod = parseCodeChallengeMethod(providerConfig) + if len(p.SupportedCodeChallengeMethods) != 0 && p.CodeChallengeMethod == "" { + logger.Printf("Warning: Your provider supports PKCE methods %+q, but you have not enabled one with --code-challenge-method", p.SupportedCodeChallengeMethods) + } + // TODO (@NickMeves) - Remove This // Backwards Compatibility for Deprecated UserIDClaim option if providerConfig.OIDCConfig.EmailClaim == options.OIDCEmailClaim && @@ -154,6 +168,18 @@ func newProviderDataFromConfig(providerConfig options.Provider) (*ProviderData, return p, nil } +// Pick the most appropriate code challenge method for PKCE +// At this time we do not consider what the server supports to be safe and +// only enable PKCE if the user opts-in +func parseCodeChallengeMethod(providerConfig options.Provider) string { + switch { + case providerConfig.CodeChallengeMethod != "": + return providerConfig.CodeChallengeMethod + default: + return "" + } +} + func providerRequiresOIDCProviderVerifier(providerType options.ProviderType) (bool, error) { switch providerType { case options.BitbucketProvider, options.DigitalOceanProvider, options.FacebookProvider, options.GitHubProvider, diff --git a/providers/providers_test.go b/providers/providers_test.go index 02521ad71c..ed2a10a063 100644 --- a/providers/providers_test.go +++ b/providers/providers_test.go @@ -171,3 +171,38 @@ func TestScope(t *testing.T) { g.Expect(pd.Scope).To(Equal(tc.expectedScope)) } } + +func TestForcedMethodS256(t *testing.T) { + g := NewWithT(t) + options := options.NewOptions() + options.Providers[0].CodeChallengeMethod = CodeChallengeMethodS256 + method := parseCodeChallengeMethod(options.Providers[0]) + + g.Expect(method).To(Equal(CodeChallengeMethodS256)) +} + +func TestForcedMethodPlain(t *testing.T) { + g := NewWithT(t) + options := options.NewOptions() + options.Providers[0].CodeChallengeMethod = CodeChallengeMethodPlain + method := parseCodeChallengeMethod(options.Providers[0]) + + g.Expect(method).To(Equal(CodeChallengeMethodPlain)) +} + +func TestPrefersS256(t *testing.T) { + g := NewWithT(t) + options := options.NewOptions() + method := parseCodeChallengeMethod(options.Providers[0]) + + g.Expect(method).To(Equal("")) +} + +func TestCanOverwriteS256(t *testing.T) { + g := NewWithT(t) + options := options.NewOptions() + options.Providers[0].CodeChallengeMethod = "plain" + method := parseCodeChallengeMethod(options.Providers[0]) + + g.Expect(method).To(Equal(CodeChallengeMethodPlain)) +}