Skip to content

Commit

Permalink
allow full location to be passed to tp.Client auth helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
btoews committed Nov 16, 2023
1 parent 4a84c62 commit 276bc49
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 14 deletions.
20 changes: 13 additions & 7 deletions tp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -42,14 +43,18 @@ func WithHTTP(h *http.Client) ClientOption {
}

// WithBearerAuthentication specifies a token to be sent in requests to the
// specified host in the `Authorization: Bearer` header.
func WithBearerAuthentication(hostname, token string) ClientOption {
return WithAuthentication(hostname, "Bearer "+token)
// specified third party in the `Authorization: Bearer` header.
func WithBearerAuthentication(tpLocation, token string) ClientOption {
return WithAuthentication(tpLocation, "Bearer "+token)
}

// WithBearerAuthentication specifies a token to be sent in requests to the
// specified host in the `Authorization` header.
func WithAuthentication(hostname, token string) ClientOption {
// specified third party in the `Authorization` header.
func WithAuthentication(tpLocation, token string) ClientOption {
if u, err := url.Parse(tpLocation); err == nil && u.IsAbs() {
tpLocation = u.Hostname()
}

return func(c *Client) {
if c.http == nil {
cpy := *http.DefaultClient
Expand All @@ -58,11 +63,11 @@ func WithAuthentication(hostname, token string) ClientOption {

switch t := c.http.Transport.(type) {
case *authenticatedHTTP:
t.auth[hostname] = token
t.auth[tpLocation] = token
default:
c.http.Transport = &authenticatedHTTP{
t: t,
auth: map[string]string{hostname: token},
auth: map[string]string{tpLocation: token},
}
}
}
Expand Down Expand Up @@ -354,6 +359,7 @@ type authenticatedHTTP struct {
}

func (a *authenticatedHTTP) RoundTrip(r *http.Request) (*http.Response, error) {

if cred := a.auth[r.URL.Hostname()]; cred != "" {
r.Header.Set("Authorization", cred)
}
Expand Down
2 changes: 1 addition & 1 deletion tp/immediate_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func ExampleTP_RespondDischarge() {
fmt.Printf("validation error without 3p discharge token: %v\n", err)

client := NewClient(firstPartyLocation,
WithBearerAuthentication("127.0.0.1", "trustno1"),
WithBearerAuthentication(tp.Location, "trustno1"),
)

firstPartyMacaroon, err = client.FetchDischargeTokens(context.Background(), firstPartyMacaroon)
Expand Down
8 changes: 2 additions & 6 deletions tp/tp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -75,12 +74,9 @@ func TestTP(t *testing.T) {
tp.RespondDischarge(w, r)
})

u, err := url.Parse(tp.Location)
assert.NoError(t, err)

hdr := genFP(t, tp)
c := NewClient(firstPartyLocation,
WithBearerAuthentication(u.Hostname(), "my-token"),
WithBearerAuthentication(tp.Location, "my-token"),
)
hdr, err = c.FetchDischargeTokens(context.Background(), hdr)
assert.NoError(t, err)
Expand All @@ -101,7 +97,7 @@ func TestTP(t *testing.T) {

hdr := genFP(t, tp)
c := NewClient(firstPartyLocation,
WithBearerAuthentication("wrong.com", "my-token"),
WithBearerAuthentication("https://wrong.com", "my-token"),
)
hdr, err = c.FetchDischargeTokens(context.Background(), hdr)
assert.NoError(t, err)
Expand Down

0 comments on commit 276bc49

Please sign in to comment.