From 50b80144c6feb709b8dcde6a75ded5dc712dde4d Mon Sep 17 00:00:00 2001 From: Antoine Pourchet Date: Thu, 11 Jan 2024 09:47:39 -0700 Subject: [PATCH] added more tests --- standard.go | 1 - standard_test.go | 147 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 147 insertions(+), 1 deletion(-) create mode 100644 standard_test.go diff --git a/standard.go b/standard.go index 954df1d..3438e19 100644 --- a/standard.go +++ b/standard.go @@ -62,7 +62,6 @@ func StandardResponseWriter() func(w http.ResponseWriter, res any, err error) { w.WriteHeader(http.StatusOK) encoder := json.NewEncoder(w) - encoder.SetIndent("", " ") if sendError := encoder.Encode(res); sendError != nil { log.Println("Error writing response:", sendError) } diff --git a/standard_test.go b/standard_test.go new file mode 100644 index 0000000..bacabb0 --- /dev/null +++ b/standard_test.go @@ -0,0 +1,147 @@ +package httpwrap + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +type header struct { + Integer int `http:"header=integer"` +} + +type query struct { + String string `http:"query=string"` +} + +type typedContext struct { + Integer int + String string + Bool bool +} + +type typedResponse struct { + Value int `json:"value"` +} + +func TestStandardWrapper(t *testing.T) { + var rw *httptest.ResponseRecorder + var req *http.Request + + t.Run("simple error", func(t *testing.T) { + wrapper := NewStandardWrapper() + noErrorHandler := wrapper.Wrap(func() error { + return nil + }) + errorHandler := wrapper.Wrap(func() error { + return NewHTTPError(http.StatusForbidden, "Forbidden.") + }) + + rw = httptest.NewRecorder() + req = httptest.NewRequest("GET", "/endpoint", nil) + noErrorHandler.ServeHTTP(rw, req) + require.Equal(t, http.StatusOK, rw.Result().StatusCode) + + rw = httptest.NewRecorder() + req = httptest.NewRequest("GET", "/endpoint", nil) + errorHandler.ServeHTTP(rw, req) + require.Equal(t, http.StatusForbidden, rw.Result().StatusCode) + }) + + t.Run("no middleware", func(t *testing.T) { + wrapper := NewStandardWrapper() + handler := wrapper.Wrap(func(p1 header, p2 query) error { + require.Equal(t, 12, p1.Integer) + require.Equal(t, "abc", p2.String) + return nil + }) + + rw := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/endpoint?string=abc", nil) + req.Header.Set("integer", "12") + + handler.ServeHTTP(rw, req) + require.Equal(t, http.StatusOK, rw.Result().StatusCode) + }) + + t.Run("json response", func(t *testing.T) { + wrapper := NewStandardWrapper() + handler := wrapper.Wrap(func(p1 header, p2 query) (typedResponse, error) { + require.Equal(t, 12, p1.Integer) + require.Equal(t, "abc", p2.String) + return typedResponse{Value: 42}, nil + }) + + rw := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/endpoint?string=abc", nil) + req.Header.Set("integer", "12") + + handler.ServeHTTP(rw, req) + statusCode, body := readResponseRecorder(t, rw) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, `{"value":42}`, body) + }) + + t.Run("with middleware", func(t *testing.T) { + wrapper := NewStandardWrapper(). + Before(func() typedContext { + return typedContext{ + Integer: 13, + String: "abc", + Bool: true, + } + }) + + handler := wrapper.Wrap(func(p1 header, p2 query, context typedContext) (typedResponse, error) { + require.Equal(t, 12, p1.Integer) + require.Equal(t, "abc", p2.String) + + require.Equal(t, 13, context.Integer) + require.Equal(t, "abc", context.String) + require.Equal(t, true, context.Bool) + + return typedResponse{Value: 42}, nil + }) + + rw := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/endpoint?string=abc", nil) + req.Header.Set("integer", "12") + + handler.ServeHTTP(rw, req) + statusCode, body := readResponseRecorder(t, rw) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, `{"value":42}`, body) + }) + + t.Run("middleware shortcircuit", func(t *testing.T) { + wrapper := NewStandardWrapper(). + Before(func() error { + return NewHTTPError(http.StatusUnauthorized, "Unauthorized.") + }) + + handler := wrapper.Wrap(func(p1 header, p2 query, context typedContext) error { + require.Fail(t, "Handler should never be called.") + return nil + }) + + rw := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/endpoint?string=abc", nil) + req.Header.Set("integer", "12") + + handler.ServeHTTP(rw, req) + statusCode, body := readResponseRecorder(t, rw) + require.Equal(t, http.StatusUnauthorized, statusCode) + require.Equal(t, "Unauthorized.", body) + }) +} + +func readResponseRecorder(t *testing.T, rw *httptest.ResponseRecorder) (int, string) { + result := rw.Result() + body, err := io.ReadAll(result.Body) + require.NoError(t, err) + return result.StatusCode, strings.TrimSpace(string(body)) +}