Skip to content

Commit

Permalink
feat!: add TLSClientConfiger interface to support TLSClientConfig for…
Browse files Browse the repository at this point in the history
… custom round tripper #377
  • Loading branch information
jeevatkm committed Nov 5, 2024
1 parent 8471a1a commit 5a7a5ad
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 49 deletions.
2 changes: 1 addition & 1 deletion cert_watcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func TestClient_SetRootCertificateWatcher(t *testing.T) {
assertNil(t, err)

// Reset TLS config to ensure that previous root cert is not re-used
tr, err := client.Transport()
tr, err := client.HTTPTransport()
assertNil(t, err)
tr.TLSClientConfig = nil
client.SetTransport(tr)
Expand Down
83 changes: 63 additions & 20 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ type (

// SuccessHook type is for reacting to request success
SuccessHook func(*Client, *Response)

// TLSClientConfiger interface is to configure TLS Client configuration on custom transport
// implemented using [http.RoundTripper]
TLSClientConfiger interface {
TLSClientConfig() *tls.Config
SetTLSClientConfig(*tls.Config) error
}
)

// TransportSettings struct is used to define custom dialer and transport
Expand Down Expand Up @@ -1297,6 +1304,18 @@ func (c *Client) AddRetryHook(hook RetryHookFunc) *Client {
return c
}

// TLSClientConfig method returns the [tls.Config] from underlying client transport
// otherwise returns nil
func (c *Client) TLSClientConfig() *tls.Config {
cfg, err := c.tlsConfig()
if err != nil {
c.lock.RLock()
c.log.Errorf("%v", err)
c.lock.RUnlock()
}
return cfg
}

// SetTLSClientConfig method sets TLSClientConfig for underlying client Transport.
//
// For Example:
Expand All @@ -1308,15 +1327,26 @@ func (c *Client) AddRetryHook(hook RetryHookFunc) *Client {
// client.SetTLSClientConfig(&tls.Config{ InsecureSkipVerify: true })
//
// NOTE: This method overwrites existing [http.Transport.TLSClientConfig]
func (c *Client) SetTLSClientConfig(config *tls.Config) *Client {
transport, err := c.Transport()
if err != nil {
c.log.Errorf("%v", err)
return c
}
func (c *Client) SetTLSClientConfig(tlsConfig *tls.Config) *Client {
c.lock.Lock()
defer c.lock.Unlock()
transport.TLSClientConfig = config

// TLSClientConfiger interface handling
if tc, ok := c.httpClient.Transport.(TLSClientConfiger); ok {
if err := tc.SetTLSClientConfig(tlsConfig); err != nil {
c.log.Errorf("%v", err)
}
return c
}

// default standard transport handling
transport, ok := c.httpClient.Transport.(*http.Transport)
if !ok {
c.log.Errorf("SetTLSClientConfig: %v", ErrNotHttpTransportType)
return c
}
transport.TLSClientConfig = tlsConfig

return c
}

Expand All @@ -1333,7 +1363,7 @@ func (c *Client) ProxyURL() *url.URL {
//
// OR you could also set Proxy via environment variable, refer to [http.ProxyFromEnvironment]
func (c *Client) SetProxy(proxyURL string) *Client {
transport, err := c.Transport()
transport, err := c.HTTPTransport()
if err != nil {
c.log.Errorf("%v", err)
return c
Expand All @@ -1356,7 +1386,7 @@ func (c *Client) SetProxy(proxyURL string) *Client {
//
// client.RemoveProxy()
func (c *Client) RemoveProxy() *Client {
transport, err := c.Transport()
transport, err := c.HTTPTransport()
if err != nil {
c.log.Errorf("%v", err)
return c
Expand Down Expand Up @@ -1554,11 +1584,9 @@ func (c *Client) SetOutputDirectory(dirPath string) *Client {
return c
}

// Transport method returns [http.Transport] currently in use or error
// in case the currently used `transport` is not a [http.Transport].
//
// Since v2.8.0 has become exported method.
func (c *Client) Transport() (*http.Transport, error) {
// HTTPTransport method does type assertion and returns [http.Transport]
// from the client instance, if type assertion fails it returns an error
func (c *Client) HTTPTransport() (*http.Transport, error) {
c.lock.RLock()
defer c.lock.RUnlock()
if transport, ok := c.httpClient.Transport.(*http.Transport); ok {
Expand All @@ -1567,6 +1595,14 @@ func (c *Client) Transport() (*http.Transport, error) {
return nil, ErrNotHttpTransportType
}

// Transport method returns underlying client transport referance as-is
// i.e., [http.RoundTripper]
func (c *Client) Transport() http.RoundTripper {
c.lock.RLock()
defer c.lock.RUnlock()
return c.httpClient.Transport
}

// SetTransport method sets custom [http.Transport] or any [http.RoundTripper]
// compatible interface implementation in the Resty client.
//
Expand All @@ -1579,8 +1615,9 @@ func (c *Client) Transport() (*http.Transport, error) {
// client.SetTransport(transport)
//
// NOTE:
// - If transport is not the type of `*http.Transport`, then you may not be able to
// take advantage of some of the Resty client settings.
// - If transport is not the type of [http.Transport], you may lose the
// ability to set a few Resty client settings. However, if you implement
// [TLSClientConfiger] interface, then TLS client config is possible to set.
// - It overwrites the Resty client transport instance and its configurations.
func (c *Client) SetTransport(transport http.RoundTripper) *Client {
c.lock.Lock()
Expand Down Expand Up @@ -2041,12 +2078,18 @@ func (c *Client) execute(req *Request) (*Response, error) {

// getting TLS client config if not exists then create one
func (c *Client) tlsConfig() (*tls.Config, error) {
transport, err := c.Transport()
if err != nil {
return nil, err
}
c.lock.Lock()
defer c.lock.Unlock()

if tc, ok := c.httpClient.Transport.(TLSClientConfiger); ok {
return tc.TLSClientConfig(), nil
}

transport, ok := c.httpClient.Transport.(*http.Transport)
if !ok {
return nil, ErrNotHttpTransportType
}

if transport.TLSClientConfig == nil {
transport.TLSClientConfig = &tls.Config{}
}
Expand Down
119 changes: 92 additions & 27 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ func TestClientSetCertificates(t *testing.T) {
client := dcnl()
client.SetCertificates(tls.Certificate{})

transport, err := client.Transport()
transport, err := client.HTTPTransport()

assertNil(t, err)
assertEqual(t, 1, len(transport.TLSClientConfig.Certificates))
Expand All @@ -275,7 +275,7 @@ func TestClientSetRootCertificate(t *testing.T) {
client := dcnl()
client.SetRootCertificate(filepath.Join(getTestDataPath(), "sample-root.pem"))

transport, err := client.Transport()
transport, err := client.HTTPTransport()

assertNil(t, err)
assertNotNil(t, transport.TLSClientConfig.RootCAs)
Expand All @@ -285,7 +285,7 @@ func TestClientSetRootCertificate(t *testing.T) {
client := dcnl()
client.SetRootCertificate(filepath.Join(getTestDataPath(), "not-exists-sample-root.pem"))

transport, err := client.Transport()
transport, err := client.HTTPTransport()

assertNil(t, err)
assertNil(t, transport.TLSClientConfig)
Expand All @@ -298,23 +298,30 @@ func TestClientSetRootCertificate(t *testing.T) {

client.SetRootCertificateFromString(string(rootPemData))

transport, err := client.Transport()
transport, err := client.HTTPTransport()

assertNil(t, err)
assertNotNil(t, transport.TLSClientConfig.RootCAs)
})
}

type CustomRoundTripper1 struct{}

// RoundTrip just for test
func (rt *CustomRoundTripper1) RoundTrip(_ *http.Request) (*http.Response, error) {
return &http.Response{}, nil
}

func TestClientCACertificateFromStringErrorTls(t *testing.T) {
t.Run("root cert string", func(t *testing.T) {
client := NewWithClient(&http.Client{})
client.outputLogTo(io.Discard)

rootPemData, err := os.ReadFile(filepath.Join(getTestDataPath(), "sample-root.pem"))
assertNil(t, err)
rt := &CustomRoundTripper{}
rt := &CustomRoundTripper1{}
client.SetTransport(rt)
transport, err := client.Transport()
transport, err := client.HTTPTransport()

client.SetRootCertificateFromString(string(rootPemData))

Expand All @@ -329,9 +336,9 @@ func TestClientCACertificateFromStringErrorTls(t *testing.T) {

rootPemData, err := os.ReadFile(filepath.Join(getTestDataPath(), "sample-root.pem"))
assertNil(t, err)
rt := &CustomRoundTripper{}
rt := &CustomRoundTripper1{}
client.SetTransport(rt)
transport, err := client.Transport()
transport, err := client.HTTPTransport()

client.SetClientRootCertificateFromString(string(rootPemData))

Expand All @@ -341,11 +348,78 @@ func TestClientCACertificateFromStringErrorTls(t *testing.T) {
})
}

// CustomRoundTripper2 just for test
type CustomRoundTripper2 struct {
tlsConfig *tls.Config
returnErr bool
}

// RoundTrip just for test
func (rt *CustomRoundTripper2) RoundTrip(_ *http.Request) (*http.Response, error) {
return &http.Response{}, nil
}

func (rt *CustomRoundTripper2) TLSClientConfig() *tls.Config {
return rt.tlsConfig
}
func (rt *CustomRoundTripper2) SetTLSClientConfig(tlsConfig *tls.Config) error {
if rt.returnErr {
return errors.New("test mock error")
}
rt.tlsConfig = tlsConfig
return nil
}

func TestClientTLSConfigerInterface(t *testing.T) {

t.Run("assert transport and custom roundtripper", func(t *testing.T) {
c := dcnl()

assertNotNil(t, c.Transport())
assertEqual(t, "http.Transport", inferType(c.Transport()).String())

ct := &CustomRoundTripper2{}
c.SetTransport(ct)
assertNotNil(t, c.Transport())
assertEqual(t, "resty.CustomRoundTripper2", inferType(c.Transport()).String())
})

t.Run("get and set tls config", func(t *testing.T) {
c := dcnl()

ct := &CustomRoundTripper2{}
c.SetTransport(ct)

tlsConfig := &tls.Config{InsecureSkipVerify: true}
c.SetTLSClientConfig(tlsConfig)
assertEqual(t, tlsConfig, c.TLSClientConfig())
})

t.Run("get tls config error", func(t *testing.T) {
c := dcnl()

ct := &CustomRoundTripper1{}
c.SetTransport(ct)
assertNil(t, c.TLSClientConfig())
})

t.Run("set tls config error", func(t *testing.T) {
c := dcnl()

ct := &CustomRoundTripper2{returnErr: true}
c.SetTransport(ct)

tlsConfig := &tls.Config{InsecureSkipVerify: true}
c.SetTLSClientConfig(tlsConfig)
assertNil(t, c.TLSClientConfig())
})
}

func TestClientSetClientRootCertificate(t *testing.T) {
client := dcnl()
client.SetClientRootCertificate(filepath.Join(getTestDataPath(), "sample-root.pem"))

transport, err := client.Transport()
transport, err := client.HTTPTransport()

assertNil(t, err)
assertNotNil(t, transport.TLSClientConfig.ClientCAs)
Expand All @@ -355,7 +429,7 @@ func TestClientSetClientRootCertificateNotExists(t *testing.T) {
client := dcnl()
client.SetClientRootCertificate(filepath.Join(getTestDataPath(), "not-exists-sample-root.pem"))

transport, err := client.Transport()
transport, err := client.HTTPTransport()

assertNil(t, err)
assertNil(t, transport.TLSClientConfig)
Expand All @@ -368,7 +442,7 @@ func TestClientSetClientRootCertificateWatcher(t *testing.T) {
PoolInterval: time.Second * 1,
})

transport, err := client.Transport()
transport, err := client.HTTPTransport()

assertNil(t, err)
assertNotNil(t, transport.TLSClientConfig.ClientCAs)
Expand All @@ -378,7 +452,7 @@ func TestClientSetClientRootCertificateWatcher(t *testing.T) {
client := dcnl()
client.SetClientRootCertificateWatcher(filepath.Join(getTestDataPath(), "not-exists-sample-root.pem"), nil)

transport, err := client.Transport()
transport, err := client.HTTPTransport()

assertNil(t, err)
assertNil(t, transport.TLSClientConfig)
Expand All @@ -392,7 +466,7 @@ func TestClientSetClientRootCertificateFromString(t *testing.T) {

client.SetClientRootCertificateFromString(string(rootPemData))

transport, err := client.Transport()
transport, err := client.HTTPTransport()

assertNil(t, err)
assertNotNil(t, transport.TLSClientConfig.ClientCAs)
Expand Down Expand Up @@ -444,7 +518,7 @@ func TestClientSetTransport(t *testing.T) {
},
}
client.SetTransport(transport)
transportInUse, err := client.Transport()
transportInUse, err := client.HTTPTransport()

assertNil(t, err)
assertEqual(t, true, transport == transportInUse)
Expand Down Expand Up @@ -502,8 +576,8 @@ func TestClientSettingsCoverage(t *testing.T) {

// [Start] Custom Transport scenario
ct := dcnl()
ct.SetTransport(&CustomRoundTripper{})
_, err := ct.Transport()
ct.SetTransport(&CustomRoundTripper1{})
_, err := ct.HTTPTransport()
assertNotNil(t, err)
assertEqual(t, ErrNotHttpTransportType, err)

Expand Down Expand Up @@ -685,10 +759,10 @@ func TestClientRoundTripper(t *testing.T) {
c := NewWithClient(&http.Client{})
c.outputLogTo(io.Discard)

rt := &CustomRoundTripper{}
rt := &CustomRoundTripper2{}
c.SetTransport(rt)

ct, err := c.Transport()
ct, err := c.HTTPTransport()
assertNotNil(t, err)
assertNil(t, ct)
assertEqual(t, ErrNotHttpTransportType, err)
Expand Down Expand Up @@ -734,15 +808,6 @@ func TestClientDebugBodySizeLimit(t *testing.T) {
}
}

// CustomRoundTripper just for test
type CustomRoundTripper struct {
}

// RoundTrip just for test
func (rt *CustomRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) {
return &http.Response{}, nil
}

func TestGzipCompress(t *testing.T) {
ts := createGenericServer(t)
defer ts.Close()
Expand Down
Loading

0 comments on commit 5a7a5ad

Please sign in to comment.