From bf14daded12f507df6ba365dbfd352f94a92e332 Mon Sep 17 00:00:00 2001 From: Nikos Date: Wed, 27 Nov 2024 12:38:22 +0100 Subject: [PATCH] fix: use client authn on device auth request According to https://datatracker.ietf.org/doc/html/rfc8628#section-3.1, the device auth request must include client authentication. Fixes https://github.com/golang/oauth2/issues/685 --- deviceauth.go | 60 ++++++++++------------- internal/deviceauth.go | 95 +++++++++++++++++++++++++++++++++++++ internal/deviceauth_test.go | 88 ++++++++++++++++++++++++++++++++++ internal/oauth2.go | 34 +++++++++++++ internal/token.go | 15 +----- 5 files changed, 243 insertions(+), 49 deletions(-) create mode 100644 internal/deviceauth.go create mode 100644 internal/deviceauth_test.go diff --git a/deviceauth.go b/deviceauth.go index e99c92f39..863e507e5 100644 --- a/deviceauth.go +++ b/deviceauth.go @@ -4,9 +4,6 @@ import ( "context" "encoding/json" "errors" - "fmt" - "io" - "net/http" "net/url" "strings" "time" @@ -93,47 +90,40 @@ func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*Devic return retrieveDeviceAuth(ctx, c, v) } -func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) { - if c.Endpoint.DeviceAuthURL == "" { - return nil, errors.New("endpoint missing DeviceAuthURL") +// deviceAuthFromInternal maps an *internal.DeviceAuthResponse struct into +// a *DeviceAuthResponse struct. +func deviceAuthFromInternal(da *internal.DeviceAuthResponse) *DeviceAuthResponse { + if da == nil { + return nil } - - req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode())) - if err != nil { - return nil, err + return &DeviceAuthResponse{ + DeviceCode: da.DeviceCode, + UserCode: da.UserCode, + VerificationURI: da.VerificationURI, + VerificationURIComplete: da.VerificationURIComplete, + Expiry: time.Now().UTC().Add(time.Second * time.Duration(da.Expiry)), + Interval: da.Interval, } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("Accept", "application/json") +} - t := time.Now() - r, err := internal.ContextClient(ctx).Do(req) - if err != nil { - return nil, err +// retrieveDeviceAuth takes a *Config and uses that to retrieve an *internal.DeviceAuthResponse. +// This response is then mapped from *internal.DeviceAuthResponse into an *oauth2.DeviceAuthResponse which is returned along +// with an error. +func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) { + if c.Endpoint.DeviceAuthURL == "" { + return nil, errors.New("endpoint missing DeviceAuthURL") } - body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) + da, err := internal.RetrieveDeviceAuth(ctx, c.ClientID, c.ClientSecret, c.Endpoint.DeviceAuthURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get()) if err != nil { - return nil, fmt.Errorf("oauth2: cannot auth device: %v", err) - } - if code := r.StatusCode; code < 200 || code > 299 { - return nil, &RetrieveError{ - Response: r, - Body: body, + if rErr, ok := err.(*internal.RetrieveError); ok { + return nil, (*RetrieveError)(rErr) } + return nil, err } + dar := deviceAuthFromInternal(da) - da := &DeviceAuthResponse{} - err = json.Unmarshal(body, &da) - if err != nil { - return nil, fmt.Errorf("unmarshal %s", err) - } - - if !da.Expiry.IsZero() { - // Make a small adjustment to account for time taken by the request - da.Expiry = da.Expiry.Add(-time.Since(t)) - } - - return da, nil + return dar, err } // DeviceAccessToken polls the server to exchange a device code for a token. diff --git a/internal/deviceauth.go b/internal/deviceauth.go new file mode 100644 index 000000000..b973ae33b --- /dev/null +++ b/internal/deviceauth.go @@ -0,0 +1,95 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package internal + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/url" + "time" +) + +// DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response +// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2 +// +// This type is a mirror of oauth2.DeviceAuthResponse and exists to break +// an otherwise-circular dependency. Other internal packages +// should convert this DeviceAuthResponse into an oauth2.DeviceAuthResponse before use. +type DeviceAuthResponse struct { + // DeviceCode + DeviceCode string `json:"device_code"` + // UserCode is the code the user should enter at the verification uri + UserCode string `json:"user_code"` + // VerificationURI is where user should enter the user code + VerificationURI string `json:"verification_uri"` + // VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code. + VerificationURIComplete string `json:"verification_uri_complete,omitempty"` + // Expiry is when the device code and user code expire + Expiry int64 `json:"expires_in,omitempty"` + // Interval is the duration in seconds that Poll should wait between requests + Interval int64 `json:"interval,omitempty"` +} + +func RetrieveDeviceAuth(ctx context.Context, clientID, clientSecret, deviceAuthURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*DeviceAuthResponse, error) { + needsAuthStyleProbe := authStyle == AuthStyleUnknown + if needsAuthStyleProbe { + if style, ok := styleCache.lookupAuthStyle(deviceAuthURL); ok { + authStyle = style + needsAuthStyleProbe = false + } else { + authStyle = AuthStyleInHeader // the first way we'll try + } + } + + req, err := NewRequestWithClientAuthn("POST", deviceAuthURL, clientID, clientSecret, v, authStyle) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + t := time.Now() + r, err := ContextClient(ctx).Do(req) + + if err != nil && needsAuthStyleProbe { + // If we get an error, assume the server wants the + // clientID & clientSecret in a different form. + authStyle = AuthStyleInParams // the second way we'll try + req, _ := NewRequestWithClientAuthn("POST", deviceAuthURL, clientID, clientSecret, v, authStyle) + r, err = ContextClient(ctx).Do(req) + } + if needsAuthStyleProbe && err == nil { + styleCache.setAuthStyle(deviceAuthURL, authStyle) + } + + if err != nil { + return nil, err + } + + body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("oauth2: cannot auth device: %v", err) + } + if code := r.StatusCode; code < 200 || code > 299 { + return nil, &RetrieveError{ + Response: r, + Body: body, + } + } + + da := &DeviceAuthResponse{} + err = json.Unmarshal(body, &da) + if err != nil { + return nil, fmt.Errorf("unmarshal %s", err) + } + + if da.Expiry != 0 { + // Make a small adjustment to account for time taken by the request + da.Expiry = da.Expiry + int64(t.Nanosecond()) + } + return da, nil +} diff --git a/internal/deviceauth_test.go b/internal/deviceauth_test.go new file mode 100644 index 000000000..937158e0e --- /dev/null +++ b/internal/deviceauth_test.go @@ -0,0 +1,88 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package internal + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestDeviceAuth_ClientAuthnInParams(t *testing.T) { + styleCache := new(AuthStyleCache) + const clientID = "client-id" + const clientSecret = "client-secret" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got, want := r.FormValue("client_id"), clientID; got != want { + t.Errorf("client_id = %q; want %q", got, want) + } + if got, want := r.FormValue("client_secret"), clientSecret; got != want { + t.Errorf("client_secret = %q; want %q", got, want) + } + io.WriteString(w, `{"device_code":"code","user_code":"user_code","verification_uri":"http://example.device.com","expires_in":300,"interval":5}`) + })) + defer ts.Close() + _, err := RetrieveDeviceAuth(context.Background(), clientID, clientSecret, ts.URL, url.Values{}, AuthStyleInParams, styleCache) + if err != nil { + t.Errorf("RetrieveDeviceAuth = %v; want no error", err) + } +} + +func TestDeviceAuth_ClientAuthnInHeader(t *testing.T) { + styleCache := new(AuthStyleCache) + const clientID = "client-id" + const clientSecret = "client-secret" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + u, p, ok := r.BasicAuth() + if !ok { + io.WriteString(w, `{"error":"invalid_client"}`) + w.WriteHeader(http.StatusBadRequest) + } + if got, want := u, clientID; got != want { + io.WriteString(w, `{"error":"invalid_client"}`) + w.WriteHeader(http.StatusBadRequest) + } + if got, want := p, clientSecret; got != want { + io.WriteString(w, `{"error":"invalid_client"}`) + w.WriteHeader(http.StatusBadRequest) + } + io.WriteString(w, `{"device_code":"code","user_code":"user_code","verification_uri":"http://example.device.com","expires_in":300,"interval":5}`) + })) + defer ts.Close() + _, err := RetrieveDeviceAuth(context.Background(), clientID, clientSecret, ts.URL, url.Values{}, AuthStyleInHeader, styleCache) + if err != nil { + t.Errorf("RetrieveDeviceAuth = %v; want no error", err) + } +} + +func TestDeviceAuth_ClientAuthnProbe(t *testing.T) { + styleCache := new(AuthStyleCache) + const clientID = "client-id" + const clientSecret = "client-secret" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + u, p, ok := r.BasicAuth() + if !ok { + io.WriteString(w, `{"error":"invalid_client"}`) + w.WriteHeader(http.StatusBadRequest) + } + if got, want := u, clientID; got != want { + io.WriteString(w, `{"error":"invalid_client"}`) + w.WriteHeader(http.StatusBadRequest) + } + if got, want := p, clientSecret; got != want { + io.WriteString(w, `{"error":"invalid_client"}`) + w.WriteHeader(http.StatusBadRequest) + } + io.WriteString(w, `{"device_code":"code","user_code":"user_code","verification_uri":"http://example.device.com","expires_in":300,"interval":5}`) + })) + defer ts.Close() + _, err := RetrieveDeviceAuth(context.Background(), clientID, clientSecret, ts.URL, url.Values{}, AuthStyleUnknown, styleCache) + if err != nil { + t.Errorf("RetrieveDeviceAuth = %v; want no error", err) + } +} diff --git a/internal/oauth2.go b/internal/oauth2.go index 14989beaf..b80d6f95b 100644 --- a/internal/oauth2.go +++ b/internal/oauth2.go @@ -10,6 +10,9 @@ import ( "encoding/pem" "errors" "fmt" + "net/http" + "net/url" + "strings" ) // ParseKey converts the binary contents of a private key file @@ -35,3 +38,34 @@ func ParseKey(key []byte) (*rsa.PrivateKey, error) { } return parsed, nil } + +// addClientAuthnRequestParams adds client_secret_post client authentication +func addClientAuthnRequestParams(clientID, clientSecret string, v url.Values, authStyle AuthStyle) url.Values { + if authStyle == AuthStyleInParams { + v = cloneURLValues(v) + if clientID != "" { + v.Set("client_id", clientID) + } + if clientSecret != "" { + v.Set("client_secret", clientSecret) + } + } + return v +} + +// addClientAuthnRequestHeaders adds client_secret_basic client authentication +func addClientAuthnRequestHeaders(clientID, clientSecret string, req *http.Request, authStyle AuthStyle) { + if authStyle == AuthStyleInHeader { + req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret)) + } +} + +func NewRequestWithClientAuthn(httpMethod string, endpointURL, clientID, clientSecret string, v url.Values, authStyle AuthStyle) (*http.Request, error) { + v = addClientAuthnRequestParams(clientID, clientSecret, v, authStyle) + req, err := http.NewRequest(httpMethod, endpointURL, strings.NewReader(v.Encode())) + if err != nil { + return nil, err + } + addClientAuthnRequestHeaders(clientID, clientSecret, req, authStyle) + return req, nil +} diff --git a/internal/token.go b/internal/token.go index e83ddeef0..f1933de89 100644 --- a/internal/token.go +++ b/internal/token.go @@ -16,7 +16,6 @@ import ( "net/http" "net/url" "strconv" - "strings" "sync" "sync/atomic" "time" @@ -181,23 +180,11 @@ func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) { // the POST body (along with any values in v); false means to send it // in the Authorization header. func newTokenRequest(tokenURL, clientID, clientSecret string, v url.Values, authStyle AuthStyle) (*http.Request, error) { - if authStyle == AuthStyleInParams { - v = cloneURLValues(v) - if clientID != "" { - v.Set("client_id", clientID) - } - if clientSecret != "" { - v.Set("client_secret", clientSecret) - } - } - req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode())) + req, err := NewRequestWithClientAuthn("POST", tokenURL, clientID, clientSecret, v, authStyle) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - if authStyle == AuthStyleInHeader { - req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret)) - } return req, nil }