From 25cec2f721904d8d181c7dc73c3511708f8e3d50 Mon Sep 17 00:00:00 2001 From: Zachary Lozano Date: Mon, 26 Oct 2020 17:21:28 -0500 Subject: [PATCH] Add support for error aggregation for request/response validation (#259) --- openapi3/errors.go | 43 ++++++ openapi3/schema.go | 187 +++++++++++++++++++++---- openapi3/schema_test.go | 154 +++++++++++++++++++- openapi3/schema_validation_settings.go | 5 + openapi3filter/options.go | 1 + openapi3filter/validate_request.go | 61 +++++++- openapi3filter/validate_response.go | 8 +- 7 files changed, 420 insertions(+), 39 deletions(-) create mode 100644 openapi3/errors.go diff --git a/openapi3/errors.go b/openapi3/errors.go new file mode 100644 index 000000000..ce52cd483 --- /dev/null +++ b/openapi3/errors.go @@ -0,0 +1,43 @@ +package openapi3 + +import ( + "bytes" + "errors" +) + +// MultiError is a collection of errors, intended for when +// multiple issues need to be reported upstream +type MultiError []error + +func (me MultiError) Error() string { + buff := &bytes.Buffer{} + for _, e := range me { + buff.WriteString(e.Error()) + buff.WriteString(" | ") + } + return buff.String() +} + +//Is allows you to determine if a generic error is in fact a MultiError using `errors.Is()` +//It will also return true if any of the contained errors match target +func (me MultiError) Is(target error) bool { + if _, ok := target.(MultiError); ok { + return true + } + for _, e := range me { + if errors.Is(e, target) { + return true + } + } + return false +} + +//As allows you to use `errors.As()` to set target to the first error within the multi error that matches the target type +func (me MultiError) As(target interface{}) bool { + for _, e := range me { + if errors.As(e, target) { + return true + } + } + return false +} diff --git a/openapi3/schema.go b/openapi3/schema.go index a4e3b8d00..41d6fd909 100644 --- a/openapi3/schema.go +++ b/openapi3/schema.go @@ -802,19 +802,24 @@ func (schema *Schema) VisitJSONNumber(value float64) error { return schema.visitJSONNumber(settings, value) } -func (schema *Schema) visitJSONNumber(settings *schemaValidationSettings, value float64) (err error) { +func (schema *Schema) visitJSONNumber(settings *schemaValidationSettings, value float64) error { + var me MultiError schemaType := schema.Type if schemaType == "integer" { if bigFloat := big.NewFloat(value); !bigFloat.IsInt() { if settings.failfast { return errSchema } - return &SchemaError{ + err := &SchemaError{ Value: value, Schema: schema, SchemaField: "type", Reason: "Value must be an integer", } + if !settings.multiError { + return err + } + me = append(me, err) } } else if schemaType != "" && schemaType != "number" { return schema.expectedType(settings, "number, integer") @@ -825,12 +830,16 @@ func (schema *Schema) visitJSONNumber(settings *schemaValidationSettings, value if settings.failfast { return errSchema } - return &SchemaError{ + err := &SchemaError{ Value: value, Schema: schema, SchemaField: "exclusiveMinimum", Reason: fmt.Sprintf("Number must be more than %g", *schema.Min), } + if !settings.multiError { + return err + } + me = append(me, err) } // "exclusiveMaximum" @@ -838,12 +847,16 @@ func (schema *Schema) visitJSONNumber(settings *schemaValidationSettings, value if settings.failfast { return errSchema } - return &SchemaError{ + err := &SchemaError{ Value: value, Schema: schema, SchemaField: "exclusiveMaximum", Reason: fmt.Sprintf("Number must be less than %g", *schema.Max), } + if !settings.multiError { + return err + } + me = append(me, err) } // "minimum" @@ -851,12 +864,16 @@ func (schema *Schema) visitJSONNumber(settings *schemaValidationSettings, value if settings.failfast { return errSchema } - return &SchemaError{ + err := &SchemaError{ Value: value, Schema: schema, SchemaField: "minimum", Reason: fmt.Sprintf("Number must be at least %g", *v), } + if !settings.multiError { + return err + } + me = append(me, err) } // "maximum" @@ -864,12 +881,16 @@ func (schema *Schema) visitJSONNumber(settings *schemaValidationSettings, value if settings.failfast { return errSchema } - return &SchemaError{ + err := &SchemaError{ Value: value, Schema: schema, SchemaField: "maximum", Reason: fmt.Sprintf("Number must be most %g", *v), } + if !settings.multiError { + return err + } + me = append(me, err) } // "multipleOf" @@ -880,14 +901,23 @@ func (schema *Schema) visitJSONNumber(settings *schemaValidationSettings, value if settings.failfast { return errSchema } - return &SchemaError{ + err := &SchemaError{ Value: value, Schema: schema, SchemaField: "multipleOf", } + if !settings.multiError { + return err + } + me = append(me, err) } } - return + + if len(me) > 0 { + return me + } + + return nil } func (schema *Schema) VisitJSONString(value string) error { @@ -895,11 +925,13 @@ func (schema *Schema) VisitJSONString(value string) error { return schema.visitJSONString(settings, value) } -func (schema *Schema) visitJSONString(settings *schemaValidationSettings, value string) (err error) { +func (schema *Schema) visitJSONString(settings *schemaValidationSettings, value string) error { if schemaType := schema.Type; schemaType != "" && schemaType != "string" { return schema.expectedType(settings, "string") } + var me MultiError + // "minLength" and "maxLength" minLength := schema.MinLength maxLength := schema.MaxLength @@ -917,23 +949,31 @@ func (schema *Schema) visitJSONString(settings *schemaValidationSettings, value if settings.failfast { return errSchema } - return &SchemaError{ + err := &SchemaError{ Value: value, Schema: schema, SchemaField: "minLength", Reason: fmt.Sprintf("Minimum string length is %d", minLength), } + if !settings.multiError { + return err + } + me = append(me, err) } if maxLength != nil && length > int64(*maxLength) { if settings.failfast { return errSchema } - return &SchemaError{ + err := &SchemaError{ Value: value, Schema: schema, SchemaField: "maxLength", Reason: fmt.Sprintf("Maximum string length is %d", *maxLength), } + if !settings.multiError { + return err + } + me = append(me, err) } } @@ -970,15 +1010,24 @@ func (schema *Schema) visitJSONString(settings *schemaValidationSettings, value if schema.Pattern != "" { field = "pattern" } - return &SchemaError{ + err := &SchemaError{ Value: value, Schema: schema, SchemaField: field, Reason: cp.ErrReason, } + if !settings.multiError { + return err + } + me = append(me, err) } } - return + + if len(me) > 0 { + return me + } + + return nil } func (schema *Schema) VisitJSONArray(value []interface{}) error { @@ -986,11 +1035,13 @@ func (schema *Schema) VisitJSONArray(value []interface{}) error { return schema.visitJSONArray(settings, value) } -func (schema *Schema) visitJSONArray(settings *schemaValidationSettings, value []interface{}) (err error) { +func (schema *Schema) visitJSONArray(settings *schemaValidationSettings, value []interface{}) error { if schemaType := schema.Type; schemaType != "" && schemaType != "array" { return schema.expectedType(settings, "array") } + var me MultiError + lenValue := int64(len(value)) // "minItems" @@ -998,12 +1049,16 @@ func (schema *Schema) visitJSONArray(settings *schemaValidationSettings, value [ if settings.failfast { return errSchema } - return &SchemaError{ + err := &SchemaError{ Value: value, Schema: schema, SchemaField: "minItems", Reason: fmt.Sprintf("Minimum number of items is %d", v), } + if !settings.multiError { + return err + } + me = append(me, err) } // "maxItems" @@ -1011,12 +1066,16 @@ func (schema *Schema) visitJSONArray(settings *schemaValidationSettings, value [ if settings.failfast { return errSchema } - return &SchemaError{ + err := &SchemaError{ Value: value, Schema: schema, SchemaField: "maxItems", Reason: fmt.Sprintf("Maximum number of items is %d", *v), } + if !settings.multiError { + return err + } + me = append(me, err) } // "uniqueItems" @@ -1027,12 +1086,16 @@ func (schema *Schema) visitJSONArray(settings *schemaValidationSettings, value [ if settings.failfast { return errSchema } - return &SchemaError{ + err := &SchemaError{ Value: value, Schema: schema, SchemaField: "uniqueItems", Reason: fmt.Sprintf("Duplicate items found"), } + if !settings.multiError { + return err + } + me = append(me, err) } // "items" @@ -1042,12 +1105,25 @@ func (schema *Schema) visitJSONArray(settings *schemaValidationSettings, value [ return foundUnresolvedRef(itemSchemaRef.Ref) } for i, item := range value { - if err := itemSchema.VisitJSON(item); err != nil { - return markSchemaErrorIndex(err, i) + if err := itemSchema.visitJSON(settings, item); err != nil { + err = markSchemaErrorIndex(err, i) + if !settings.multiError { + return err + } + if itemMe, ok := err.(MultiError); ok { + me = append(me, itemMe...) + } else { + me = append(me, err) + } } } } - return + + if len(me) > 0 { + return me + } + + return nil } func (schema *Schema) VisitJSONObject(value map[string]interface{}) error { @@ -1055,11 +1131,13 @@ func (schema *Schema) VisitJSONObject(value map[string]interface{}) error { return schema.visitJSONObject(settings, value) } -func (schema *Schema) visitJSONObject(settings *schemaValidationSettings, value map[string]interface{}) (err error) { +func (schema *Schema) visitJSONObject(settings *schemaValidationSettings, value map[string]interface{}) error { if schemaType := schema.Type; schemaType != "" && schemaType != "object" { return schema.expectedType(settings, "object") } + var me MultiError + // "properties" properties := schema.Properties lenValue := int64(len(value)) @@ -1069,12 +1147,16 @@ func (schema *Schema) visitJSONObject(settings *schemaValidationSettings, value if settings.failfast { return errSchema } - return &SchemaError{ + err := &SchemaError{ Value: value, Schema: schema, SchemaField: "minProperties", Reason: fmt.Sprintf("There must be at least %d properties", v), } + if !settings.multiError { + return err + } + me = append(me, err) } // "maxProperties" @@ -1082,12 +1164,16 @@ func (schema *Schema) visitJSONObject(settings *schemaValidationSettings, value if settings.failfast { return errSchema } - return &SchemaError{ + err := &SchemaError{ Value: value, Schema: schema, SchemaField: "maxProperties", Reason: fmt.Sprintf("There must be at most %d properties", *v), } + if !settings.multiError { + return err + } + me = append(me, err) } // "additionalProperties" @@ -1103,11 +1189,19 @@ func (schema *Schema) visitJSONObject(settings *schemaValidationSettings, value if p == nil { return foundUnresolvedRef(propertyRef.Ref) } - if err := p.VisitJSON(v); err != nil { + if err := p.visitJSON(settings, v); err != nil { if settings.failfast { return errSchema } - return markSchemaErrorKey(err, k) + err = markSchemaErrorKey(err, k) + if !settings.multiError { + return err + } + if v, ok := err.(MultiError); ok { + me = append(me, v...) + continue + } + me = append(me, err) } continue } @@ -1115,11 +1209,19 @@ func (schema *Schema) visitJSONObject(settings *schemaValidationSettings, value allowed := schema.AdditionalPropertiesAllowed if additionalProperties != nil || allowed == nil || (allowed != nil && *allowed) { if additionalProperties != nil { - if err := additionalProperties.VisitJSON(v); err != nil { + if err := additionalProperties.visitJSON(settings, v); err != nil { if settings.failfast { return errSchema } - return markSchemaErrorKey(err, k) + err = markSchemaErrorKey(err, k) + if !settings.multiError { + return err + } + if v, ok := err.(MultiError); ok { + me = append(me, v...) + continue + } + me = append(me, err) } } continue @@ -1127,12 +1229,16 @@ func (schema *Schema) visitJSONObject(settings *schemaValidationSettings, value if settings.failfast { return errSchema } - return &SchemaError{ + err := &SchemaError{ Value: value, Schema: schema, SchemaField: "properties", Reason: fmt.Sprintf("Property '%s' is unsupported", k), } + if !settings.multiError { + return err + } + me = append(me, err) } // "required" @@ -1147,15 +1253,24 @@ func (schema *Schema) visitJSONObject(settings *schemaValidationSettings, value if settings.failfast { return errSchema } - return markSchemaErrorKey(&SchemaError{ + err := markSchemaErrorKey(&SchemaError{ Value: value, Schema: schema, SchemaField: "required", Reason: fmt.Sprintf("Property '%s' is missing", k), }, k) + if !settings.multiError { + return err + } + me = append(me, err) } } - return + + if len(me) > 0 { + return me + } + + return nil } func (schema *Schema) expectedType(settings *schemaValidationSettings, typ string) error { @@ -1184,6 +1299,12 @@ func markSchemaErrorKey(err error, key string) error { v.reversePath = append(v.reversePath, key) return v } + if v, ok := err.(MultiError); ok { + for _, e := range v { + _ = markSchemaErrorKey(e, key) + } + return v + } return err } @@ -1192,6 +1313,12 @@ func markSchemaErrorIndex(err error, index int) error { v.reversePath = append(v.reversePath, strconv.FormatInt(int64(index), 10)) return v } + if v, ok := err.(MultiError); ok { + for _, e := range v { + _ = markSchemaErrorIndex(e, index) + } + return v + } return err } diff --git a/openapi3/schema_test.go b/openapi3/schema_test.go index 10e1d0589..6650ad547 100644 --- a/openapi3/schema_test.go +++ b/openapi3/schema_test.go @@ -4,7 +4,9 @@ import ( "context" "encoding/base64" "encoding/json" + "fmt" "math" + "reflect" "strings" "testing" @@ -59,13 +61,13 @@ func testSchema(t *testing.T, example schemaExample) func(*testing.T) { } } -func validateSchema(t *testing.T, schema *Schema, value interface{}) error { +func validateSchema(t *testing.T, schema *Schema, value interface{}, opts ...SchemaValidationOption) error { data, err := json.Marshal(value) require.NoError(t, err) var val interface{} err = json.Unmarshal(data, &val) require.NoError(t, err) - return schema.VisitJSON(val) + return schema.VisitJSON(val, opts...) } var schemaExamples = []schemaExample{ @@ -1042,3 +1044,151 @@ var schemaErrorExamples = []schemaErrorExample{ Want: "NEST", }, } + +type schemaMultiErrorExample struct { + Title string + Schema *Schema + Values []interface{} + ExpectedErrors []MultiError +} + +func TestSchemasMultiError(t *testing.T) { + for _, example := range schemaMultiErrorExamples { + t.Run(example.Title, testSchemaMultiError(t, example)) + } +} + +func testSchemaMultiError(t *testing.T, example schemaMultiErrorExample) func(*testing.T) { + return func(t *testing.T) { + schema := example.Schema + for i, value := range example.Values { + err := validateSchema(t, schema, value, MultiErrors()) + require.Error(t, err) + require.IsType(t, MultiError{}, err) + + merr, _ := err.(MultiError) + expected := example.ExpectedErrors[i] + require.True(t, len(merr) > 0) + require.Len(t, merr, len(expected)) + for _, e := range merr { + require.IsType(t, &SchemaError{}, e) + var found bool + scherr, _ := e.(*SchemaError) + for _, expectedErr := range expected { + expectedScherr, _ := expectedErr.(*SchemaError) + if reflect.DeepEqual(expectedScherr.reversePath, scherr.reversePath) && + expectedScherr.SchemaField == scherr.SchemaField { + found = true + break + } + } + require.True(t, found, fmt.Sprintf("Missing %s error on %s", scherr.SchemaField, strings.Join(scherr.JSONPointer(), "."))) + } + } + } +} + +var schemaMultiErrorExamples = []schemaMultiErrorExample{ + { + Title: "STRING", + Schema: NewStringSchema(). + WithMinLength(2). + WithMaxLength(3). + WithPattern("^[abc]+$"), + Values: []interface{}{ + "f", + "foobar", + }, + ExpectedErrors: []MultiError{ + {&SchemaError{SchemaField: "minLength"}, &SchemaError{SchemaField: "pattern"}}, + {&SchemaError{SchemaField: "maxLength"}, &SchemaError{SchemaField: "pattern"}}, + }, + }, + { + Title: "NUMBER", + Schema: NewIntegerSchema(). + WithMin(1). + WithMax(10), + Values: []interface{}{ + 0.5, + 10.1, + }, + ExpectedErrors: []MultiError{ + {&SchemaError{SchemaField: "type"}, &SchemaError{SchemaField: "minimum"}}, + {&SchemaError{SchemaField: "type"}, &SchemaError{SchemaField: "maximum"}}, + }, + }, + { + Title: "ARRAY: simple", + Schema: NewArraySchema(). + WithMinItems(2). + WithMaxItems(2). + WithItems(NewStringSchema(). + WithPattern("^[abc]+$")), + Values: []interface{}{ + []interface{}{"foo"}, + []interface{}{"foo", "bar", "fizz"}, + }, + ExpectedErrors: []MultiError{ + { + &SchemaError{SchemaField: "minItems"}, + &SchemaError{SchemaField: "pattern", reversePath: []string{"0"}}, + }, + { + &SchemaError{SchemaField: "maxItems"}, + &SchemaError{SchemaField: "pattern", reversePath: []string{"0"}}, + &SchemaError{SchemaField: "pattern", reversePath: []string{"1"}}, + &SchemaError{SchemaField: "pattern", reversePath: []string{"2"}}, + }, + }, + }, + { + Title: "ARRAY: object", + Schema: NewArraySchema(). + WithItems(NewObjectSchema(). + WithProperties(map[string]*Schema{ + "key1": NewStringSchema(), + "key2": NewIntegerSchema(), + }), + ), + Values: []interface{}{ + []interface{}{ + map[string]interface{}{ + "key1": 100, // not a string + "key2": "not an integer", + }, + }, + }, + ExpectedErrors: []MultiError{ + { + &SchemaError{SchemaField: "type", reversePath: []string{"key1", "0"}}, + &SchemaError{SchemaField: "type", reversePath: []string{"key2", "0"}}, + }, + }, + }, + { + Title: "OBJECT", + Schema: NewObjectSchema(). + WithProperties(map[string]*Schema{ + "key1": NewStringSchema(), + "key2": NewIntegerSchema(), + "key3": NewArraySchema(). + WithItems(NewStringSchema(). + WithPattern("^[abc]+$")), + }), + Values: []interface{}{ + map[string]interface{}{ + "key1": 100, // not a string + "key2": "not an integer", + "key3": []interface{}{"abc", "def"}, + }, + }, + ExpectedErrors: []MultiError{ + { + &SchemaError{SchemaField: "type", reversePath: []string{"key1"}}, + &SchemaError{SchemaField: "type", reversePath: []string{"key2"}}, + &SchemaError{SchemaField: "pattern", reversePath: []string{"1", "key3"}}, + }, + }, + }, +} diff --git a/openapi3/schema_validation_settings.go b/openapi3/schema_validation_settings.go index 6c073cd43..71db5f237 100644 --- a/openapi3/schema_validation_settings.go +++ b/openapi3/schema_validation_settings.go @@ -5,6 +5,7 @@ type SchemaValidationOption func(*schemaValidationSettings) type schemaValidationSettings struct { failfast bool + multiError bool asreq, asrep bool // exclusive (XOR) fields } @@ -13,6 +14,10 @@ func FailFast() SchemaValidationOption { return func(s *schemaValidationSettings) { s.failfast = true } } +func MultiErrors() SchemaValidationOption { + return func(s *schemaValidationSettings) { s.multiError = true } +} + func VisitAsRequest() SchemaValidationOption { return func(s *schemaValidationSettings) { s.asreq, s.asrep = true, false } } diff --git a/openapi3filter/options.go b/openapi3filter/options.go index 510b77756..60e5475f1 100644 --- a/openapi3filter/options.go +++ b/openapi3filter/options.go @@ -10,5 +10,6 @@ type Options struct { ExcludeRequestBody bool ExcludeResponseBody bool IncludeResponseStatus bool + MultiError bool AuthenticationFunc func(c context.Context, input *AuthenticationInput) error } diff --git a/openapi3filter/validate_request.go b/openapi3filter/validate_request.go index 69fc58bd1..0af54c299 100644 --- a/openapi3filter/validate_request.go +++ b/openapi3filter/validate_request.go @@ -22,6 +22,11 @@ var ErrInvalidRequired = errors.New("must have a value") // Note: One can tune the behavior of uniqueItems: true verification // by registering a custom function with openapi3.RegisterArrayUniqueItemsChecker func ValidateRequest(c context.Context, input *RequestValidationInput) error { + var ( + err error + me openapi3.MultiError + ) + options := input.Options if options == nil { options = DefaultOptions @@ -45,24 +50,37 @@ func ValidateRequest(c context.Context, input *RequestValidationInput) error { continue } } - if err := ValidateParameter(c, input, parameter); err != nil { + + if err = ValidateParameter(c, input, parameter); err != nil && !options.MultiError { return err } + + if err != nil { + me = append(me, err) + } } // For each parameter of the Operation for _, parameter := range operationParameters { - if err := ValidateParameter(c, input, parameter.Value); err != nil { + if err = ValidateParameter(c, input, parameter.Value); err != nil && !options.MultiError { return err } + + if err != nil { + me = append(me, err) + } } // RequestBody requestBody := operation.RequestBody if requestBody != nil && !options.ExcludeRequestBody { - if err := ValidateRequestBody(c, input, requestBody.Value); err != nil { + if err = ValidateRequestBody(c, input, requestBody.Value); err != nil && !options.MultiError { return err } + + if err != nil { + me = append(me, err) + } } // Security @@ -76,10 +94,19 @@ func ValidateRequest(c context.Context, input *RequestValidationInput) error { security = &route.Swagger.Security } if security != nil { - if err := ValidateSecurityRequirements(c, input, *security); err != nil { + if err = ValidateSecurityRequirements(c, input, *security); err != nil && !options.MultiError { return err } + + if err != nil { + me = append(me, err) + } } + + if len(me) > 0 { + return me + } + return nil } @@ -95,6 +122,11 @@ func ValidateParameter(c context.Context, input *RequestValidationInput, paramet return nil } + options := input.Options + if options == nil { + options = DefaultOptions + } + var value interface{} var err error var schema *openapi3.Schema @@ -121,7 +153,13 @@ func ValidateParameter(c context.Context, input *RequestValidationInput, paramet // A parameter's schema is not defined so skip validation of a parameter's value. return nil } - if err = schema.VisitJSON(value); err != nil { + + var opts []openapi3.SchemaValidationOption + if options.MultiError { + opts = make([]openapi3.SchemaValidationOption, 0, 1) + opts = append(opts, openapi3.MultiErrors()) + } + if err = schema.VisitJSON(value, opts...); err != nil { return &RequestError{Input: input, Parameter: parameter, Err: err} } return nil @@ -137,6 +175,11 @@ func ValidateRequestBody(c context.Context, input *RequestValidationInput, reque data []byte ) + options := input.Options + if options == nil { + options = DefaultOptions + } + if req.Body != http.NoBody && req.Body != nil { defer req.Body.Close() var err error @@ -191,8 +234,14 @@ func ValidateRequestBody(c context.Context, input *RequestValidationInput, reque } } + opts := make([]openapi3.SchemaValidationOption, 0, 2) // 2 potential opts here + opts = append(opts, openapi3.VisitAsRequest()) + if options.MultiError { + opts = append(opts, openapi3.MultiErrors()) + } + // Validate JSON with the schema - if err := contentType.Schema.Value.VisitJSON(value, openapi3.VisitAsRequest()); err != nil { + if err := contentType.Schema.Value.VisitJSON(value, opts...); err != nil { return &RequestError{ Input: input, RequestBody: requestBody, diff --git a/openapi3filter/validate_response.go b/openapi3filter/validate_response.go index 9a458aa1b..f203802a4 100644 --- a/openapi3filter/validate_response.go +++ b/openapi3filter/validate_response.go @@ -128,8 +128,14 @@ func ValidateResponse(c context.Context, input *ResponseValidationInput) error { } } + opts := make([]openapi3.SchemaValidationOption, 0, 2) // 2 potential opts here + opts = append(opts, openapi3.VisitAsRequest()) + if options.MultiError { + opts = append(opts, openapi3.MultiErrors()) + } + // Validate data with the schema. - if err := contentType.Schema.Value.VisitJSON(value, openapi3.VisitAsResponse()); err != nil { + if err := contentType.Schema.Value.VisitJSON(value, opts...); err != nil { return &ResponseError{ Input: input, Reason: "response body doesn't match the schema",