Skip to content

Commit

Permalink
added more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
apourchet committed Jan 11, 2024
1 parent ef563dc commit 50b8014
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 1 deletion.
1 change: 0 additions & 1 deletion standard.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ func StandardResponseWriter() func(w http.ResponseWriter, res any, err error) {

w.WriteHeader(http.StatusOK)
encoder := json.NewEncoder(w)
encoder.SetIndent("", " ")
if sendError := encoder.Encode(res); sendError != nil {
log.Println("Error writing response:", sendError)
}
Expand Down
147 changes: 147 additions & 0 deletions standard_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package httpwrap

import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/stretchr/testify/require"
)

type header struct {
Integer int `http:"header=integer"`
}

type query struct {
String string `http:"query=string"`
}

type typedContext struct {
Integer int
String string
Bool bool
}

type typedResponse struct {
Value int `json:"value"`
}

func TestStandardWrapper(t *testing.T) {
var rw *httptest.ResponseRecorder
var req *http.Request

t.Run("simple error", func(t *testing.T) {
wrapper := NewStandardWrapper()
noErrorHandler := wrapper.Wrap(func() error {
return nil
})
errorHandler := wrapper.Wrap(func() error {
return NewHTTPError(http.StatusForbidden, "Forbidden.")
})

rw = httptest.NewRecorder()
req = httptest.NewRequest("GET", "/endpoint", nil)
noErrorHandler.ServeHTTP(rw, req)
require.Equal(t, http.StatusOK, rw.Result().StatusCode)

rw = httptest.NewRecorder()
req = httptest.NewRequest("GET", "/endpoint", nil)
errorHandler.ServeHTTP(rw, req)
require.Equal(t, http.StatusForbidden, rw.Result().StatusCode)
})

t.Run("no middleware", func(t *testing.T) {
wrapper := NewStandardWrapper()
handler := wrapper.Wrap(func(p1 header, p2 query) error {
require.Equal(t, 12, p1.Integer)
require.Equal(t, "abc", p2.String)
return nil
})

rw := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/endpoint?string=abc", nil)
req.Header.Set("integer", "12")

handler.ServeHTTP(rw, req)
require.Equal(t, http.StatusOK, rw.Result().StatusCode)
})

t.Run("json response", func(t *testing.T) {
wrapper := NewStandardWrapper()
handler := wrapper.Wrap(func(p1 header, p2 query) (typedResponse, error) {
require.Equal(t, 12, p1.Integer)
require.Equal(t, "abc", p2.String)
return typedResponse{Value: 42}, nil
})

rw := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/endpoint?string=abc", nil)
req.Header.Set("integer", "12")

handler.ServeHTTP(rw, req)
statusCode, body := readResponseRecorder(t, rw)
require.Equal(t, http.StatusOK, statusCode)
require.Equal(t, `{"value":42}`, body)
})

t.Run("with middleware", func(t *testing.T) {
wrapper := NewStandardWrapper().
Before(func() typedContext {
return typedContext{
Integer: 13,
String: "abc",
Bool: true,
}
})

handler := wrapper.Wrap(func(p1 header, p2 query, context typedContext) (typedResponse, error) {
require.Equal(t, 12, p1.Integer)
require.Equal(t, "abc", p2.String)

require.Equal(t, 13, context.Integer)
require.Equal(t, "abc", context.String)
require.Equal(t, true, context.Bool)

return typedResponse{Value: 42}, nil
})

rw := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/endpoint?string=abc", nil)
req.Header.Set("integer", "12")

handler.ServeHTTP(rw, req)
statusCode, body := readResponseRecorder(t, rw)
require.Equal(t, http.StatusOK, statusCode)
require.Equal(t, `{"value":42}`, body)
})

t.Run("middleware shortcircuit", func(t *testing.T) {
wrapper := NewStandardWrapper().
Before(func() error {
return NewHTTPError(http.StatusUnauthorized, "Unauthorized.")
})

handler := wrapper.Wrap(func(p1 header, p2 query, context typedContext) error {
require.Fail(t, "Handler should never be called.")
return nil
})

rw := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/endpoint?string=abc", nil)
req.Header.Set("integer", "12")

handler.ServeHTTP(rw, req)
statusCode, body := readResponseRecorder(t, rw)
require.Equal(t, http.StatusUnauthorized, statusCode)
require.Equal(t, "Unauthorized.", body)
})
}

func readResponseRecorder(t *testing.T, rw *httptest.ResponseRecorder) (int, string) {
result := rw.Result()
body, err := io.ReadAll(result.Body)
require.NoError(t, err)
return result.StatusCode, strings.TrimSpace(string(body))
}

0 comments on commit 50b8014

Please sign in to comment.