From 5a961eed16b96e5c0f716ed6588be128301cd938 Mon Sep 17 00:00:00 2001 From: Edouard Lacourt <34008166+rizerkrof@users.noreply.github.com> Date: Mon, 18 Dec 2023 15:36:30 +0100 Subject: [PATCH] Validation should return 400 on fail (#30) New BadRequestError type : returns 400 status code on Validation failure Instead of 500 previously --------- Co-authored-by: EwenQuim --- ctx.go | 26 ++++++++------ deserialization.go | 10 +++--- deserialization_test.go | 6 ++-- errors.go | 44 +++++++++++++++++++----- errors_test.go | 20 +++++------ examples/simple-crud/errors_custom.go | 11 ++++++ examples/simple-crud/store/ingredient.go | 1 - examples/simple-crud/views/admin.go | 1 - middleware/basicauth/basicauth.go | 2 +- options_test.go | 2 +- serialization.go | 2 +- serialization_test.go | 4 +-- 12 files changed, 84 insertions(+), 45 deletions(-) create mode 100644 examples/simple-crud/errors_custom.go diff --git a/ctx.go b/ctx.go index a987db20..c443ca0c 100644 --- a/ctx.go +++ b/ctx.go @@ -17,6 +17,7 @@ const ( // Ctx is the context of the request. // It contains the request body, the path parameters, the query parameters, and the http request. +// Please do not use a pointer type as parameter. type Ctx[B any] interface { // Body returns the body of the request. // If (*B) implements [InTransformer], it will be transformed after deserialization. @@ -79,6 +80,7 @@ type Ctx[B any] interface { Pass() ClassicContext } +// NewContext returns a new context. It is used internally by Fuego. You probably want to use Ctx[B] instead. func NewContext[B any](w http.ResponseWriter, r *http.Request, options readOptions) *Context[B] { c := &Context[B]{ ClassicContext: ClassicContext{ @@ -94,6 +96,7 @@ func NewContext[B any](w http.ResponseWriter, r *http.Request, options readOptio return c } +// Context is used internally by Fuego. You probably want to use Ctx[B] instead. Please do not use a pointer type as parameter. type Context[BodyType any] struct { body *BodyType ClassicContext @@ -103,15 +106,14 @@ func (c *Context[B]) Pass() ClassicContext { return c.ClassicContext } -// ClassicContext for the request. BodyType is the type of the request body. Please do not use a pointer type as parameter. +// ClassicContext is used internally by Fuego. Please do not use a pointer type as parameter. type ClassicContext struct { request *http.Request response http.ResponseWriter pathParams map[string]string - fs fs.FS - templates *template.Template - templatesParsed bool + fs fs.FS + templates *template.Template readOptions readOptions } @@ -119,6 +121,7 @@ type ClassicContext struct { func (c ClassicContext) Body() (any, error) { panic("this method should not be called. It probably happened because you passed the context to another controller with the Pass method.") } + func (c ClassicContext) MustBody() any { b, err := c.Body() if err != nil { @@ -146,8 +149,11 @@ type readOptions struct { LogBody bool } -var _ Ctx[any] = &Context[any]{} // Check that Context implements Ctx. -var _ Ctx[any] = &ClassicContext{} // Check that Context implements Ctx. +var ( + _ Ctx[any] = &Context[any]{} // Check that Context implements Ctx. + _ Ctx[string] = &Context[string]{} // Check that Context implements Ctx. + _ Ctx[any] = &ClassicContext{} // Check that Context implements Ctx. +) // Context returns the context of the request. // Same as c.Request().Context(). @@ -174,14 +180,13 @@ func (c ClassicContext) Pass() ClassicContext { // the need to parse the templates on each request but also preventing // to dynamically use new templates. func (c ClassicContext) Render(templateToExecute string, data any, layoutsGlobs ...string) (HTML, error) { - if !c.templatesParsed && - (strings.Contains(templateToExecute, "/") || strings.Contains(templateToExecute, "*")) { + if strings.Contains(templateToExecute, "/") || strings.Contains(templateToExecute, "*") { layoutsGlobs = append(layoutsGlobs, templateToExecute) // To override all blocks defined in the main template cloned := template.Must(c.templates.Clone()) tmpl, err := cloned.ParseFS(c.fs, layoutsGlobs...) if err != nil { - return "", ErrorResponse{ + return "", HTTPError{ StatusCode: http.StatusInternalServerError, Message: fmt.Errorf("error parsing template '%s': %w", layoutsGlobs, err).Error(), MoreInfo: map[string]any{ @@ -191,7 +196,6 @@ func (c ClassicContext) Render(templateToExecute string, data any, layoutsGlobs } } c.templates = template.Must(tmpl.Clone()) - c.templatesParsed = true } // Get only last template name (for example, with partials/nav/main/nav.partial.html, get nav.partial.html) @@ -201,7 +205,7 @@ func (c ClassicContext) Render(templateToExecute string, data any, layoutsGlobs c.response.Header().Set("Content-Type", "text/html; charset=utf-8") err := c.templates.ExecuteTemplate(c.response, templateToExecute, data) if err != nil { - return "", ErrorResponse{ + return "", HTTPError{ StatusCode: http.StatusInternalServerError, Message: fmt.Errorf("error executing template '%s': %w", templateToExecute, err).Error(), MoreInfo: map[string]any{ diff --git a/deserialization.go b/deserialization.go index e2c9d9e2..6a594527 100644 --- a/deserialization.go +++ b/deserialization.go @@ -45,18 +45,18 @@ func readJSON[B any](input io.Reader, options readOptions) (B, error) { } err := dec.Decode(&body) if err != nil { - return body, fmt.Errorf("cannot decode request body: %w", err) + return body, BadRequestError{Message: "cannot decode request body: " + err.Error()} } slog.Debug("Decoded body", "body", body) body, err = transform(body) if err != nil { - return body, fmt.Errorf("cannot transform request body: %w", err) + return body, BadRequestError{Message: "cannot transform request body: " + err.Error()} } err = validate(body) if err != nil { - return body, fmt.Errorf("cannot validate request body: %w", err) + return body, BadRequestError{Message: "cannot validate request body: " + err.Error()} } return body, nil @@ -73,7 +73,7 @@ func readString[B ~string](input io.Reader, options readOptions) (B, error) { // Read the request body. readBody, err := io.ReadAll(input) if err != nil { - return "", fmt.Errorf("cannot read request body: %w", err) + return "", BadRequestError{Message: "cannot read request body: " + err.Error()} } body := B(readBody) @@ -150,7 +150,7 @@ func transform[B any](body B) (B, error) { if inTransformerBody, ok := any(&body).(InTransformer); ok { err := inTransformerBody.InTransform() if err != nil { - return body, fmt.Errorf("cannot transform request body: %w", err) + return body, BadRequestError{Message: "cannot transform request body: " + err.Error()} } body = *any(inTransformerBody).(*B) diff --git a/deserialization_test.go b/deserialization_test.go index f6b8ad16..66846f66 100644 --- a/deserialization_test.go +++ b/deserialization_test.go @@ -26,7 +26,7 @@ func TestReadJSON(t *testing.T) { t.Run("cannot read invalid JSON", func(t *testing.T) { _, err := ReadJSON[TestBody](input) - require.Error(t, err) + require.ErrorAs(t, err, &BadRequestError{}, "Expected a BadRequestError") }) t.Run("cannot deserialize JSON to wrong struct", func(t *testing.T) { @@ -36,7 +36,7 @@ func TestReadJSON(t *testing.T) { // Missing C bool } _, err := ReadJSON[WrongBody](input) - require.Error(t, err) + require.ErrorAs(t, err, &BadRequestError{}, "Expected a BadRequestError") }) } @@ -127,7 +127,7 @@ func TestInTransformStringWithError(t *testing.T) { t.Run("ReadString", func(t *testing.T) { input := strings.NewReader(`coucou`) body, err := ReadString[transformableStringWithError](input) - require.Error(t, err) + require.ErrorAs(t, err, &BadRequestError{}, "Expected a BadRequestError") require.Equal(t, transformableStringWithError("transformed coucou"), body) }) } diff --git a/errors.go b/errors.go index d99131e0..ba804c04 100644 --- a/errors.go +++ b/errors.go @@ -20,37 +20,63 @@ type ErrorWithInfo interface { Info() map[string]any } -// ErrorResponse is the error response used by the serialization part of the framework. -type ErrorResponse struct { +// HTTPError is the error response used by the serialization part of the framework. +type HTTPError struct { + Err error `json:",omitempty"` // backend developer readable error message Message string `json:"error" xml:"Error"` // human readable error message StatusCode int `json:"-" xml:"-"` // http status code MoreInfo map[string]any `json:"info,omitempty" xml:"Info,omitempty"` // additional info } -func (e ErrorResponse) Error() string { +var ( + _ ErrorWithInfo = HTTPError{} + _ ErrorWithStatus = HTTPError{} +) + +func (e HTTPError) Error() string { return e.Message } -var _ ErrorWithStatus = ErrorResponse{} +func (e HTTPError) Info() map[string]any { + return e.MoreInfo +} -func (e ErrorResponse) Status() int { +func (e HTTPError) Status() int { if e.StatusCode == 0 { return http.StatusInternalServerError } return e.StatusCode } -var _ ErrorWithInfo = ErrorResponse{} +// BadRequestError is an error used to return a 400 status code. +type BadRequestError struct { + Err error // developer readable error message + Message string `json:"error" xml:"Error"` // human readable error message + MoreInfo map[string]any `json:"info,omitempty" xml:"Info,omitempty"` // additional info +} + +var ( + _ ErrorWithInfo = BadRequestError{} + _ ErrorWithStatus = BadRequestError{} +) + +func (e BadRequestError) Error() string { + return e.Message +} -func (e ErrorResponse) Info() map[string]any { +func (e BadRequestError) Info() map[string]any { return e.MoreInfo } +func (e BadRequestError) Status() int { + return http.StatusBadRequest +} + // ErrorHandler is the default error handler used by the framework. -// It transforms any error into the unified error type [ErrorResponse], +// It transforms any error into the unified error type [HTTPError], // Using the [ErrorWithStatus] and [ErrorWithInfo] interfaces. func ErrorHandler(err error) error { - errResponse := ErrorResponse{ + errResponse := HTTPError{ Message: err.Error(), } diff --git a/errors_test.go b/errors_test.go index ca0aa279..c83863dd 100644 --- a/errors_test.go +++ b/errors_test.go @@ -20,10 +20,10 @@ func TestErrorHandler(t *testing.T) { err := errors.New("test error") errResponse := ErrorHandler(err) - require.ErrorAs(t, errResponse, &ErrorResponse{}) + require.ErrorAs(t, errResponse, &HTTPError{}) require.Equal(t, "test error", errResponse.Error()) - require.Equal(t, http.StatusInternalServerError, errResponse.(ErrorResponse).Status()) - require.Nil(t, errResponse.(ErrorResponse).Info()) + require.Equal(t, http.StatusInternalServerError, errResponse.(HTTPError).Status()) + require.Nil(t, errResponse.(HTTPError).Info()) }) t.Run("error with status ", func(t *testing.T) { @@ -31,14 +31,14 @@ func TestErrorHandler(t *testing.T) { status: http.StatusNotFound, } errResponse := ErrorHandler(err) - require.ErrorAs(t, errResponse, &ErrorResponse{}) + require.ErrorAs(t, errResponse, &HTTPError{}) require.Equal(t, "test error", errResponse.Error()) - require.Equal(t, http.StatusNotFound, errResponse.(ErrorResponse).Status()) - require.Nil(t, errResponse.(ErrorResponse).Info()) + require.Equal(t, http.StatusNotFound, errResponse.(HTTPError).Status()) + require.Nil(t, errResponse.(HTTPError).Info()) }) t.Run("error with status and info", func(t *testing.T) { - err := ErrorResponse{ + err := HTTPError{ Message: "test error", StatusCode: http.StatusNotFound, MoreInfo: map[string]any{ @@ -46,9 +46,9 @@ func TestErrorHandler(t *testing.T) { }, } errResponse := ErrorHandler(err) - require.ErrorAs(t, errResponse, &ErrorResponse{}) + require.ErrorAs(t, errResponse, &HTTPError{}) require.Equal(t, "test error", errResponse.Error()) - require.Equal(t, http.StatusNotFound, errResponse.(ErrorResponse).Status()) - require.NotNil(t, errResponse.(ErrorResponse).Info()) + require.Equal(t, http.StatusNotFound, errResponse.(HTTPError).Status()) + require.NotNil(t, errResponse.(HTTPError).Info()) }) } diff --git a/examples/simple-crud/errors_custom.go b/examples/simple-crud/errors_custom.go new file mode 100644 index 00000000..7cd92619 --- /dev/null +++ b/examples/simple-crud/errors_custom.go @@ -0,0 +1,11 @@ +package main + +import "net/http" + +type MyError struct { + Err error // developer readable error message +} + +func (e MyError) Status() int { + return http.StatusTeapot +} diff --git a/examples/simple-crud/store/ingredient.go b/examples/simple-crud/store/ingredient.go index fc0d291f..3f0a2220 100644 --- a/examples/simple-crud/store/ingredient.go +++ b/examples/simple-crud/store/ingredient.go @@ -60,5 +60,4 @@ func (i Ingredient) Months() string { } return strings.Join(months, ", ") - } diff --git a/examples/simple-crud/views/admin.go b/examples/simple-crud/views/admin.go index 524cb8b4..6fd6ff56 100644 --- a/examples/simple-crud/views/admin.go +++ b/examples/simple-crud/views/admin.go @@ -155,7 +155,6 @@ func (rs Ressource) adminAddDosing(c fuego.Ctx[store.CreateDosingParams]) (any, } func (rs Ressource) adminIngredients(c fuego.Ctx[any]) (any, error) { - searchParams := components.SearchParams{ Name: c.QueryParam("name"), PerPage: c.QueryParamInt("perPage", 10), diff --git a/middleware/basicauth/basicauth.go b/middleware/basicauth/basicauth.go index cf529107..99c8ff70 100644 --- a/middleware/basicauth/basicauth.go +++ b/middleware/basicauth/basicauth.go @@ -29,7 +29,7 @@ func New(config Config) func(http.Handler) http.Handler { return } - err := fuego.ErrorResponse{ + err := fuego.HTTPError{ Message: "unauthorized", StatusCode: http.StatusUnauthorized, } diff --git a/options_test.go b/options_test.go index 46bc8137..08ba3e96 100644 --- a/options_test.go +++ b/options_test.go @@ -57,6 +57,6 @@ func TestWithXML(t *testing.T) { require.Equal(t, 500, recorder.Code) require.Equal(t, "application/xml", recorder.Header().Get("Content-Type")) - require.Equal(t, "error", recorder.Body.String()) + require.Equal(t, "error", recorder.Body.String()) }) } diff --git a/serialization.go b/serialization.go index 7e403b45..a6c51116 100644 --- a/serialization.go +++ b/serialization.go @@ -93,7 +93,7 @@ func SendJSON(w http.ResponseWriter, ans any) { // If the error implements ErrorWithStatus, the status code will be set. func SendJSONError(w http.ResponseWriter, err error) { status := http.StatusInternalServerError - errorStatus := ErrorResponse{ + errorStatus := HTTPError{ Message: err.Error(), } if errors.As(err, &errorStatus) { diff --git a/serialization_test.go b/serialization_test.go index e6af9c47..1283ef33 100644 --- a/serialization_test.go +++ b/serialization_test.go @@ -54,11 +54,11 @@ func TestXML(t *testing.T) { t.Run("can serialize xml error", func(t *testing.T) { w := httptest.NewRecorder() - err := ErrorResponse{Message: "Hello World"} + err := HTTPError{Message: "Hello World"} SendXMLError(w, err) body := w.Body.String() - require.Equal(t, `Hello World`, body) + require.Equal(t, `Hello World`, body) }) }