Skip to content

Commit

Permalink
feat: parse multipart.Form.Value into struct
Browse files Browse the repository at this point in the history
  • Loading branch information
b-kamphorst committed Jan 13, 2025
1 parent 1ce45cb commit 14b6613
Show file tree
Hide file tree
Showing 3 changed files with 263 additions and 33 deletions.
29 changes: 16 additions & 13 deletions formdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ type MultipartFormFiles[T any] struct {
data *T
}

func (m *MultipartFormFiles[T]) Data() *T {
return m.data
}

type MimeTypeValidator struct {
accept []string
}
Expand Down Expand Up @@ -85,13 +89,9 @@ func (v MimeTypeValidator) Validate(fh *multipart.FileHeader, location string) (
}
}

func (m *MultipartFormFiles[T]) Data() *T {
return m.data
}

// Decodes multipart.Form data into *T, returning []*ErrorDetail if any
// Schema is used to check for validation constraints
func (m *MultipartFormFiles[T]) Decode(opMediaType *MediaType) []error {
func (m *MultipartFormFiles[T]) Decode(opMediaType *MediaType, formValueParser func(val reflect.Value)) []error {
var (
dataType = reflect.TypeOf(m.data).Elem()
value = reflect.New(dataType)
Expand Down Expand Up @@ -120,11 +120,9 @@ func (m *MultipartFormFiles[T]) Decode(opMediaType *MediaType) []error {
continue
}
field.Set(reflect.ValueOf(files))

default:
continue
}
}
formValueParser(value)
m.data = value.Interface().(*T)
return errors
}
Expand Down Expand Up @@ -200,7 +198,7 @@ func formDataFieldName(f reflect.StructField) string {
return name
}

func multiPartFormFileSchema(t reflect.Type) *Schema {
func multiPartFormFileSchema(r Registry, t reflect.Type) *Schema {
nFields := t.NumField()
schema := &Schema{
Type: "object",
Expand All @@ -221,8 +219,9 @@ func multiPartFormFileSchema(t reflect.Type) *Schema {
Items: multiPartFileSchema(f),
}
default:
schema.Properties[name] = SchemaFromField(r, f, name)

// Should we panic if [T] struct defines fields with unsupported types ?
continue
}

if _, ok := f.Tag.Lookup("required"); ok && boolTag(f, "required", false) {
Expand All @@ -249,9 +248,13 @@ func multiPartContentEncoding(t reflect.Type) map[string]*Encoding {
for i := 0; i < nFields; i++ {
f := t.Field(i)
name := formDataFieldName(f)
contentType := f.Tag.Get("contentType")
if contentType == "" {
contentType = "application/octet-stream"

contentType := "text/plain"
if f.Type == reflect.TypeOf(FormFile{}) || f.Type == reflect.TypeOf([]FormFile{}) {
contentType = f.Tag.Get("contentType")
if contentType == "" {
contentType = "application/octet-stream"
}
}
encoding[name] = &Encoding{
ContentType: contentType,
Expand Down
102 changes: 82 additions & 20 deletions huma.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"errors"
"fmt"
"io"
"mime/multipart"
"net"
"net/http"
"net/url"
Expand Down Expand Up @@ -97,7 +98,6 @@ func findParams(registry Registry, op *Operation, t reflect.Type) *findResult[*p
if f.Anonymous {
return nil
}

pfi := &paramFieldInfo{
Type: f.Type,
}
Expand Down Expand Up @@ -125,6 +125,10 @@ func findParams(registry Registry, op *Operation, t reflect.Type) *findResult[*p
} else if h := f.Tag.Get("header"); h != "" {
pfi.Loc = "header"
name = h
} else if fo := f.Tag.Get("form"); fo != "" {
// TODO: clearify in README that "form" tag is REQUIRED
pfi.Loc = "form"
name = fo
} else if c := f.Tag.Get("cookie"); c != "" {
pfi.Loc = "cookie"
name = c
Expand Down Expand Up @@ -720,9 +724,64 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
}

if rbt.isMultipart() {
if cErr := processMultipartMsgBody(op, ctx, v, rbt, rawBodyIndex, res); cErr != nil {
writeErr(api, ctx, cErr, *res)
return
// Read form
form, err := readForm(ctx)

if err != nil {
res.Errors = append(res.Errors, err)
} else {
var formValueParser func(val reflect.Value)
if rbt == rbtMultipart {
formValueParser = func(val reflect.Value) {}
} else {
rawBodyF := v.FieldByIndex(rawBodyIndex)
rawBodyDataF := rawBodyF.FieldByName("data")
rawBodyDataT := rawBodyDataF.Type()

rawBodyInputParams := findParams(oapi.Components.Schemas, &op, rawBodyDataT)
formValueParser = func(val reflect.Value) {
rawBodyInputParams.Every(val, func(f reflect.Value, p *paramFieldInfo) {
f = reflect.Indirect(f)
if f.Kind() == reflect.Invalid {
return

Check warning on line 746 in huma.go

View check run for this annotation

Codecov / codecov/patch

huma.go#L746

Added line #L746 was not covered by tests
}

pb.Reset()
pb.Push(p.Loc)
pb.Push(p.Name)

value, ok := form.Value[p.Name]
if !ok {
_, isFile := form.File[p.Name]
if !op.SkipValidateParams && p.Required && !isFile {
res.Add(pb, "", "required "+p.Loc+" parameter is missing")
}
return
}

// Validation should fail if multiple values are
// provided but the type of f is not a slice.
if len(value) > 1 && f.Type().Kind() != reflect.Slice {
res.Add(pb, value, "expected at most one value, but received multiple values")
return
}
s := splittableString{Raw: value[0], Splitted: value}
pv, err := parseInto(f, s, *p)
if err != nil {
res.Add(pb, value, err.Error())

Check warning on line 771 in huma.go

View check run for this annotation

Codecov / codecov/patch

huma.go#L771

Added line #L771 was not covered by tests
}

if !op.SkipValidateParams {
Validate(oapi.Components.Schemas, p.Schema, pb, ModeWriteToServer, pv, res)
}
})
}
}

if cErr := processMultipartMsgBody(form, op, v, rbt, rawBodyIndex, formValueParser); cErr != nil {
writeErr(api, ctx, cErr, *res)
return
}
}
} else {
// Read body
Expand Down Expand Up @@ -927,7 +986,7 @@ func processInputType(inputType reflect.Type, op *Operation, registry Registry)
if f, ok := inputType.FieldByName("RawBody"); ok {
rawBodyIndex = f.Index
initRequestBody(op, setRequestBodyRequired)
rbt = setRequestBodyFromRawBody(op, f)
rbt = setRequestBodyFromRawBody(op, registry, f)
}

if op.RequestBody != nil {
Expand Down Expand Up @@ -998,7 +1057,7 @@ func (r rawBodyType) isMultipart() bool {
}

// setRequestBodyFromRawBody configures op.RequestBody from the RawBody field.
func setRequestBodyFromRawBody(op *Operation, fRawBody reflect.StructField) rawBodyType {
func setRequestBodyFromRawBody(op *Operation, r Registry, fRawBody reflect.StructField) rawBodyType {
rbt := rbtOther
contentType := "application/octet-stream"
if fRawBody.Type.String() == "multipart.Form" {
Expand Down Expand Up @@ -1050,7 +1109,7 @@ func setRequestBodyFromRawBody(op *Operation, fRawBody reflect.StructField) rawB
panic("Expected type MultipartFormFiles[T] to have a 'data *T' generic pointer field")

Check warning on line 1109 in huma.go

View check run for this annotation

Codecov / codecov/patch

huma.go#L1109

Added line #L1109 was not covered by tests
}
op.RequestBody.Content["multipart/form-data"] = &MediaType{
Schema: multiPartFormFileSchema(dataField.Type.Elem()),
Schema: multiPartFormFileSchema(r, dataField.Type.Elem()),
Encoding: multiPartContentEncoding(dataField.Type.Elem()),
}
op.RequestBody.Required = false
Expand Down Expand Up @@ -1512,29 +1571,21 @@ func writeErr(api API, ctx Context, cErr *contextError, res ValidateResult) {
}
}

func processMultipartMsgBody(op Operation, ctx Context, inputValue reflect.Value, rbt rawBodyType, rawBodyIndex []int, res *ValidateResult) *contextError {
form, err := ctx.GetMultipartForm()
if err != nil {
res.Errors = append(res.Errors, &ErrorDetail{
Location: "body",
Message: "cannot read multipart form: " + err.Error(),
})
return nil
}
f := inputValue
for _, i := range rawBodyIndex {
f = f.Field(i)
}
func processMultipartMsgBody(form *multipart.Form, op Operation, v reflect.Value, rbt rawBodyType, rawBodyIndex []int, formValueParser func(val reflect.Value)) *contextError {
f := v.FieldByIndex(rawBodyIndex)
switch rbt {
case rbtMultipart:
// f is of type multipart.Form
f.Set(reflect.ValueOf(*form))
case rbtMultipartDecoded:
// f is of type MultipartFormFiles[T]
f.FieldByName("Form").Set(reflect.ValueOf(form))
r := f.Addr().
MethodByName("Decode").
Call(
[]reflect.Value{
reflect.ValueOf(op.RequestBody.Content["multipart/form-data"]),
reflect.ValueOf(formValueParser),
})
errs := r[0].Interface().([]error)
if errs != nil {
Expand All @@ -1544,6 +1595,17 @@ func processMultipartMsgBody(op Operation, ctx Context, inputValue reflect.Value
return nil
}

func readForm(ctx Context) (*multipart.Form, *ErrorDetail) {
form, err := ctx.GetMultipartForm()
if err != nil {
return form, &ErrorDetail{
Location: "body",
Message: "cannot read multipart form: " + err.Error(),
}
}
return form, nil
}

type intoUnmarshaler = func(data []byte, v any) error

// processRegularMsgBody parses the raw body with unmarshaler and validates it
Expand Down
Loading

0 comments on commit 14b6613

Please sign in to comment.