Skip to content

Commit

Permalink
feat(redirect)!: update error type on no redirect policy and cleanup #…
Browse files Browse the repository at this point in the history
  • Loading branch information
jeevatkm authored Nov 3, 2024
1 parent f6dca4a commit 8471a1a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 30 deletions.
17 changes: 6 additions & 11 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1131,23 +1131,18 @@ func (c *Client) newErrorInterface() any {
// SetRedirectPolicy method sets the redirect policy for the client. Resty provides ready-to-use
// redirect policies. Wanna create one for yourself, refer to `redirect.go`.
//
// client.SetRedirectPolicy(FlexibleRedirectPolicy(20))
// client.SetRedirectPolicy(resty.FlexibleRedirectPolicy(20))
//
// // Need multiple redirect policies together
// client.SetRedirectPolicy(FlexibleRedirectPolicy(20), DomainCheckRedirectPolicy("host1.com", "host2.net"))
func (c *Client) SetRedirectPolicy(policies ...any) *Client {
for _, p := range policies {
if _, ok := p.(RedirectPolicy); !ok {
c.log.Errorf("%v does not implement resty.RedirectPolicy (missing Apply method)",
functionName(p))
}
}

// client.SetRedirectPolicy(resty.FlexibleRedirectPolicy(20), resty.DomainCheckRedirectPolicy("host1.com", "host2.net"))
//
// NOTE: It overwrites the previous redirect policies in the client instance.
func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client {
c.lock.Lock()
defer c.lock.Unlock()
c.httpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
for _, p := range policies {
if err := p.(RedirectPolicy).Apply(req, via); err != nil {
if err := p.Apply(req, via); err != nil {
return err
}
}
Expand Down
18 changes: 10 additions & 8 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,16 +188,21 @@ func TestClientRedirectPolicy(t *testing.T) {
ts := createRedirectServer(t)
defer ts.Close()

c := dcnl().SetRedirectPolicy(FlexibleRedirectPolicy(20))
_, err := c.R().Get(ts.URL + "/redirect-1")
c := dcnl().SetRedirectPolicy(FlexibleRedirectPolicy(20), DomainCheckRedirectPolicy("127.0.0.1"))
_, err := c.R().
SetHeader("Name1", "Value1").
SetHeader("Name2", "Value2").
SetHeader("Name3", "Value3").
Get(ts.URL + "/redirect-1")

assertEqual(t, true, (err.Error() == "Get /redirect-21: stopped after 20 redirects" ||
err.Error() == "Get \"/redirect-21\": stopped after 20 redirects"))

c.SetRedirectPolicy(NoRedirectPolicy())
_, err = c.R().Get(ts.URL + "/redirect-1")
assertEqual(t, true, (err.Error() == "Get /redirect-2: resty: auto redirect is disabled" ||
err.Error() == "Get \"/redirect-2\": resty: auto redirect is disabled"))
res, err := c.R().Get(ts.URL + "/redirect-1")
assertNil(t, err)
assertEqual(t, http.StatusTemporaryRedirect, res.StatusCode())
assertEqual(t, `<a href="/redirect-2">Temporary Redirect</a>.`, res.String())
}

func TestClientTimeout(t *testing.T) {
Expand Down Expand Up @@ -485,9 +490,6 @@ func TestClientSettingsCoverage(t *testing.T) {
c.SetAuthToken(authToken)
assertEqual(t, authToken, c.AuthToken())

type brokenRedirectPolicy struct{}
c.SetRedirectPolicy(&brokenRedirectPolicy{})

c.SetCloseConnection(true)

c.DisableDebug()
Expand Down
19 changes: 8 additions & 11 deletions redirect.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) 2015-2024 Jeevanandam M ([email protected]), All rights reserved.
// Copyright (c) 2015-present Jeevanandam M ([email protected]), All rights reserved.
// resty source code and usage is governed by a MIT style
// license that can be found in the LICENSE file.
// SPDX-License-Identifier: MIT

package resty

Expand All @@ -12,8 +13,6 @@ import (
"strings"
)

var ErrAutoRedirectDisabled = errors.New("resty: auto redirect is disabled")

type (
// RedirectPolicy to regulate the redirects in the Resty client.
// Objects implementing the [RedirectPolicy] interface can be registered as
Expand All @@ -35,12 +34,12 @@ func (f RedirectPolicyFunc) Apply(req *http.Request, via []*http.Request) error
return f(req, via)
}

// NoRedirectPolicy is used to disable redirects in the Resty client
// NoRedirectPolicy is used to disable the redirects in the Resty client
//
// resty.SetRedirectPolicy(NoRedirectPolicy())
// resty.SetRedirectPolicy(resty.NoRedirectPolicy())
func NoRedirectPolicy() RedirectPolicy {
return RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error {
return ErrAutoRedirectDisabled
return http.ErrUseLastResponse
})
}

Expand All @@ -60,22 +59,20 @@ func FlexibleRedirectPolicy(noOfRedirect int) RedirectPolicy {
// DomainCheckRedirectPolicy method is convenient for defining domain name redirect rules in Resty clients.
// Redirect is allowed only for the host mentioned in the policy.
//
// resty.SetRedirectPolicy(DomainCheckRedirectPolicy("host1.com", "host2.org", "host3.net"))
// resty.SetRedirectPolicy(resty.DomainCheckRedirectPolicy("host1.com", "host2.org", "host3.net"))
func DomainCheckRedirectPolicy(hostnames ...string) RedirectPolicy {
hosts := make(map[string]bool)
for _, h := range hostnames {
hosts[strings.ToLower(h)] = true
}

fn := RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error {
return RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error {
if ok := hosts[getHostname(req.URL.Host)]; !ok {
return errors.New("redirect is not allowed as per DomainCheckRedirectPolicy")
}

checkHostAndAddHeaders(req, via[0])
return nil
})

return fn
}

func getHostname(host string) (hostname string) {
Expand Down

0 comments on commit 8471a1a

Please sign in to comment.