diff --git a/client.go b/client.go index 09ca508..73164d5 100644 --- a/client.go +++ b/client.go @@ -15,9 +15,9 @@ var ( ) type Client struct { - ClientError error - session http.Client - clientTransport transport + ClientError error + session http.Client + Transport *transport } // NewClient constructs a new client given a URL to a Postgrest instance. @@ -31,11 +31,12 @@ func NewClient(rawURL, schema string, headers map[string]string) *Client { t := transport{ header: http.Header{}, baseURL: *baseURL, + Parent: nil, } c := Client{ - session: http.Client{Transport: t}, - clientTransport: t, + session: http.Client{Transport: &t}, + Transport: &t, } if schema == "" { @@ -43,22 +44,22 @@ func NewClient(rawURL, schema string, headers map[string]string) *Client { } // Set required headers - c.clientTransport.header.Set("Accept", "application/json") - c.clientTransport.header.Set("Content-Type", "application/json") - c.clientTransport.header.Set("Accept-Profile", schema) - c.clientTransport.header.Set("Content-Profile", schema) - c.clientTransport.header.Set("X-Client-Info", "postgrest-go/"+version) + c.Transport.header.Set("Accept", "application/json") + c.Transport.header.Set("Content-Type", "application/json") + c.Transport.header.Set("Accept-Profile", schema) + c.Transport.header.Set("Content-Profile", schema) + c.Transport.header.Set("X-Client-Info", "postgrest-go/"+version) // Set optional headers if they exist for key, value := range headers { - c.clientTransport.header.Set(key, value) + c.Transport.header.Set(key, value) } return &c } func (c *Client) Ping() bool { - req, err := http.NewRequest("GET", path.Join(c.clientTransport.baseURL.Path, ""), nil) + req, err := http.NewRequest("GET", path.Join(c.Transport.baseURL.Path, ""), nil) if err != nil { c.ClientError = err @@ -81,23 +82,16 @@ func (c *Client) Ping() bool { return true } -// SetApiKey sets api key header for subsequent requests. -func (c *Client) SetApiKey(apiKey string) *Client { - c.clientTransport.header.Set("apikey", apiKey) - return c -} - - -// SetAuthToken sets authorization header for subsequent requests. -func (c *Client) SetAuthToken(authToken string) *Client { - c.clientTransport.header.Set("Authorization", "Bearer "+authToken) - return c +// TokenAuth sets authorization headers for subsequent requests. +func (c *Client) TokenAuth(token string) *Client { + c.Transport.header.Set("Authorization", "Bearer "+token) + c.Transport.header.Set("apikey", token) } // ChangeSchema modifies the schema for subsequent requests. func (c *Client) ChangeSchema(schema string) *Client { - c.clientTransport.header.Set("Accept-Profile", schema) - c.clientTransport.header.Set("Content-Profile", schema) + c.Transport.header.Set("Accept-Profile", schema) + c.Transport.header.Set("Content-Profile", schema) return c } @@ -121,7 +115,7 @@ func (c *Client) Rpc(name string, count string, rpcBody interface{}) string { } readerBody := bytes.NewBuffer(byteBody) - url := path.Join(c.clientTransport.baseURL.Path, "rpc", name) + url := path.Join(c.Transport.baseURL.Path, "rpc", name) req, err := http.NewRequest("POST", url, readerBody) if err != nil { c.ClientError = err @@ -158,6 +152,7 @@ func (c *Client) Rpc(name string, count string, rpcBody interface{}) string { type transport struct { header http.Header baseURL url.URL + Parent http.RoundTripper } func (t transport) RoundTrip(req *http.Request) (*http.Response, error) { @@ -168,5 +163,11 @@ func (t transport) RoundTrip(req *http.Request) (*http.Response, error) { } req.URL = t.baseURL.ResolveReference(req.URL) + + // This is only needed with usage of httpmock in testing. It would be better to initialize + // t.Parent with http.DefaultTransport and then use t.Parent.RoundTrip(req) + if t.Parent != nil { + return t.Parent.RoundTrip(req) + } return http.DefaultTransport.RoundTrip(req) } diff --git a/execute.go b/execute.go index 56ab3e7..792be3c 100644 --- a/execute.go +++ b/execute.go @@ -31,7 +31,7 @@ func executeHelper(client *Client, method string, body []byte, urlFragments []st } readerBody := bytes.NewBuffer(body) - baseUrl := path.Join(append([]string{client.clientTransport.baseURL.Path}, urlFragments...)...) + baseUrl := path.Join(append([]string{client.Transport.baseURL.Path}, urlFragments...)...) req, err := http.NewRequest(method, baseUrl, readerBody) if err != nil { return nil, 0, fmt.Errorf("error creating request: %s", err.Error())