From 08171029e735b4832aaa5d968293edac7acf2f56 Mon Sep 17 00:00:00 2001 From: "Jeevanandam M." Date: Mon, 4 Nov 2024 19:32:30 -0800 Subject: [PATCH] feat!: add TLSClientConfiger interface to support TLSClientConfig for custom round tripper #377 #810 (#901) --- cert_watcher_test.go | 2 +- client.go | 83 ++++++++++++++++++++++-------- client_test.go | 119 +++++++++++++++++++++++++++++++++---------- request_test.go | 2 +- 4 files changed, 157 insertions(+), 49 deletions(-) diff --git a/cert_watcher_test.go b/cert_watcher_test.go index e3ccd554..7f2d4b98 100644 --- a/cert_watcher_test.go +++ b/cert_watcher_test.go @@ -94,7 +94,7 @@ func TestClient_SetRootCertificateWatcher(t *testing.T) { assertNil(t, err) // Reset TLS config to ensure that previous root cert is not re-used - tr, err := client.Transport() + tr, err := client.HTTPTransport() assertNil(t, err) tr.TLSClientConfig = nil client.SetTransport(tr) diff --git a/client.go b/client.go index f42f36a9..235a9e00 100644 --- a/client.go +++ b/client.go @@ -99,6 +99,13 @@ type ( // SuccessHook type is for reacting to request success SuccessHook func(*Client, *Response) + + // TLSClientConfiger interface is to configure TLS Client configuration on custom transport + // implemented using [http.RoundTripper] + TLSClientConfiger interface { + TLSClientConfig() *tls.Config + SetTLSClientConfig(*tls.Config) error + } ) // TransportSettings struct is used to define custom dialer and transport @@ -1297,6 +1304,18 @@ func (c *Client) AddRetryHook(hook RetryHookFunc) *Client { return c } +// TLSClientConfig method returns the [tls.Config] from underlying client transport +// otherwise returns nil +func (c *Client) TLSClientConfig() *tls.Config { + cfg, err := c.tlsConfig() + if err != nil { + c.lock.RLock() + c.log.Errorf("%v", err) + c.lock.RUnlock() + } + return cfg +} + // SetTLSClientConfig method sets TLSClientConfig for underlying client Transport. // // For Example: @@ -1308,15 +1327,26 @@ func (c *Client) AddRetryHook(hook RetryHookFunc) *Client { // client.SetTLSClientConfig(&tls.Config{ InsecureSkipVerify: true }) // // NOTE: This method overwrites existing [http.Transport.TLSClientConfig] -func (c *Client) SetTLSClientConfig(config *tls.Config) *Client { - transport, err := c.Transport() - if err != nil { - c.log.Errorf("%v", err) - return c - } +func (c *Client) SetTLSClientConfig(tlsConfig *tls.Config) *Client { c.lock.Lock() defer c.lock.Unlock() - transport.TLSClientConfig = config + + // TLSClientConfiger interface handling + if tc, ok := c.httpClient.Transport.(TLSClientConfiger); ok { + if err := tc.SetTLSClientConfig(tlsConfig); err != nil { + c.log.Errorf("%v", err) + } + return c + } + + // default standard transport handling + transport, ok := c.httpClient.Transport.(*http.Transport) + if !ok { + c.log.Errorf("SetTLSClientConfig: %v", ErrNotHttpTransportType) + return c + } + transport.TLSClientConfig = tlsConfig + return c } @@ -1333,7 +1363,7 @@ func (c *Client) ProxyURL() *url.URL { // // OR you could also set Proxy via environment variable, refer to [http.ProxyFromEnvironment] func (c *Client) SetProxy(proxyURL string) *Client { - transport, err := c.Transport() + transport, err := c.HTTPTransport() if err != nil { c.log.Errorf("%v", err) return c @@ -1356,7 +1386,7 @@ func (c *Client) SetProxy(proxyURL string) *Client { // // client.RemoveProxy() func (c *Client) RemoveProxy() *Client { - transport, err := c.Transport() + transport, err := c.HTTPTransport() if err != nil { c.log.Errorf("%v", err) return c @@ -1554,11 +1584,9 @@ func (c *Client) SetOutputDirectory(dirPath string) *Client { return c } -// Transport method returns [http.Transport] currently in use or error -// in case the currently used `transport` is not a [http.Transport]. -// -// Since v2.8.0 has become exported method. -func (c *Client) Transport() (*http.Transport, error) { +// HTTPTransport method does type assertion and returns [http.Transport] +// from the client instance, if type assertion fails it returns an error +func (c *Client) HTTPTransport() (*http.Transport, error) { c.lock.RLock() defer c.lock.RUnlock() if transport, ok := c.httpClient.Transport.(*http.Transport); ok { @@ -1567,6 +1595,14 @@ func (c *Client) Transport() (*http.Transport, error) { return nil, ErrNotHttpTransportType } +// Transport method returns underlying client transport referance as-is +// i.e., [http.RoundTripper] +func (c *Client) Transport() http.RoundTripper { + c.lock.RLock() + defer c.lock.RUnlock() + return c.httpClient.Transport +} + // SetTransport method sets custom [http.Transport] or any [http.RoundTripper] // compatible interface implementation in the Resty client. // @@ -1579,8 +1615,9 @@ func (c *Client) Transport() (*http.Transport, error) { // client.SetTransport(transport) // // NOTE: -// - If transport is not the type of `*http.Transport`, then you may not be able to -// take advantage of some of the Resty client settings. +// - If transport is not the type of [http.Transport], you may lose the +// ability to set a few Resty client settings. However, if you implement +// [TLSClientConfiger] interface, then TLS client config is possible to set. // - It overwrites the Resty client transport instance and its configurations. func (c *Client) SetTransport(transport http.RoundTripper) *Client { c.lock.Lock() @@ -2041,12 +2078,18 @@ func (c *Client) execute(req *Request) (*Response, error) { // getting TLS client config if not exists then create one func (c *Client) tlsConfig() (*tls.Config, error) { - transport, err := c.Transport() - if err != nil { - return nil, err - } c.lock.Lock() defer c.lock.Unlock() + + if tc, ok := c.httpClient.Transport.(TLSClientConfiger); ok { + return tc.TLSClientConfig(), nil + } + + transport, ok := c.httpClient.Transport.(*http.Transport) + if !ok { + return nil, ErrNotHttpTransportType + } + if transport.TLSClientConfig == nil { transport.TLSClientConfig = &tls.Config{} } diff --git a/client_test.go b/client_test.go index 86ee49e1..9d91efa4 100644 --- a/client_test.go +++ b/client_test.go @@ -264,7 +264,7 @@ func TestClientSetCertificates(t *testing.T) { client := dcnl() client.SetCertificates(tls.Certificate{}) - transport, err := client.Transport() + transport, err := client.HTTPTransport() assertNil(t, err) assertEqual(t, 1, len(transport.TLSClientConfig.Certificates)) @@ -275,7 +275,7 @@ func TestClientSetRootCertificate(t *testing.T) { client := dcnl() client.SetRootCertificate(filepath.Join(getTestDataPath(), "sample-root.pem")) - transport, err := client.Transport() + transport, err := client.HTTPTransport() assertNil(t, err) assertNotNil(t, transport.TLSClientConfig.RootCAs) @@ -285,7 +285,7 @@ func TestClientSetRootCertificate(t *testing.T) { client := dcnl() client.SetRootCertificate(filepath.Join(getTestDataPath(), "not-exists-sample-root.pem")) - transport, err := client.Transport() + transport, err := client.HTTPTransport() assertNil(t, err) assertNil(t, transport.TLSClientConfig) @@ -298,13 +298,20 @@ func TestClientSetRootCertificate(t *testing.T) { client.SetRootCertificateFromString(string(rootPemData)) - transport, err := client.Transport() + transport, err := client.HTTPTransport() assertNil(t, err) assertNotNil(t, transport.TLSClientConfig.RootCAs) }) } +type CustomRoundTripper1 struct{} + +// RoundTrip just for test +func (rt *CustomRoundTripper1) RoundTrip(_ *http.Request) (*http.Response, error) { + return &http.Response{}, nil +} + func TestClientCACertificateFromStringErrorTls(t *testing.T) { t.Run("root cert string", func(t *testing.T) { client := NewWithClient(&http.Client{}) @@ -312,9 +319,9 @@ func TestClientCACertificateFromStringErrorTls(t *testing.T) { rootPemData, err := os.ReadFile(filepath.Join(getTestDataPath(), "sample-root.pem")) assertNil(t, err) - rt := &CustomRoundTripper{} + rt := &CustomRoundTripper1{} client.SetTransport(rt) - transport, err := client.Transport() + transport, err := client.HTTPTransport() client.SetRootCertificateFromString(string(rootPemData)) @@ -329,9 +336,9 @@ func TestClientCACertificateFromStringErrorTls(t *testing.T) { rootPemData, err := os.ReadFile(filepath.Join(getTestDataPath(), "sample-root.pem")) assertNil(t, err) - rt := &CustomRoundTripper{} + rt := &CustomRoundTripper1{} client.SetTransport(rt) - transport, err := client.Transport() + transport, err := client.HTTPTransport() client.SetClientRootCertificateFromString(string(rootPemData)) @@ -341,11 +348,78 @@ func TestClientCACertificateFromStringErrorTls(t *testing.T) { }) } +// CustomRoundTripper2 just for test +type CustomRoundTripper2 struct { + tlsConfig *tls.Config + returnErr bool +} + +// RoundTrip just for test +func (rt *CustomRoundTripper2) RoundTrip(_ *http.Request) (*http.Response, error) { + return &http.Response{}, nil +} + +func (rt *CustomRoundTripper2) TLSClientConfig() *tls.Config { + return rt.tlsConfig +} +func (rt *CustomRoundTripper2) SetTLSClientConfig(tlsConfig *tls.Config) error { + if rt.returnErr { + return errors.New("test mock error") + } + rt.tlsConfig = tlsConfig + return nil +} + +func TestClientTLSConfigerInterface(t *testing.T) { + + t.Run("assert transport and custom roundtripper", func(t *testing.T) { + c := dcnl() + + assertNotNil(t, c.Transport()) + assertEqual(t, "http.Transport", inferType(c.Transport()).String()) + + ct := &CustomRoundTripper2{} + c.SetTransport(ct) + assertNotNil(t, c.Transport()) + assertEqual(t, "resty.CustomRoundTripper2", inferType(c.Transport()).String()) + }) + + t.Run("get and set tls config", func(t *testing.T) { + c := dcnl() + + ct := &CustomRoundTripper2{} + c.SetTransport(ct) + + tlsConfig := &tls.Config{InsecureSkipVerify: true} + c.SetTLSClientConfig(tlsConfig) + assertEqual(t, tlsConfig, c.TLSClientConfig()) + }) + + t.Run("get tls config error", func(t *testing.T) { + c := dcnl() + + ct := &CustomRoundTripper1{} + c.SetTransport(ct) + assertNil(t, c.TLSClientConfig()) + }) + + t.Run("set tls config error", func(t *testing.T) { + c := dcnl() + + ct := &CustomRoundTripper2{returnErr: true} + c.SetTransport(ct) + + tlsConfig := &tls.Config{InsecureSkipVerify: true} + c.SetTLSClientConfig(tlsConfig) + assertNil(t, c.TLSClientConfig()) + }) +} + func TestClientSetClientRootCertificate(t *testing.T) { client := dcnl() client.SetClientRootCertificate(filepath.Join(getTestDataPath(), "sample-root.pem")) - transport, err := client.Transport() + transport, err := client.HTTPTransport() assertNil(t, err) assertNotNil(t, transport.TLSClientConfig.ClientCAs) @@ -355,7 +429,7 @@ func TestClientSetClientRootCertificateNotExists(t *testing.T) { client := dcnl() client.SetClientRootCertificate(filepath.Join(getTestDataPath(), "not-exists-sample-root.pem")) - transport, err := client.Transport() + transport, err := client.HTTPTransport() assertNil(t, err) assertNil(t, transport.TLSClientConfig) @@ -368,7 +442,7 @@ func TestClientSetClientRootCertificateWatcher(t *testing.T) { PoolInterval: time.Second * 1, }) - transport, err := client.Transport() + transport, err := client.HTTPTransport() assertNil(t, err) assertNotNil(t, transport.TLSClientConfig.ClientCAs) @@ -378,7 +452,7 @@ func TestClientSetClientRootCertificateWatcher(t *testing.T) { client := dcnl() client.SetClientRootCertificateWatcher(filepath.Join(getTestDataPath(), "not-exists-sample-root.pem"), nil) - transport, err := client.Transport() + transport, err := client.HTTPTransport() assertNil(t, err) assertNil(t, transport.TLSClientConfig) @@ -392,7 +466,7 @@ func TestClientSetClientRootCertificateFromString(t *testing.T) { client.SetClientRootCertificateFromString(string(rootPemData)) - transport, err := client.Transport() + transport, err := client.HTTPTransport() assertNil(t, err) assertNotNil(t, transport.TLSClientConfig.ClientCAs) @@ -444,7 +518,7 @@ func TestClientSetTransport(t *testing.T) { }, } client.SetTransport(transport) - transportInUse, err := client.Transport() + transportInUse, err := client.HTTPTransport() assertNil(t, err) assertEqual(t, true, transport == transportInUse) @@ -502,8 +576,8 @@ func TestClientSettingsCoverage(t *testing.T) { // [Start] Custom Transport scenario ct := dcnl() - ct.SetTransport(&CustomRoundTripper{}) - _, err := ct.Transport() + ct.SetTransport(&CustomRoundTripper1{}) + _, err := ct.HTTPTransport() assertNotNil(t, err) assertEqual(t, ErrNotHttpTransportType, err) @@ -685,10 +759,10 @@ func TestClientRoundTripper(t *testing.T) { c := NewWithClient(&http.Client{}) c.outputLogTo(io.Discard) - rt := &CustomRoundTripper{} + rt := &CustomRoundTripper2{} c.SetTransport(rt) - ct, err := c.Transport() + ct, err := c.HTTPTransport() assertNotNil(t, err) assertNil(t, ct) assertEqual(t, ErrNotHttpTransportType, err) @@ -734,15 +808,6 @@ func TestClientDebugBodySizeLimit(t *testing.T) { } } -// CustomRoundTripper just for test -type CustomRoundTripper struct { -} - -// RoundTrip just for test -func (rt *CustomRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) { - return &http.Response{}, nil -} - func TestGzipCompress(t *testing.T) { ts := createGenericServer(t) defer ts.Close() diff --git a/request_test.go b/request_test.go index 3ca22f0a..abd3667f 100644 --- a/request_test.go +++ b/request_test.go @@ -1108,7 +1108,7 @@ func TestRawFileUploadByBody(t *testing.T) { func TestProxySetting(t *testing.T) { c := dcnl() - transport, err := c.Transport() + transport, err := c.HTTPTransport() assertNil(t, err)