Skip to content

Commit

Permalink
refactor(enhancement): use std lib functions (#912)
Browse files Browse the repository at this point in the history
- curl cmd request body error scenario and removed pointer on string
- redirect add error prefix and remove default header
- test case additions
  • Loading branch information
jeevatkm authored Nov 16, 2024
1 parent a13791b commit e134807
Show file tree
Hide file tree
Showing 12 changed files with 102 additions and 37 deletions.
15 changes: 3 additions & 12 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -915,10 +915,11 @@ func (c *Client) ContentDecompressors() map[string]ContentDecompressor {
//
// [RFC 9110]: https://datatracker.ietf.org/doc/html/rfc9110
func (c *Client) AddContentDecompressor(k string, d ContentDecompressor) *Client {
c.insertFirstContentDecompressor(k)

c.lock.Lock()
defer c.lock.Unlock()
if !slices.Contains(c.contentDecompressorKeys, k) {
c.contentDecompressorKeys = slices.Insert(c.contentDecompressorKeys, 0, k)
}
c.contentDecompressors[k] = d
return c
}
Expand Down Expand Up @@ -955,16 +956,6 @@ func (c *Client) SetContentDecompressorKeys(keys []string) *Client {
return c
}

func (c *Client) insertFirstContentDecompressor(k string) {
c.lock.Lock()
defer c.lock.Unlock()
if !slices.Contains(c.contentDecompressorKeys, k) {
c.contentDecompressorKeys = append(c.contentDecompressorKeys, "")
copy(c.contentDecompressorKeys[1:], c.contentDecompressorKeys)
c.contentDecompressorKeys[0] = k
}
}

// IsDebug method returns `true` if the client is in debug mode; otherwise, it is `false`.
func (c *Client) IsDebug() bool {
c.lock.RLock()
Expand Down
28 changes: 25 additions & 3 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,7 @@ func TestClientRedirectPolicy(t *testing.T) {
SetHeader("Name3", "Value3").
Get(ts.URL + "/redirect-1")

assertEqual(t, true, (err.Error() == "Get /redirect-21: stopped after 20 redirects" ||
err.Error() == "Get \"/redirect-21\": stopped after 20 redirects"))
assertEqual(t, true, err.Error() == "Get \"/redirect-21\": resty: stopped after 20 redirects")

c.SetRedirectPolicy(NoRedirectPolicy())
res, err := c.R().Get(ts.URL + "/redirect-1")
Expand Down Expand Up @@ -804,7 +803,7 @@ func TestClientDebugBodySizeLimit(t *testing.T) {
// JSON, does not exceed limit.
{url: ts.URL + "/json", want: "{\n \"TestGet\": \"JSON response\"\n}"},
// Invalid JSON, does not exceed limit.
{url: ts.URL + "/json-invalid", want: "Debug: Response.fmtBodyString: invalid character 'T' looking for beginning of value"},
{url: ts.URL + "/json-invalid", want: "DebugLog: Response.fmtBodyString: invalid character 'T' looking for beginning of value"},
// Text, exceeds limit.
{url: ts.URL + "/long-text", want: "RESPONSE TOO LARGE"},
// JSON, exceeds limit.
Expand Down Expand Up @@ -1446,6 +1445,29 @@ func TestResponseBodyLimit(t *testing.T) {
})
}

func TestClient_executeReadAllError(t *testing.T) {
ts := createGetServer(t)
defer ts.Close()

ioReadAll = func(_ io.Reader) ([]byte, error) {
return nil, errors.New("test case error")
}
t.Cleanup(func() {
ioReadAll = io.ReadAll
})

c := dcnld()

resp, err := c.R().
SetQueryParam("request_no", strconv.FormatInt(time.Now().Unix(), 10)).
Get(ts.URL + "/json")

assertNotNil(t, err)
assertEqual(t, "test case error", err.Error())
assertEqual(t, http.StatusOK, resp.StatusCode())
assertEqual(t, "", resp.String())
}

func TestClientDebugf(t *testing.T) {
t.Run("Debug mode enabled", func(t *testing.T) {
var b bytes.Buffer
Expand Down
17 changes: 11 additions & 6 deletions curl.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
// Copyright (c) 2015-present Jeevanandam M ([email protected]), All rights reserved.
// resty source code and usage is governed by a MIT style
// license that can be found in the LICENSE file.
// SPDX-License-Identifier: MIT

package resty

import (
Expand All @@ -19,22 +24,22 @@ func buildCurlCmd(req *Request) string {
}

// 2. Generate curl cookies
// TODO validate this block of code, I think its not required since cookie captured via Headers
if cookieJar := req.client.CookieJar(); cookieJar != nil {
if cookies := cookieJar.Cookies(req.RawRequest.URL); len(cookies) > 0 {
curl += "-H " + cmdQuote(dumpCurlCookies(cookies)) + " "
}
}

// 3. Generate curl body
// 3. Generate curl body except for io.Reader and multipart request
if req.RawRequest.GetBody != nil {
body, err := req.RawRequest.GetBody()
if err != nil {
if err == nil {
buf, _ := io.ReadAll(body)
curl += "-d " + cmdQuote(string(bytes.TrimRight(buf, "\n"))) + " "
} else {
req.log.Errorf("curl: %v", err)
return ""
curl += "-d ''"
}
buf, _ := io.ReadAll(body)
curl += "-d " + cmdQuote(string(bytes.TrimRight(buf, "\n"))) + " "
}

urlString := cmdQuote(req.RawRequest.URL.String())
Expand Down
46 changes: 46 additions & 0 deletions curl_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
// Copyright (c) 2015-present Jeevanandam M ([email protected]), All rights reserved.
// resty source code and usage is governed by a MIT style
// license that can be found in the LICENSE file.
// SPDX-License-Identifier: MIT

package resty

import (
"bytes"
"errors"
"io"
"net/http"
"net/http/cookiejar"
"strings"
Expand Down Expand Up @@ -206,6 +213,45 @@ func TestCurl_buildCurlCmd(t *testing.T) {
}
}

func TestCurlRequestGetBodyError(t *testing.T) {
c := dcnl().
EnableDebug().
SetRequestMiddlewares(
PrepareRequestMiddleware,
func(_ *Client, r *Request) error {
r.RawRequest.GetBody = func() (io.ReadCloser, error) {
return nil, errors.New("test case error")
}
return nil
},
)

req := c.R().
SetBody(map[string]string{
"name": "Resty",
}).
SetCookies(
[]*http.Cookie{
{Name: "count", Value: "1"},
},
).
SetMethod(MethodPost)

assertEqual(t, "", req.GenerateCurlCommand())

curlCmdUnexecuted := req.EnableGenerateCurlOnDebug().GenerateCurlCommand()
req.DisableGenerateCurlOnDebug()

if !strings.Contains(curlCmdUnexecuted, "Cookie: count=1") ||
!strings.Contains(curlCmdUnexecuted, "curl -X POST") ||
!strings.Contains(curlCmdUnexecuted, `-d ''`) {
t.Fatal("Incomplete curl:", curlCmdUnexecuted)
} else {
t.Log("curlCmdUnexecuted: \n", curlCmdUnexecuted)
}

}

func TestCurlMiscTestCoverage(t *testing.T) {
cookieStr := dumpCurlCookies([]*http.Cookie{
{Name: "count", Value: "1"},
Expand Down
1 change: 1 addition & 0 deletions load_balancer.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) 2015-present Jeevanandam M ([email protected]), All rights reserved.
// resty source code and usage is governed by a MIT style
// license that can be found in the LICENSE file.
// SPDX-License-Identifier: MIT

package resty

Expand Down
1 change: 1 addition & 0 deletions load_balancer_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) 2015-present Jeevanandam M ([email protected]), All rights reserved.
// resty source code and usage is governed by a MIT style
// license that can be found in the LICENSE file.
// SPDX-License-Identifier: MIT

package resty

Expand Down
5 changes: 2 additions & 3 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ func PrepareRequestMiddleware(c *Client, r *Request) error {
// See, [Client.SetGenerateCurlOnDebug], [Request.SetGenerateCurlOnDebug]
func GenerateCurlRequestMiddleware(c *Client, r *Request) (err error) {
if r.Debug && r.generateCurlOnDebug {
if r.resultCurlCmd == nil {
r.resultCurlCmd = new(string)
if isStringEmpty(r.resultCurlCmd) {
r.resultCurlCmd = buildCurlCmd(r)
}
*r.resultCurlCmd = buildCurlCmd(r)
}
return nil
}
Expand Down
4 changes: 1 addition & 3 deletions redirect.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func NoRedirectPolicy() RedirectPolicy {
func FlexibleRedirectPolicy(noOfRedirect int) RedirectPolicy {
return RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error {
if len(via) >= noOfRedirect {
return fmt.Errorf("stopped after %d redirects", noOfRedirect)
return fmt.Errorf("resty: stopped after %d redirects", noOfRedirect)
}
checkHostAndAddHeaders(req, via[0])
return nil
Expand Down Expand Up @@ -95,7 +95,5 @@ func checkHostAndAddHeaders(cur *http.Request, pre *http.Request) {
for key, val := range pre.Header {
cur.Header[key] = val
}
} else { // only library User-Agent header is added
cur.Header.Set(hdrUserAgentKey, hdrUserAgentValue)
}
}
11 changes: 5 additions & 6 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ type Request struct {
multipartBoundary string
multipartFields []*MultipartField
retryConditions []RetryConditionFunc
resultCurlCmd *string
resultCurlCmd string
generateCurlOnDebug bool
unescapeQueryParams bool
multipartErrChan chan error
Expand All @@ -103,17 +103,17 @@ func (r *Request) GenerateCurlCommand() string {
if !(r.Debug && r.generateCurlOnDebug) {
return ""
}
if r.resultCurlCmd != nil {
return *r.resultCurlCmd
if len(r.resultCurlCmd) > 0 {
return r.resultCurlCmd
}
if r.RawRequest == nil {
// mock with r.Get("/")
if err := r.client.executeRequestMiddlewares(r); err != nil {
r.log.Errorf("%v", err)
}
}
*r.resultCurlCmd = buildCurlCmd(r)
return *r.resultCurlCmd
r.resultCurlCmd = buildCurlCmd(r)
return r.resultCurlCmd
}

// SetMethod method used to set the HTTP verb for the request
Expand Down Expand Up @@ -1424,7 +1424,6 @@ func (r *Request) Clone(ctx context.Context) *Request {
rr.Time = time.Time{}
rr.Attempt = 0
rr.initTraceIfEnabled()
rr.resultCurlCmd = new(string)
r.values = make(map[string]any)
r.multipartErrChan = nil

Expand Down
8 changes: 5 additions & 3 deletions response.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func (r *Response) fmtBodyString(sl int) string {
defer releaseBuffer(out)
err := json.Indent(out, r.bodyBytes, "", " ")
if err != nil {
r.Request.log.Errorf("Debug: Response.fmtBodyString: %v", err)
r.Request.log.Errorf("DebugLog: Response.fmtBodyString: %v", err)
return ""
}
return out.String()
Expand All @@ -199,6 +199,8 @@ func (r *Response) readIfRequired() {
}
}

var ioReadAll = io.ReadAll

// auto-unmarshal didn't happen, so fallback to
// old behavior of reading response as body bytes
func (r *Response) readAll() (err error) {
Expand All @@ -207,9 +209,9 @@ func (r *Response) readAll() (err error) {
}

if _, ok := r.Body.(*copyReadCloser); ok {
_, err = io.ReadAll(r.Body)
_, err = ioReadAll(r.Body)
} else {
r.bodyBytes, err = io.ReadAll(r.Body)
r.bodyBytes, err = ioReadAll(r.Body)
closeq(r.Body)
r.Body = &nopReadCloser{r: bytes.NewReader(r.bodyBytes)}
}
Expand Down
1 change: 1 addition & 0 deletions stream.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) 2015-present Jeevanandam M ([email protected]), All rights reserved.
// resty source code and usage is governed by a MIT style
// license that can be found in the LICENSE file.
// SPDX-License-Identifier: MIT

package resty

Expand Down
2 changes: 1 addition & 1 deletion util.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ func requestDebugLogger(c *Client, r *Request) {

if r.Debug && r.generateCurlOnDebug {
reqLog += "~~~ REQUEST(CURL) ~~~\n" +
fmt.Sprintf(" %v\n", *r.resultCurlCmd)
fmt.Sprintf(" %v\n", r.resultCurlCmd)
}

reqLog += "~~~ REQUEST ~~~\n" +
Expand Down

0 comments on commit e134807

Please sign in to comment.