diff --git a/insrequester/README.md b/insrequester/README.md index cb5617f..93fa352 100644 --- a/insrequester/README.md +++ b/insrequester/README.md @@ -65,6 +65,16 @@ timeoutSeconds := 30 requester.WithTimeout(timeoutSeconds) // this timeout overrides the default timeout ``` +#### Default Headers +For applying default headers to all requests, you can use the WithDefaultHeaders method: + +```go +headers := insrequester.Headers{{"Authorization": "Bearer token"}} +requester.WithHeaders(headers) +``` +It should be noted that you can still override these default headers by providing the same header in the request entity. + + ### Loading Middlewares After configuring the desired resilience features, load the configured middlewares using the Load method: diff --git a/insrequester/requester.go b/insrequester/requester.go index 846e1cc..7a9a8db 100644 --- a/insrequester/requester.go +++ b/insrequester/requester.go @@ -39,12 +39,15 @@ type Requester interface { WithRetry(config RetryConfig) *Request WithCircuitbreaker(config CircuitBreakerConfig) *Request WithTimeout(timeoutSeconds int) *Request + WithHeaders(headers Headers) *Request Load() *Request } +type Headers []map[string]interface{} + // RequestEntity contains required information for sending http request. type RequestEntity struct { - Headers []map[string]interface{} + Headers Headers Endpoint string Body []byte } @@ -53,6 +56,7 @@ type Request struct { timeout int runner goresilience.Runner middlewares []goresilience.Middleware + headers Headers } // Get sends HTTP get request to the given endpoint and returns *http.Response and an error. @@ -95,6 +99,7 @@ func (r *Request) sendRequest(httpMethod string, re RequestEntity) (*http.Respon } req.Close = true + re.Headers = append(r.headers, re.Headers...) // RequestEntity headers will override Requester level headers. re.applyHeadersToRequest(req) res, outerErr = (&http.Client{Timeout: time.Duration(r.timeout) * time.Second}).Do(req) @@ -197,6 +202,11 @@ func (r *Request) WithTimeout(timeoutSeconds int) *Request { return r } +func (r *Request) WithHeaders(headers Headers) *Request { + r.headers = headers + return r +} + func (r *Request) Load() *Request { r.runner = goresilience.RunnerChain(r.middlewares...) return r diff --git a/insrequester/requester_mock.go b/insrequester/requester_mock.go index a8f9c44..f2a8680 100644 --- a/insrequester/requester_mock.go +++ b/insrequester/requester_mock.go @@ -1,7 +1,7 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: requester.go +// Source: ./insrequester/requester.go -// Package mock_requester is a generated GoMock package. +// Package insrequester is a generated GoMock package. package insrequester import ( @@ -37,7 +37,7 @@ func (m *MockRequester) EXPECT() *MockRequesterMockRecorder { // Delete mocks base method. func (m *MockRequester) Delete(re RequestEntity) (*http.Response, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, http.MethodDelete, re) + ret := m.ctrl.Call(m, "Delete", re) ret0, _ := ret[0].(*http.Response) ret1, _ := ret[1].(error) return ret0, ret1 @@ -46,13 +46,13 @@ func (m *MockRequester) Delete(re RequestEntity) (*http.Response, error) { // Delete indicates an expected call of Delete. func (mr *MockRequesterMockRecorder) Delete(re interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, http.MethodDelete, reflect.TypeOf((*MockRequester)(nil).Delete), re) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockRequester)(nil).Delete), re) } // Get mocks base method. func (m *MockRequester) Get(re RequestEntity) (*http.Response, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, http.MethodGet, re) + ret := m.ctrl.Call(m, "Get", re) ret0, _ := ret[0].(*http.Response) ret1, _ := ret[1].(error) return ret0, ret1 @@ -61,7 +61,7 @@ func (m *MockRequester) Get(re RequestEntity) (*http.Response, error) { // Get indicates an expected call of Get. func (mr *MockRequesterMockRecorder) Get(re interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, http.MethodGet, reflect.TypeOf((*MockRequester)(nil).Get), re) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockRequester)(nil).Get), re) } // Load mocks base method. @@ -81,7 +81,7 @@ func (mr *MockRequesterMockRecorder) Load() *gomock.Call { // Post mocks base method. func (m *MockRequester) Post(re RequestEntity) (*http.Response, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, http.MethodPost, re) + ret := m.ctrl.Call(m, "Post", re) ret0, _ := ret[0].(*http.Response) ret1, _ := ret[1].(error) return ret0, ret1 @@ -90,13 +90,13 @@ func (m *MockRequester) Post(re RequestEntity) (*http.Response, error) { // Post indicates an expected call of Post. func (mr *MockRequesterMockRecorder) Post(re interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, http.MethodPost, reflect.TypeOf((*MockRequester)(nil).Post), re) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Post", reflect.TypeOf((*MockRequester)(nil).Post), re) } // Put mocks base method. func (m *MockRequester) Put(re RequestEntity) (*http.Response, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, http.MethodPut, re) + ret := m.ctrl.Call(m, "Put", re) ret0, _ := ret[0].(*http.Response) ret1, _ := ret[1].(error) return ret0, ret1 @@ -105,7 +105,7 @@ func (m *MockRequester) Put(re RequestEntity) (*http.Response, error) { // Put indicates an expected call of Put. func (mr *MockRequesterMockRecorder) Put(re interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, http.MethodPut, reflect.TypeOf((*MockRequester)(nil).Put), re) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockRequester)(nil).Put), re) } // WithCircuitbreaker mocks base method. @@ -122,6 +122,20 @@ func (mr *MockRequesterMockRecorder) WithCircuitbreaker(config interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithCircuitbreaker", reflect.TypeOf((*MockRequester)(nil).WithCircuitbreaker), config) } +// WithHeaders mocks base method. +func (m *MockRequester) WithHeaders(headers Headers) *Request { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WithHeaders", headers) + ret0, _ := ret[0].(*Request) + return ret0 +} + +// WithHeaders indicates an expected call of WithHeaders. +func (mr *MockRequesterMockRecorder) WithHeaders(headers interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithHeaders", reflect.TypeOf((*MockRequester)(nil).WithHeaders), headers) +} + // WithRetry mocks base method. func (m *MockRequester) WithRetry(config RetryConfig) *Request { m.ctrl.T.Helper() diff --git a/insrequester/requester_test.go b/insrequester/requester_test.go index e5c4fff..f41cb0d 100644 --- a/insrequester/requester_test.go +++ b/insrequester/requester_test.go @@ -68,4 +68,48 @@ func TestRequest_Get(t *testing.T) { _, err = r.Get(req) assert.ErrorIs(t, err, errors.ErrCircuitOpen) }) + + t.Run("it_should_apply_headers_properly", func(t *testing.T) { + var receivedUserAgent string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedUserAgent = r.Header.Get("User-Agent") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"status": "OK"}`)) + })) + + defer ts.Close() + + userAgent := "test-user-agent" + r := NewRequester().WithHeaders(Headers{{"User-Agent": userAgent}}) + res, err := r.Get(RequestEntity{Endpoint: ts.URL}) + + assert.NoError(t, err) + assert.Equal(t, receivedUserAgent, userAgent) + assert.Equal(t, http.StatusOK, res.StatusCode) + }) + + t.Run("it_should_override_Requester_level_header_if_RequestEntity_headers_set", func(t *testing.T) { + var receivedUserAgent string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedUserAgent = r.Header.Get("User-Agent") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"status": "OK"}`)) + })) + + defer ts.Close() + + oldUserAgent := "old-user-agent" + r := NewRequester().WithHeaders(Headers{{"User-Agent": oldUserAgent}}) + + newUserAgent := "new-user-agent" + req := RequestEntity{ + Endpoint: ts.URL, + Headers: Headers{{"User-Agent": newUserAgent}}, + } + res, err := r.Get(req) + + assert.NoError(t, err) + assert.Equal(t, receivedUserAgent, newUserAgent) + assert.Equal(t, http.StatusOK, res.StatusCode) + }) }