Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature client transport config #40

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 27 additions & 26 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ var (
)

type Client struct {
ClientError error
session http.Client
clientTransport transport
ClientError error
session http.Client
Transport *transport
}

// NewClient constructs a new client given a URL to a Postgrest instance.
Expand All @@ -31,34 +31,35 @@ func NewClient(rawURL, schema string, headers map[string]string) *Client {
t := transport{
header: http.Header{},
baseURL: *baseURL,
Parent: nil,
}

c := Client{
session: http.Client{Transport: t},
clientTransport: t,
session: http.Client{Transport: &t},
Transport: &t,
}

if schema == "" {
schema = "public"
}

// Set required headers
c.clientTransport.header.Set("Accept", "application/json")
c.clientTransport.header.Set("Content-Type", "application/json")
c.clientTransport.header.Set("Accept-Profile", schema)
c.clientTransport.header.Set("Content-Profile", schema)
c.clientTransport.header.Set("X-Client-Info", "postgrest-go/"+version)
c.Transport.header.Set("Accept", "application/json")
c.Transport.header.Set("Content-Type", "application/json")
c.Transport.header.Set("Accept-Profile", schema)
c.Transport.header.Set("Content-Profile", schema)
c.Transport.header.Set("X-Client-Info", "postgrest-go/"+version)

// Set optional headers if they exist
for key, value := range headers {
c.clientTransport.header.Set(key, value)
c.Transport.header.Set(key, value)
}

return &c
}

func (c *Client) Ping() bool {
req, err := http.NewRequest("GET", path.Join(c.clientTransport.baseURL.Path, ""), nil)
req, err := http.NewRequest("GET", path.Join(c.Transport.baseURL.Path, ""), nil)
if err != nil {
c.ClientError = err

Expand All @@ -81,23 +82,16 @@ func (c *Client) Ping() bool {
return true
}

// SetApiKey sets api key header for subsequent requests.
func (c *Client) SetApiKey(apiKey string) *Client {
c.clientTransport.header.Set("apikey", apiKey)
return c
}


// SetAuthToken sets authorization header for subsequent requests.
func (c *Client) SetAuthToken(authToken string) *Client {
c.clientTransport.header.Set("Authorization", "Bearer "+authToken)
return c
// TokenAuth sets authorization headers for subsequent requests.
func (c *Client) TokenAuth(token string) *Client {
c.Transport.header.Set("Authorization", "Bearer "+token)
c.Transport.header.Set("apikey", token)
}

// ChangeSchema modifies the schema for subsequent requests.
func (c *Client) ChangeSchema(schema string) *Client {
c.clientTransport.header.Set("Accept-Profile", schema)
c.clientTransport.header.Set("Content-Profile", schema)
c.Transport.header.Set("Accept-Profile", schema)
c.Transport.header.Set("Content-Profile", schema)
return c
}

Expand All @@ -121,7 +115,7 @@ func (c *Client) Rpc(name string, count string, rpcBody interface{}) string {
}

readerBody := bytes.NewBuffer(byteBody)
url := path.Join(c.clientTransport.baseURL.Path, "rpc", name)
url := path.Join(c.Transport.baseURL.Path, "rpc", name)
req, err := http.NewRequest("POST", url, readerBody)
if err != nil {
c.ClientError = err
Expand Down Expand Up @@ -158,6 +152,7 @@ func (c *Client) Rpc(name string, count string, rpcBody interface{}) string {
type transport struct {
header http.Header
baseURL url.URL
Parent http.RoundTripper
}

func (t transport) RoundTrip(req *http.Request) (*http.Response, error) {
Expand All @@ -168,5 +163,11 @@ func (t transport) RoundTrip(req *http.Request) (*http.Response, error) {
}

req.URL = t.baseURL.ResolveReference(req.URL)

// This is only needed with usage of httpmock in testing. It would be better to initialize
// t.Parent with http.DefaultTransport and then use t.Parent.RoundTrip(req)
if t.Parent != nil {
return t.Parent.RoundTrip(req)
}
return http.DefaultTransport.RoundTrip(req)
}
2 changes: 1 addition & 1 deletion execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func executeHelper(client *Client, method string, body []byte, urlFragments []st
}

readerBody := bytes.NewBuffer(body)
baseUrl := path.Join(append([]string{client.clientTransport.baseURL.Path}, urlFragments...)...)
baseUrl := path.Join(append([]string{client.Transport.baseURL.Path}, urlFragments...)...)
req, err := http.NewRequest(method, baseUrl, readerBody)
if err != nil {
return nil, 0, fmt.Errorf("error creating request: %s", err.Error())
Expand Down
Loading