diff --git a/error.go b/error.go index e02cab1..99a4360 100644 --- a/error.go +++ b/error.go @@ -34,3 +34,8 @@ func (err httpError) WriteBody(writer io.Writer) error { _, writeError := io.WriteString(writer, err.body) return writeError } + +// NewNoopError returns an HTTPError that will completely bypass the +// deserialization logic. This can be used when the endpoint or middleware +// operates directly on the native http.ResponseWriter. +func NewNoopError() HTTPError { return NewHTTPError(0, "") } diff --git a/standard_test.go b/standard_test.go index bacabb0..213770d 100644 --- a/standard_test.go +++ b/standard_test.go @@ -137,6 +137,29 @@ func TestStandardWrapper(t *testing.T) { require.Equal(t, http.StatusUnauthorized, statusCode) require.Equal(t, "Unauthorized.", body) }) + + t.Run("middleware shortcircuit with nooperror", func(t *testing.T) { + wrapper := NewStandardWrapper(). + Before(func(rw http.ResponseWriter) error { + rw.WriteHeader(http.StatusCreated) + rw.Write([]byte("HELLO WORLD")) + return NewNoopError() + }) + + handler := wrapper.Wrap(func(p1 header, p2 query, context typedContext) error { + require.Fail(t, "Handler should never be called.") + return nil + }) + + rw := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/endpoint?string=abc", nil) + req.Header.Set("integer", "12") + + handler.ServeHTTP(rw, req) + statusCode, body := readResponseRecorder(t, rw) + require.Equal(t, http.StatusCreated, statusCode) + require.Equal(t, "HELLO WORLD", body) + }) } func readResponseRecorder(t *testing.T, rw *httptest.ResponseRecorder) (int, string) {