Skip to content

Commit

Permalink
Merge pull request #3 from hhuang-rayark/feature/export_auther
Browse files Browse the repository at this point in the history
Allow customized auther in Transport
  • Loading branch information
raymond-chia authored Aug 30, 2023
2 parents ac25440 + 0f1df77 commit 265eb94
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 71 deletions.
25 changes: 13 additions & 12 deletions auther.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,28 @@ type clock interface {
Now() time.Time
}

// auther adds an "OAuth" Authorization header field to requests.
type auther struct {
// DefaultAuther adds an "OAuth" Authorization header field to requests.
type DefaultAuther struct {
config *Config
clock clock
}

func newAuther(config *Config) *auther {
// NewDefaultAuther returns a new DefaultAuther
func NewDefaultAuther(config *Config) *DefaultAuther {
if config == nil {
config = &Config{}
}
if config.Noncer == nil {
config.Noncer = Base64Noncer{}
}
return &auther{
return &DefaultAuther{
config: config,
}
}

// setRequestTokenAuthHeader adds the OAuth1 header for the request token
// request (temporary credential) according to RFC 5849 2.1.
func (a *auther) setRequestTokenAuthHeader(req *http.Request) error {
func (a *DefaultAuther) setRequestTokenAuthHeader(req *http.Request) error {
oauthParams := a.commonOAuthParams()
oauthParams[oauthCallbackParam] = a.config.CallbackURL
params, err := collectParameters(req, oauthParams)
Expand All @@ -78,7 +79,7 @@ func (a *auther) setRequestTokenAuthHeader(req *http.Request) error {

// setAccessTokenAuthHeader sets the OAuth1 header for the access token request
// (token credential) according to RFC 5849 2.3.
func (a *auther) setAccessTokenAuthHeader(req *http.Request, requestToken, requestSecret, verifier string) error {
func (a *DefaultAuther) setAccessTokenAuthHeader(req *http.Request, requestToken, requestSecret, verifier string) error {
oauthParams := a.commonOAuthParams()
oauthParams[oauthTokenParam] = requestToken
oauthParams[oauthVerifierParam] = verifier
Expand All @@ -96,9 +97,9 @@ func (a *auther) setAccessTokenAuthHeader(req *http.Request, requestToken, reque
return nil
}

// setRequestAuthHeader sets the OAuth1 header for making authenticated
// SetRequestAuthHeader sets the OAuth1 header for making authenticated
// requests with an AccessToken (token credential) according to RFC 5849 3.1.
func (a *auther) setRequestAuthHeader(req *http.Request, accessToken *Token) error {
func (a *DefaultAuther) SetRequestAuthHeader(req *http.Request, accessToken *Token) error {
oauthParams := a.commonOAuthParams()
oauthParams[oauthTokenParam] = accessToken.Token
params, err := collectParameters(req, oauthParams)
Expand All @@ -119,7 +120,7 @@ func (a *auther) setRequestAuthHeader(req *http.Request, accessToken *Token) err
// excluding the oauth_signature parameter. This includes the realm parameter
// if it was set in the config. The realm parameter will not be included in
// the signature base string as specified in RFC 5849 3.4.1.3.1.
func (a *auther) commonOAuthParams() map[string]string {
func (a *DefaultAuther) commonOAuthParams() map[string]string {
params := map[string]string{
oauthConsumerKeyParam: a.config.ConsumerKey,
oauthSignatureMethodParam: a.signer().Name(),
Expand All @@ -134,20 +135,20 @@ func (a *auther) commonOAuthParams() map[string]string {
}

// Returns a nonce using the configured Noncer.
func (a *auther) nonce() string {
func (a *DefaultAuther) nonce() string {
return a.config.Noncer.Nonce()
}

// Returns the Unix epoch seconds.
func (a *auther) epoch() int64 {
func (a *DefaultAuther) epoch() int64 {
if a.clock != nil {
return a.clock.Now().Unix()
}
return time.Now().Unix()
}

// Returns the Config's Signer or the default Signer.
func (a *auther) signer() Signer {
func (a *DefaultAuther) signer() Signer {
if a.config.Signer != nil {
return a.config.Signer
}
Expand Down
22 changes: 11 additions & 11 deletions auther_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ import (

func TestCommonOAuthParams(t *testing.T) {
cases := []struct {
auther *auther
DefaultAuther *DefaultAuther
expectedParams map[string]string
}{
{
&auther{
&DefaultAuther{
&Config{
ConsumerKey: "some_consumer_key",
Noncer: &fixedNoncer{"some_nonce"},
Expand All @@ -32,7 +32,7 @@ func TestCommonOAuthParams(t *testing.T) {
},
},
{
&auther{
&DefaultAuther{
&Config{
ConsumerKey: "some_consumer_key",
Realm: "photos",
Expand All @@ -52,13 +52,13 @@ func TestCommonOAuthParams(t *testing.T) {
}

for _, c := range cases {
assert.Equal(t, c.expectedParams, c.auther.commonOAuthParams())
assert.Equal(t, c.expectedParams, c.DefaultAuther.commonOAuthParams())
}
}

func TestNonce(t *testing.T) {
auther := newAuther(nil)
nonce := auther.nonce()
DefaultAuther := NewDefaultAuther(nil)
nonce := DefaultAuther.nonce()
// assert that 32 bytes (256 bites) become 44 bytes since a base64 byte
// zeros the 2 high bits. 3 bytes convert to 4 base64 bytes, 40 base64 bytes
// represent the first 30 of 32 bytes, = padding adds another 4 byte group.
Expand All @@ -67,17 +67,17 @@ func TestNonce(t *testing.T) {
}

func TestEpoch(t *testing.T) {
a := newAuther(nil)
a := NewDefaultAuther(nil)
// assert that a real time is used by default
assert.InEpsilon(t, time.Now().Unix(), a.epoch(), 1)
// assert that the fixed clock can be used for testing
a = &auther{clock: &fixedClock{time.Unix(50037133, 0)}}
a = &DefaultAuther{clock: &fixedClock{time.Unix(50037133, 0)}}
assert.Equal(t, int64(50037133), a.epoch())
}

func TestSigner_Default(t *testing.T) {
config := &Config{ConsumerSecret: "consumer_secret"}
a := newAuther(config)
a := NewDefaultAuther(config)
// echo -n "hello world" | openssl dgst -sha1 -hmac "consumer_secret&token_secret" -binary | base64
expectedSignature := "BE0uILOruKfSXd4UzYlLJDfOq08="
// assert that the default signer produces the expected HMAC-SHA1 digest
Expand All @@ -92,7 +92,7 @@ func TestSigner_SHA256(t *testing.T) {
config := &Config{
Signer: &HMAC256Signer{ConsumerSecret: "consumer_secret"},
}
a := newAuther(config)
a := NewDefaultAuther(config)
// echo -n "hello world" | openssl dgst -sha256 -hmac "consumer_secret&token_secret" -binary | base64
expectedSignature := "pW9drXUyErU8DASWbsP2I3XZbju37AW+VzcGdYSeMo8="
// assert that the signer produces the expected HMAC-SHA256 digest
Expand All @@ -118,7 +118,7 @@ func TestSigner_Custom(t *testing.T) {
ConsumerSecret: "consumer_secret",
Signer: &identitySigner{},
}
a := newAuther(config)
a := NewDefaultAuther(config)
// assert that the custom signer is used
method := a.signer().Name()
digest, err := a.signer().Sign("secret", "hello world")
Expand Down
10 changes: 3 additions & 7 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,7 @@ func (c *Config) Client(ctx context.Context, t *Token) *http.Client {

// NewClient returns a new http Client which signs requests via OAuth1.
func NewClient(ctx context.Context, config *Config, token *Token) *http.Client {
transport := &Transport{
Base: contextTransport(ctx),
source: StaticTokenSource(token),
auther: newAuther(config),
}
transport := newTransport(contextTransport(ctx), StaticTokenSource(token), NewDefaultAuther(config))
return &http.Client{Transport: transport}
}

Expand All @@ -69,7 +65,7 @@ func (c *Config) RequestToken() (requestToken, requestSecret string, err error)
if err != nil {
return "", "", err
}
err = newAuther(c).setRequestTokenAuthHeader(req)
err = NewDefaultAuther(c).setRequestTokenAuthHeader(req)
if err != nil {
return "", "", err
}
Expand Down Expand Up @@ -148,7 +144,7 @@ func (c *Config) AccessToken(requestToken, requestSecret, verifier string) (acce
if err != nil {
return "", "", err
}
err = newAuther(c).setAccessTokenAuthHeader(req, requestToken, requestSecret, verifier)
err = NewDefaultAuther(c).setAccessTokenAuthHeader(req, requestToken, requestSecret, verifier)
if err != nil {
return "", "", err
}
Expand Down
12 changes: 6 additions & 6 deletions reference_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestTwitterRequestTokenAuthHeader(t *testing.T) {
Noncer: &fixedNoncer{expectedNonce},
}

auther := &auther{config, &fixedClock{time.Unix(unixTimestamp, 0)}}
auther := &DefaultAuther{config, &fixedClock{time.Unix(unixTimestamp, 0)}}
req, err := http.NewRequest("POST", config.Endpoint.RequestTokenURL, nil)
assert.Nil(t, err)
err = auther.setRequestTokenAuthHeader(req)
Expand Down Expand Up @@ -74,7 +74,7 @@ func TestTwitterAccessTokenAuthHeader(t *testing.T) {
Noncer: &fixedNoncer{expectedNonce},
}

auther := &auther{config, &fixedClock{time.Unix(unixTimestamp, 0)}}
auther := &DefaultAuther{config, &fixedClock{time.Unix(unixTimestamp, 0)}}
req, err := http.NewRequest("POST", config.Endpoint.AccessTokenURL, nil)
assert.Nil(t, err)
err = auther.setAccessTokenAuthHeader(req, expectedRequestToken, requestTokenSecret, expectedVerifier)
Expand Down Expand Up @@ -111,7 +111,7 @@ var twitterConfig = &Config{
}

func TestTwitterParameterString(t *testing.T) {
auther := &auther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}}
auther := &DefaultAuther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}}
values := url.Values{}
values.Add("status", "Hello Ladies + Gentlemen, a signed OAuth request!")
// note: the reference example is old and uses api v1 in the URL
Expand All @@ -128,7 +128,7 @@ func TestTwitterParameterString(t *testing.T) {
}

func TestTwitterSignatureBase(t *testing.T) {
auther := &auther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}}
auther := &DefaultAuther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}}
values := url.Values{}
values.Add("status", "Hello Ladies + Gentlemen, a signed OAuth request!")
// note: the reference example is old and uses api v1 in the URL
Expand All @@ -151,15 +151,15 @@ func TestTwitterRequestAuthHeader(t *testing.T) {
expectedSignature := PercentEncode("tnnArxj06cWHq44gCs1OSKk/jLY=")
expectedTimestamp := "1318622958"

auther := &auther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}}
auther := &DefaultAuther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}}
values := url.Values{}
values.Add("status", "Hello Ladies + Gentlemen, a signed OAuth request!")

accessToken := &Token{expectedTwitterOAuthToken, oauthTokenSecret}
req, err := http.NewRequest("POST", "https://api.twitter.com/1/statuses/update.json?include_entities=true", strings.NewReader(values.Encode()))
assert.Nil(t, err)
req.Header.Set(contentType, formContentType)
err = auther.setRequestAuthHeader(req, accessToken)
err = auther.SetRequestAuthHeader(req, accessToken)
// assert that request is signed and has an access token token
assert.Nil(t, err)
params := parseOAuthParamsOrFail(t, req.Header.Get(authorizationHeaderParam))
Expand Down
50 changes: 43 additions & 7 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import (
"net/http"
)

type Auther interface {
SetRequestAuthHeader(req *http.Request, accessToken *Token) error
}

// Transport is an http.RoundTripper which makes OAuth1 HTTP requests. It
// wraps a base RoundTripper and adds an Authorization header using the
// token from a TokenSource.
Expand All @@ -18,25 +22,43 @@ type Transport struct {
// source supplies the token to use when signing a request
source TokenSource
// auther adds OAuth1 Authorization headers to requests
auther *auther
auther Auther
}

// NewTransport returns a new Transport and an error
func NewTransport(baseRoundTripper http.RoundTripper, source TokenSource, auther Auther) (*Transport, error) {
t := newTransport(baseRoundTripper, source, auther)
err := t.checkValid()
if err != nil {
return nil, err
}
return t, nil
}

// newTransport returns a new Transport without checking whether there is an error
// newTransport is only for NewTransport & NewClient
func newTransport(baseRoundTripper http.RoundTripper, source TokenSource, auther Auther) *Transport {
return &Transport{
Base: baseRoundTripper,
source: source,
auther: auther,
}
}

// RoundTrip authorizes the request with a signed OAuth1 Authorization header
// using the auther and TokenSource.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
if t.source == nil {
return nil, fmt.Errorf("oauth1: Transport's source is nil")
err := t.checkValid()
if err != nil {
return nil, err
}
accessToken, err := t.source.Token()
if err != nil {
return nil, err
}
if t.auther == nil {
return nil, fmt.Errorf("oauth1: Transport's auther is nil")
}
// RoundTripper should not modify the given request, clone it
req2 := cloneRequest(req)
err = t.auther.setRequestAuthHeader(req2, accessToken)
err = t.auther.SetRequestAuthHeader(req2, accessToken)
if err != nil {
return nil, err
}
Expand All @@ -63,3 +85,17 @@ func cloneRequest(req *http.Request) *http.Request {
}
return r2
}

func (t *Transport) checkValid() error {
if t.source == nil {
return fmt.Errorf("oauth1: Transport's source is nil")
}
if t.auther == nil {
return fmt.Errorf("oauth1: Transport's auther is nil")
}
_, err := t.source.Token()
if err != nil {
return err
}
return nil
}
Loading

0 comments on commit 265eb94

Please sign in to comment.