From 72cb565d0942d1f9cec6f4a2c31e487ec3f0dde1 Mon Sep 17 00:00:00 2001 From: Jeevanandam M Date: Tue, 5 Nov 2024 20:58:40 -0800 Subject: [PATCH] feat!: add fully composable request and response middleware - PreRequestHook removed - Resty request and response middleware become exported - Any order of middleware insertion is feasible --- README.md | 4 +- client.go | 175 ++++++++++++++++++++++++++----------------- client_test.go | 63 ++++++++++------ middleware.go | 182 +++++++++++++++++++++++++++------------------ middleware_test.go | 25 ++++++- multipart_test.go | 39 ++++++++-- request.go | 14 +++- request_test.go | 26 +++---- resty.go | 27 +++---- util.go | 1 + 10 files changed, 343 insertions(+), 213 deletions(-) diff --git a/README.md b/README.md index b8cfea00..bc7a065e 100644 --- a/README.md +++ b/README.md @@ -517,7 +517,7 @@ Resty provides middleware ability to manipulate for Request and Response. It is client := resty.New() // Registering Request Middleware -client.OnBeforeRequest(func(c *resty.Client, req *resty.Request) error { +client.OnBeforeRequest(func(c *resty.Client, req *resty.Request) error { // TODO update docs // Now you have access to Client and current Request object // manipulate it as per your need @@ -525,7 +525,7 @@ client.OnBeforeRequest(func(c *resty.Client, req *resty.Request) error { }) // Registering Response Middleware -client.OnAfterResponse(func(c *resty.Client, resp *resty.Response) error { +client.OnAfterResponse(func(c *resty.Client, resp *resty.Response) error { // TODO update docs // Now you have access to Client and current Response object // manipulate it as per your need diff --git a/client.go b/client.go index 235a9e00..52e316e0 100644 --- a/client.go +++ b/client.go @@ -87,9 +87,6 @@ type ( // ResponseMiddleware type is for response middleware, called after a response has been received ResponseMiddleware func(*Client, *Response) error - // PreRequestHook type is for the request hook, called right before the request is sent - PreRequestHook func(*Client, *http.Request) error - // DebugLogCallback type is for request and response debug log callback purpose. // It gets called before Resty logs it DebugLogCallback func(*DebugLog) @@ -209,7 +206,6 @@ type Client struct { unescapeQueryParams bool loadBalancer LoadBalancer beforeRequest []RequestMiddleware - udBeforeRequest []RequestMiddleware afterResponse []ResponseMiddleware errorHooks []ErrorHook invalidHooks []ErrorHook @@ -220,9 +216,6 @@ type Client struct { contentDecompressorKeys []string contentDecompressors map[string]ContentDecompressor certWatcherStopChan chan bool - - // TODO don't put mutex now, it may go away - preReqHook PreRequestHook } // User type is to hold an username and password information @@ -611,14 +604,14 @@ func (c *Client) SetDigestAuth(username, password string) *Client { c.lock.Lock() oldTransport := c.httpClient.Transport c.lock.Unlock() - c.OnBeforeRequest(func(c *Client, _ *Request) error { + c.AddRequestMiddleware(func(c *Client, _ *Request) error { c.httpClient.Transport = &digestTransport{ digestCredentials: digestCredentials{username, password}, transport: oldTransport, } return nil }) - c.OnAfterResponse(func(c *Client, _ *Response) error { + c.AddResponseMiddleware(func(c *Client, _ *Response) error { c.httpClient.Transport = oldTransport return nil }) @@ -676,26 +669,78 @@ func (c *Client) NewRequest() *Request { return c.R() } -func (c *Client) beforeRequestMiddlewares() []RequestMiddleware { - c.lock.RLock() - defer c.lock.RUnlock() - return c.udBeforeRequest +// SetRequestMiddlewares method allows Resty users to override the default request +// middlewares sequence +// +// client := New() +// defer client.Close() +// +// client.SetRequestMiddlewares( +// CustomRequest1Middleware, +// CustomRequest2Middleware, +// resty.PrepareRequestMiddleware, // after this, Request.RawRequest is available +// resty.GenerateCurlRequestMiddleware, +// CustomRequest3Middleware, +// CustomRequest4Middleware, +// ) +// +// See, [Client.AddRequestMiddleware] +// +// NOTE: +// - It overwrites the existing request middleware list. +// - Be sure to include Resty request middlewares in the request chain at the appropriate spot. +func (c *Client) SetRequestMiddlewares(middlewares ...RequestMiddleware) *Client { + c.lock.Lock() + defer c.lock.Unlock() + c.beforeRequest = middlewares + return c } -// OnBeforeRequest method appends a request middleware to the before request chain. -// The user-defined middlewares are applied before the default Resty request middlewares. -// After all middlewares have been applied, the request is sent from Resty to the host server. +// SetResponseMiddlewares method allows Resty users to override the default response +// middlewares sequence // -// client.OnBeforeRequest(func(c *resty.Client, r *resty.Request) error { -// // Now you have access to the Client and Request instance -// // manipulate it as per your need +// client := New() +// defer client.Close() // -// return nil // if its successful otherwise return error -// }) -func (c *Client) OnBeforeRequest(m RequestMiddleware) *Client { +// client.SetResponseMiddlewares( +// CustomResponse1Middleware, +// CustomResponse2Middleware, +// resty.AutoParseResponseMiddleware, // before this, body is not read except on debug flow +// CustomResponse3Middleware, +// resty.SaveToFileResponseMiddleware, // See, Request.SetOutputFile +// CustomResponse4Middleware, +// CustomResponse5Middleware, +// ) +// +// See, [Client.AddResponseMiddleware] +// +// NOTE: +// - It overwrites the existing request middleware list. +// - Be sure to include Resty response middlewares in the response chain at the appropriate spot. +func (c *Client) SetResponseMiddlewares(middlewares ...ResponseMiddleware) *Client { + c.lock.Lock() + defer c.lock.Unlock() + c.afterResponse = middlewares + return c +} + +// AddRequestMiddleware method appends a request middleware to the before request chain. +// After all requests, middlewares are applied, and the request is sent to the host server. +// +// client.AddRequestMiddleware(func(c *resty.Client, r *resty.Request) error { +// // Now you have access to the Client and Request instance +// // manipulate it as per your need +// +// return nil // if its successful otherwise return error +// }) +// +// NOTE: +// - Do not use [Client] setter methods within Request middleware; deadlock will happen. +func (c *Client) AddRequestMiddleware(m RequestMiddleware) *Client { c.lock.Lock() defer c.lock.Unlock() - c.udBeforeRequest = append(c.udBeforeRequest, m) + idx := len(c.beforeRequest) - 2 + c.beforeRequest = slices.Insert(c.beforeRequest, idx, m) return c } @@ -705,17 +750,20 @@ func (c *Client) afterResponseMiddlewares() []ResponseMiddleware { return c.afterResponse } -// OnAfterResponse method appends response middleware to the after-response chain. -// Once we receive a response from the host server, the default Resty response middleware -// gets applied, and then the user-assigned response middleware is applied. +// AddResponseMiddleware method appends response middleware to the after-response chain. +// All the response middlewares are applied; once we receive a response +// from the host server. // -// client.OnAfterResponse(func(c *resty.Client, r *resty.Response) error { -// // Now you have access to the Client and Response instance -// // manipulate it as per your need +// client.AddResponseMiddleware(func(c *resty.Client, r *resty.Response) error { +// // Now you have access to the Client and Response instance +// // manipulate it as per your need // -// return nil // if its successful otherwise return error -// }) -func (c *Client) OnAfterResponse(m ResponseMiddleware) *Client { +// return nil // if its successful otherwise return error +// }) +// +// NOTE: +// - Do not use [Client] setter methods within Response middleware; deadlock will happen. +func (c *Client) AddResponseMiddleware(m ResponseMiddleware) *Client { c.lock.Lock() defer c.lock.Unlock() c.afterResponse = append(c.afterResponse, m) @@ -736,6 +784,9 @@ func (c *Client) OnAfterResponse(m ResponseMiddleware) *Client { // // Out of the [Client.OnSuccess], [Client.OnError], [Client.OnInvalid], [Client.OnPanic] // callbacks, exactly one set will be invoked for each call to [Request.Execute] that completes. +// +// NOTE: +// - Do not use [Client] setter methods within OnError hooks; deadlock will happen. func (c *Client) OnError(h ErrorHook) *Client { c.lock.Lock() defer c.lock.Unlock() @@ -748,6 +799,9 @@ func (c *Client) OnError(h ErrorHook) *Client { // // Out of the [Client.OnSuccess], [Client.OnError], [Client.OnInvalid], [Client.OnPanic] // callbacks, exactly one set will be invoked for each call to [Request.Execute] that completes. +// +// NOTE: +// - Do not use [Client] setter methods within OnSuccess hooks; deadlock will happen. func (c *Client) OnSuccess(h SuccessHook) *Client { c.lock.Lock() defer c.lock.Unlock() @@ -760,6 +814,9 @@ func (c *Client) OnSuccess(h SuccessHook) *Client { // // Out of the [Client.OnSuccess], [Client.OnError], [Client.OnInvalid], [Client.OnPanic] // callbacks, exactly one set will be invoked for each call to [Request.Execute] that completes. +// +// NOTE: +// - Do not use [Client] setter methods within OnInvalid hooks; deadlock will happen. func (c *Client) OnInvalid(h ErrorHook) *Client { c.lock.Lock() defer c.lock.Unlock() @@ -775,6 +832,9 @@ func (c *Client) OnInvalid(h ErrorHook) *Client { // // If an [Client.OnSuccess], [Client.OnError], or [Client.OnInvalid] callback panics, // then exactly one rule can be violated. +// +// NOTE: +// - Do not use [Client] setter methods within OnPanic hooks; deadlock will happen. func (c *Client) OnPanic(h ErrorHook) *Client { c.lock.Lock() defer c.lock.Unlock() @@ -782,18 +842,6 @@ func (c *Client) OnPanic(h ErrorHook) *Client { return c } -// SetPreRequestHook method sets the given pre-request function into a resty client. -// It is called right before the request is fired. -// -// NOTE: Only one pre-request hook can be registered. Use [Client.OnBeforeRequest] for multiple. -func (c *Client) SetPreRequestHook(h PreRequestHook) *Client { - if c.preReqHook != nil { - c.log.Warnf("Overwriting an existing pre-request hook: %s", functionName(h)) - } - c.preReqHook = h - return c -} - // ContentTypeEncoders method returns all the registered content type encoders. func (c *Client) ContentTypeEncoders() map[string]ContentTypeEncoder { c.lock.RLock() @@ -1884,7 +1932,11 @@ func (c *Client) DisableGenerateCurlOnDebug() *Client { } // SetGenerateCurlOnDebug method is used to turn on/off the generate CURL command in debug mode -// at the client instance level. +// at the client instance level. It works in conjunction with debug mode. +// +// NOTE: Use with care. +// - Potential to leak sensitive data from [Request] and [Response] in the debug log. +// - Beware of memory usage since the request body is reread. // // It can be overridden at the request level; see [Request.SetGenerateCurlOnDebug] func (c *Client) SetGenerateCurlOnDebug(b bool) *Client { @@ -1993,45 +2045,28 @@ func (c *Client) Close() error { return nil } -func (c *Client) executeBefore(req *Request) error { - var err error - - // user defined on before request methods - // to modify the *resty.Request object - for _, f := range c.beforeRequestMiddlewares() { - if err = f(c, req); err != nil { - return err - } - } - - // resty middlewares +func (c *Client) executeRequestMiddlewares(req *Request) (err error) { + c.lock.RLock() + defer c.lock.RUnlock() for _, f := range c.beforeRequest { if err = f(c, req); err != nil { return err } } - - if hostHeader := req.Header.Get("Host"); hostHeader != "" { - req.RawRequest.Host = hostHeader - } - - // call pre-request if defined - if c.preReqHook != nil { - if err = c.preReqHook(c, req.RawRequest); err != nil { - return err - } - } - return nil } // Executes method executes the given `Request` object and returns // response or error. func (c *Client) execute(req *Request) (*Response, error) { - if err := c.executeBefore(req); err != nil { + if err := c.executeRequestMiddlewares(req); err != nil { return nil, err } + if hostHeader := req.Header.Get("Host"); hostHeader != "" { + req.RawRequest.Host = hostHeader + } + requestDebugLogger(c, req) req.Time = time.Now() diff --git a/client_test.go b/client_test.go index 9d91efa4..f58dcb35 100644 --- a/client_test.go +++ b/client_test.go @@ -163,12 +163,12 @@ func TestClientDigestErrors(t *testing.T) { } } -func TestOnAfterMiddleware(t *testing.T) { +func TestClientResponseMiddleware(t *testing.T) { ts := createGenericServer(t) defer ts.Close() c := dcnl() - c.OnAfterResponse(func(c *Client, res *Response) error { + c.AddResponseMiddleware(func(c *Client, res *Response) error { t.Logf("Request sent at: %v", res.Request.Time) t.Logf("Response Received at: %v", res.ReceivedAt()) @@ -176,7 +176,7 @@ func TestOnAfterMiddleware(t *testing.T) { }) resp, err := c.R(). - SetBody("OnAfterResponse: This is plain text body to server"). + SetBody("ResponseMiddleware: This is plain text body to server"). Put(ts.URL + "/plaintext") assertError(t, err) @@ -472,9 +472,9 @@ func TestClientSetClientRootCertificateFromString(t *testing.T) { assertNotNil(t, transport.TLSClientConfig.ClientCAs) } -func TestClientOnBeforeRequestModification(t *testing.T) { +func TestClientRequestMiddlewareModification(t *testing.T) { tc := dcnl() - tc.OnBeforeRequest(func(c *Client, r *Request) error { + tc.AddRequestMiddleware(func(c *Client, r *Request) error { r.SetAuthToken("This is test auth token") return nil }) @@ -600,33 +600,44 @@ func TestClientSettingsCoverage(t *testing.T) { func TestContentLengthWhenBodyIsNil(t *testing.T) { client := dcnl() - client.SetPreRequestHook(func(c *Client, r *http.Request) error { + fnPreRequestMiddleware1 := func(c *Client, r *Request) error { assertEqual(t, "0", r.Header.Get(hdrContentLengthKey)) return nil - }) + } + client.SetRequestMiddlewares( + PrepareRequestMiddleware, + fnPreRequestMiddleware1, + ) client.R().SetContentLength(true).SetBody(nil).Get("http://localhost") } -func TestClientPreRequestHook(t *testing.T) { +func TestClientPreRequestMiddlewares(t *testing.T) { client := dcnl() - client.SetPreRequestHook(func(c *Client, r *http.Request) error { + + fnPreRequestMiddleware1 := func(c *Client, r *Request) error { c.log.Debugf("I'm in Pre-Request Hook") return nil - }) + } - client.SetPreRequestHook(func(c *Client, r *http.Request) error { + fnPreRequestMiddleware2 := func(c *Client, r *Request) error { c.log.Debugf("I'm Overwriting existing Pre-Request Hook") // Reading Request `N` no of times for i := 0; i < 5; i++ { - b, _ := r.GetBody() + b, _ := r.RawRequest.GetBody() rb, _ := io.ReadAll(b) c.log.Debugf("%s %v", string(rb), len(rb)) assertEqual(t, true, len(rb) >= 45) } return nil - }) + } + + client.SetRequestMiddlewares( + PrepareRequestMiddleware, + fnPreRequestMiddleware1, + fnPreRequestMiddleware2, + ) ts := createPostServer(t) defer ts.Close() @@ -647,18 +658,22 @@ func TestClientPreRequestHook(t *testing.T) { assertEqual(t, `{ "id": "success", "message": "login successful" }`, resp.String()) } -func TestClientPreRequestHookError(t *testing.T) { +func TestClientPreRequestMiddlewareError(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl() - c.SetPreRequestHook(func(c *Client, r *http.Request) error { - return errors.New("error from PreRequestHook") - }) + fnPreRequestMiddleware1 := func(c *Client, r *Request) error { + return errors.New("error from PreRequestMiddleware") + } + c.SetRequestMiddlewares( + PrepareRequestMiddleware, + fnPreRequestMiddleware1, + ) resp, err := c.R().Get(ts.URL) assertNotNil(t, err) - assertEqual(t, "error from PreRequestHook", err.Error()) + assertEqual(t, "error from PreRequestMiddleware", err.Error()) assertNil(t, resp) } @@ -1080,7 +1095,7 @@ func TestClientOnResponseError(t *testing.T) { { name: "before_request_error", setup: func(client *Client) { - client.OnBeforeRequest(func(client *Client, request *Request) error { + client.AddRequestMiddleware(func(client *Client, request *Request) error { return fmt.Errorf("before request") }) }, @@ -1089,7 +1104,7 @@ func TestClientOnResponseError(t *testing.T) { { name: "before_request_error_retry", setup: func(client *Client) { - client.SetRetryCount(3).OnBeforeRequest(func(client *Client, request *Request) error { + client.SetRetryCount(3).AddRequestMiddleware(func(client *Client, request *Request) error { return fmt.Errorf("before request") }) }, @@ -1098,7 +1113,7 @@ func TestClientOnResponseError(t *testing.T) { { name: "after_response_error", setup: func(client *Client) { - client.OnAfterResponse(func(client *Client, response *Response) error { + client.AddResponseMiddleware(func(client *Client, response *Response) error { return fmt.Errorf("after response") }) }, @@ -1108,7 +1123,7 @@ func TestClientOnResponseError(t *testing.T) { { name: "after_response_error_retry", setup: func(client *Client) { - client.SetRetryCount(3).OnAfterResponse(func(client *Client, response *Response) error { + client.SetRetryCount(3).AddResponseMiddleware(func(client *Client, response *Response) error { return fmt.Errorf("after response") }) }, @@ -1118,7 +1133,7 @@ func TestClientOnResponseError(t *testing.T) { { name: "panic with error", setup: func(client *Client) { - client.OnBeforeRequest(func(client *Client, request *Request) error { + client.AddRequestMiddleware(func(client *Client, request *Request) error { panic(fmt.Errorf("before request")) }) }, @@ -1129,7 +1144,7 @@ func TestClientOnResponseError(t *testing.T) { { name: "panic with string", setup: func(client *Client) { - client.OnBeforeRequest(func(client *Client, request *Request) error { + client.AddRequestMiddleware(func(client *Client, request *Request) error { panic("before request") }) }, diff --git a/middleware.go b/middleware.go index a9c88c52..2bd43d54 100644 --- a/middleware.go +++ b/middleware.go @@ -23,6 +23,44 @@ const debugRequestLogKey = "__restyDebugRequestLog" // Request Middleware(s) //_______________________________________________________________________ +// PrepareRequestMiddleware method is used to prepare HTTP requests from +// user provides request values. Request preparation fails if any error occurs +func PrepareRequestMiddleware(c *Client, r *Request) error { + var err error + + if err = parseRequestURL(c, r); err != nil { + return err + } + + // no error returned + parseRequestHeader(c, r) + + if err = parseRequestBody(c, r); err != nil { + return err + } + + if err = createHTTPRequest(c, r); err != nil { + return err + } + + // last one doesn't need if condition + return addCredentials(c, r) +} + +// GenerateCurlRequestMiddleware method is used to perform CURL command +// generation during a request preparation +// +// See, [Client.SetGenerateCurlOnDebug], [Request.SetGenerateCurlOnDebug] +func GenerateCurlRequestMiddleware(c *Client, r *Request) (err error) { + if r.Debug && r.generateCurlOnDebug { + if r.resultCurlCmd == nil { + r.resultCurlCmd = new(string) + } + *r.resultCurlCmd = buildCurlCmd(r) + } + return nil +} + func parseRequestURL(c *Client, r *Request) error { if l := len(c.PathParams()) + len(c.RawPathParams()) + len(r.PathParams) + len(r.RawPathParams); l > 0 { params := make(map[string]string, l) @@ -299,74 +337,6 @@ func addCredentials(c *Client, r *Request) error { return nil } -func createCurlCmd(c *Client, r *Request) (err error) { - if r.Debug && r.generateCurlOnDebug { - if r.resultCurlCmd == nil { - r.resultCurlCmd = new(string) - } - *r.resultCurlCmd = buildCurlCmd(r) - } - return nil -} - -//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Response Middleware(s) -//_______________________________________________________________________ - -func parseResponseBody(c *Client, res *Response) (err error) { - if res.Err != nil || - res.Request.DoNotParseResponse || - res.Request.isSaveResponse { - return // move on - } - - if res.StatusCode() == http.StatusNoContent { - res.Request.Error = nil - return - } - - rct := firstNonEmpty( - res.Request.ForceResponseContentType, - res.Header().Get(hdrContentTypeKey), - res.Request.ExpectResponseContentType, - ) - decKey := inferContentTypeMapKey(rct) - decFunc, found := c.inferContentTypeDecoder(rct, decKey) - if !found { - // the Content-Type decoder is not found; just read all the body bytes - err = res.readAll() - return - } - - // HTTP status code > 199 and < 300, considered as Result - if res.IsSuccess() && res.Request.Result != nil { - res.Request.Error = nil - defer closeq(res.Body) - err = decFunc(res.Body, res.Request.Result) - res.IsRead = true - return - } - - // HTTP status code > 399, considered as Error - if res.IsError() { - // global error type registered at client-instance - if res.Request.Error == nil { - res.Request.Error = c.newErrorInterface() - } - - if res.Request.Error != nil { - defer closeq(res.Body) - err = decFunc(res.Body, res.Request.Error) - res.IsRead = true - return - } - } - - // read all bytes when auto-unmarshal didn't take place - err = res.readAll() - return -} - func handleMultipart(c *Client, r *Request) error { for k, v := range c.FormData() { if _, ok := r.FormData[k]; ok { @@ -453,7 +423,7 @@ func createMultipart(w *multipart.Writer, r *Request) error { if _, err = partWriter.Write(p[:size]); err != nil { return err } - _, err = io.Copy(partWriter, mf.Reader) + _, err = ioCopy(partWriter, mf.Reader) if err != nil { return err } @@ -529,7 +499,70 @@ func handleRequestBody(c *Client, r *Request) error { return nil } -func saveResponseIntoFile(c *Client, res *Response) error { +//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ +// Response Middleware(s) +//_______________________________________________________________________ + +// AutoParseResponseMiddleware method is used to parse the response body automatically +// based on registered HTTP response `Content-Type` decoder, see [Client.AddContentTypeDecoder]; +// if [Request.SetResult], [Request.SetError], or [Client.SetError] is used +func AutoParseResponseMiddleware(c *Client, res *Response) (err error) { + if res.Err != nil || + res.Request.DoNotParseResponse || + res.Request.isSaveResponse { + return // move on + } + + if res.StatusCode() == http.StatusNoContent { + res.Request.Error = nil + return + } + + rct := firstNonEmpty( + res.Request.ForceResponseContentType, + res.Header().Get(hdrContentTypeKey), + res.Request.ExpectResponseContentType, + ) + decKey := inferContentTypeMapKey(rct) + decFunc, found := c.inferContentTypeDecoder(rct, decKey) + if !found { + // the Content-Type decoder is not found; just read all the body bytes + err = res.readAll() + return + } + + // HTTP status code > 199 and < 300, considered as Result + if res.IsSuccess() && res.Request.Result != nil { + res.Request.Error = nil + defer closeq(res.Body) + err = decFunc(res.Body, res.Request.Result) + res.IsRead = true + return + } + + // HTTP status code > 399, considered as Error + if res.IsError() { + // global error type registered at client-instance + if res.Request.Error == nil { + res.Request.Error = c.newErrorInterface() + } + + if res.Request.Error != nil { + defer closeq(res.Body) + err = decFunc(res.Body, res.Request.Error) + res.IsRead = true + return + } + } + + // read all bytes when auto-unmarshal didn't take place + err = res.readAll() + return +} + +// SaveToFileResponseMiddleware method used to write HTTP response body into +// given file details via [Request.SetOutputFile] +func SaveToFileResponseMiddleware(c *Client, res *Response) error { if res.Err != nil || !res.Request.isSaveResponse { return nil } @@ -549,12 +582,13 @@ func saveResponseIntoFile(c *Client, res *Response) error { if err != nil { return err } - defer closeq(outFile) + defer func() { + closeq(outFile) + closeq(res.Body) + }() // io.Copy reads maximum 32kb size, it is perfect for large file download too - defer closeq(res.Body) - - written, err := io.Copy(outFile, res.Body) + written, err := ioCopy(outFile, res.Body) if err != nil { return err } diff --git a/middleware_test.go b/middleware_test.go index 866dd9b4..edebb114 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -1064,7 +1064,7 @@ func Benchmark_parseRequestBody_MultiPart(b *testing.B) { } } -func TestSaveResponseToFile(t *testing.T) { +func TestMiddlewareSaveToFileErrorCases(t *testing.T) { c := dcnl() tempDir := t.TempDir() @@ -1084,16 +1084,35 @@ func TestSaveResponseToFile(t *testing.T) { // dir create error req1 := c.R() req1.SetOutputFile(filepath.Join(tempDir, "new-res-dir", "sample.txt")) - err1 := saveResponseIntoFile(c, &Response{Request: req1}) + err1 := SaveToFileResponseMiddleware(c, &Response{Request: req1}) assertEqual(t, errDirMsg, err1.Error()) // file create error req2 := c.R() req2.SetOutputFile(filepath.Join(tempDir, "sample.txt")) - err2 := saveResponseIntoFile(c, &Response{Request: req2}) + err2 := SaveToFileResponseMiddleware(c, &Response{Request: req2}) assertEqual(t, errFileMsg, err2.Error()) } +func TestMiddlewareSaveToFileCopyError(t *testing.T) { + c := dcnl() + tempDir := t.TempDir() + + errCopyMsg := "test copy error" + ioCopy = func(dst io.Writer, src io.Reader) (written int64, err error) { + return 0, errors.New(errCopyMsg) + } + t.Cleanup(func() { + ioCopy = io.Copy + }) + + // copy error + req1 := c.R() + req1.SetOutputFile(filepath.Join(tempDir, "new-res-dir", "sample.txt")) + err1 := SaveToFileResponseMiddleware(c, &Response{Request: req1, Body: io.NopCloser(bytes.NewBufferString("Test context"))}) + assertEqual(t, errCopyMsg, err1.Error()) +} + func TestRequestURL_GH797(t *testing.T) { ts := createGetServer(t) defer ts.Close() diff --git a/multipart_test.go b/multipart_test.go index aa612855..b008652b 100644 --- a/multipart_test.go +++ b/multipart_test.go @@ -9,6 +9,7 @@ import ( "bytes" "context" "errors" + "io" "io/fs" "mime/multipart" "net/http" @@ -483,7 +484,7 @@ func (mwe *mpWriterError) Write(p []byte) (int, error) { return 0, errors.New("multipart write error") } -func TestRequest_writeFormData(t *testing.T) { +func TestMulipartRequest_createMultipart(t *testing.T) { mw := multipart.NewWriter(&mpWriterError{}) c := dcnl() @@ -492,13 +493,37 @@ func TestRequest_writeFormData(t *testing.T) { "name2": "value2", }) - err1 := req1.writeFormData(mw) - assertNotNil(t, err1) - assertEqual(t, "multipart write error", err1.Error()) + t.Run("writeFormData", func(t *testing.T) { + err1 := req1.writeFormData(mw) + assertNotNil(t, err1) + assertEqual(t, "multipart write error", err1.Error()) + }) + + t.Run("createMultipart", func(t *testing.T) { + err2 := createMultipart(mw, req1) + assertNotNil(t, err2) + assertEqual(t, "multipart write error", err2.Error()) + }) + + t.Run("io copy error", func(t *testing.T) { + errCopyMsg := "test copy error" + ioCopy = func(dst io.Writer, src io.Reader) (written int64, err error) { + return 0, errors.New(errCopyMsg) + } + t.Cleanup(func() { + ioCopy = io.Copy + }) - err2 := createMultipart(mw, req1) - assertNotNil(t, err2) - assertEqual(t, "multipart write error", err2.Error()) + req1 := c.R(). + SetFile("file", filepath.Join(getTestDataPath(), "test-img.png")). + SetMultipartBoundary("custom-boundary-"+strconv.FormatInt(time.Now().Unix(), 10)). + SetHeader("Content-Type", "image/png") + + mw := multipart.NewWriter(new(bytes.Buffer)) + err := createMultipart(mw, req1) + assertNotNil(t, err) + assertEqual(t, "test copy error", err.Error()) + }) } type returnValueTestWriter struct { diff --git a/request.go b/request.go index 94c818b3..d3d3fc55 100644 --- a/request.go +++ b/request.go @@ -104,7 +104,10 @@ func (r *Request) GenerateCurlCommand() string { return *r.resultCurlCmd } if r.RawRequest == nil { - r.client.executeBefore(r) // mock with r.Get("/") + // mock with r.Get("/") + if err := r.client.executeRequestMiddlewares(r); err != nil { + r.log.Errorf("%v", err) + } } *r.resultCurlCmd = buildCurlCmd(r) return *r.resultCurlCmd @@ -656,14 +659,14 @@ func (r *Request) SetAuthScheme(scheme string) *Request { // [RFC 7616]: https://datatracker.ietf.org/doc/html/rfc7616 func (r *Request) SetDigestAuth(username, password string) *Request { oldTransport := r.client.httpClient.Transport - r.client.OnBeforeRequest(func(c *Client, _ *Request) error { + r.client.AddRequestMiddleware(func(c *Client, _ *Request) error { c.httpClient.Transport = &digestTransport{ digestCredentials: digestCredentials{username, password}, transport: oldTransport, } return nil }) - r.client.OnAfterResponse(func(c *Client, _ *Response) error { + r.client.AddResponseMiddleware(func(c *Client, _ *Response) error { c.httpClient.Transport = oldTransport return nil }) @@ -1066,6 +1069,11 @@ func (r *Request) DisableGenerateCurlOnDebug() *Request { } // SetGenerateCurlOnDebug method is used to turn on/off the generate CURL command in debug mode. +// It works in conjunction with debug mode. +// +// NOTE: Use with care. +// - Potential to leak sensitive data from [Request] and [Response] in the debug log. +// - Beware of memory usage since the request body is reread. // // It overrides the options set by the [Client.SetGenerateCurlOnDebug] func (r *Request) SetGenerateCurlOnDebug(b bool) *Request { diff --git a/request_test.go b/request_test.go index abd3667f..cd958d2b 100644 --- a/request_test.go +++ b/request_test.go @@ -895,17 +895,16 @@ func TestPutJSONString(t *testing.T) { client := dcnl() - client.OnBeforeRequest(func(c *Client, r *Request) error { - r.SetHeader("X-Custom-Request-Middleware", "OnBeforeRequest middleware") + client.AddRequestMiddleware(func(c *Client, r *Request) error { + r.SetHeader("X-Custom-Request-Middleware", "Request middleware") return nil }) - client.OnBeforeRequest(func(c *Client, r *Request) error { - c.SetContentLength(true) - r.SetHeader("X-ContentLength", "OnBeforeRequest ContentLength set") + client.AddRequestMiddleware(func(c *Client, r *Request) error { + r.SetHeader("X-ContentLength", "Request middleware ContentLength set") return nil }) - client.SetDebug(true) + client.SetDebug(true).SetContentLength(true) client.outputLogTo(io.Discard) resp, err := client.R(). @@ -932,23 +931,24 @@ func TestPutXMLString(t *testing.T) { assertEqual(t, `XML response`, resp.String()) } -func TestOnBeforeMiddleware(t *testing.T) { +func TestRequestMiddleware(t *testing.T) { ts := createGenericServer(t) defer ts.Close() c := dcnl() - c.OnBeforeRequest(func(c *Client, r *Request) error { - r.SetHeader("X-Custom-Request-Middleware", "OnBeforeRequest middleware") + c.SetContentLength(true) + + c.AddRequestMiddleware(func(c *Client, r *Request) error { + r.SetHeader("X-Custom-Request-Middleware", "Request middleware") return nil }) - c.OnBeforeRequest(func(c *Client, r *Request) error { - c.SetContentLength(true) - r.SetHeader("X-ContentLength", "OnBeforeRequest ContentLength set") + c.AddRequestMiddleware(func(c *Client, r *Request) error { + r.SetHeader("X-ContentLength", "Request middleware ContentLength set") return nil }) resp, err := c.R(). - SetBody("OnBeforeRequest: This is plain text body to server"). + SetBody("RequestMiddleware: This is plain text body to server"). Put(ts.URL + "/plaintext") assertError(t, err) diff --git a/resty.go b/resty.go index e0d6b588..b1ce59cf 100644 --- a/resty.go +++ b/resty.go @@ -193,24 +193,17 @@ func createClient(hc *http.Client) *Client { c.AddContentDecompressor("deflate", decompressDeflate) c.AddContentDecompressor("gzip", decompressGzip) - // default before request middlewares - c.beforeRequest = []RequestMiddleware{ - parseRequestURL, - parseRequestHeader, - parseRequestBody, - createHTTPRequest, - addCredentials, - createCurlCmd, - } - - // user defined request middlewares - c.udBeforeRequest = []RequestMiddleware{} + // request middlewares + c.SetRequestMiddlewares( + PrepareRequestMiddleware, + GenerateCurlRequestMiddleware, + ) - // default after response middlewares - c.afterResponse = []ResponseMiddleware{ - parseResponseBody, - saveResponseIntoFile, - } + // response middlewares + c.SetResponseMiddlewares( + AutoParseResponseMiddleware, + SaveToFileResponseMiddleware, + ) return c } diff --git a/util.go b/util.go index edf087ef..1beeda1c 100644 --- a/util.go +++ b/util.go @@ -128,6 +128,7 @@ func firstNonEmpty(v ...string) string { var ( mkdirAll = os.MkdirAll createFile = os.Create + ioCopy = io.Copy ) func createDirectory(dir string) (err error) {