diff --git a/infer.go b/infer.go index 805d95d..66bb2e2 100644 --- a/infer.go +++ b/infer.go @@ -15,6 +15,7 @@ func Infer(data any) (bigquery.Schema, error) { func inferObject(data reflect.Value) (bigquery.Schema, error) { var schema bigquery.Schema + var embedded bigquery.Schema switch data.Kind() { case reflect.Ptr, reflect.Interface: @@ -29,11 +30,11 @@ func inferObject(data reflect.Value) (bigquery.Schema, error) { fieldInfo := data.Type().Field(i) if fieldInfo.Anonymous { - embeddedSchema, err := inferObject(field) + resp, err := inferObject(field) if err != nil { return nil, err } - schema = append(schema, embeddedSchema...) + embedded = append(embedded, resp...) continue } @@ -69,6 +70,19 @@ func inferObject(data reflect.Value) (bigquery.Schema, error) { return nil, fmt.Errorf("invalid data: %v: %w", data.Kind(), ErrUnsupportedObject) } + for _, field := range embedded { + var found bool + for _, f := range schema { + if f.Name == field.Name { + found = true + break + } + } + if !found { + schema = append(schema, field) + } + } + return schema, nil } diff --git a/infer_test.go b/infer_test.go index 6d31e22..967f29c 100644 --- a/infer_test.go +++ b/infer_test.go @@ -476,21 +476,33 @@ func TestInferArray(t *testing.T) { func TestInferMixIn(t *testing.T) { type Nest struct { - Str string + Prev string + Str string + Next string } type mix struct { + Prev int Nest - Int int + Next int + Int int } schemas := gt.R1(bqs.Infer(mix{})).NoError(t) - gt.A(t, schemas).Length(2). + gt.A(t, schemas).Length(4). At(0, func(t testing.TB, v *bigquery.FieldSchema) { - gt.Equal(t, v.Name, "Str") - gt.Equal(t, v.Type, bigquery.StringFieldType) + gt.Equal(t, v.Name, "Prev") + gt.Equal(t, v.Type, bigquery.IntegerFieldType) }). At(1, func(t testing.TB, v *bigquery.FieldSchema) { + gt.Equal(t, v.Name, "Next") + gt.Equal(t, v.Type, bigquery.IntegerFieldType) + }). + At(2, func(t testing.TB, v *bigquery.FieldSchema) { gt.Equal(t, v.Name, "Int") gt.Equal(t, v.Type, bigquery.IntegerFieldType) + }). + At(3, func(t testing.TB, v *bigquery.FieldSchema) { + gt.Equal(t, v.Name, "Str") + gt.Equal(t, v.Type, bigquery.StringFieldType) }) }