Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set mapping keys to "decoded" when custom unmarshaler is used #426

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,19 @@ func (md *MetaData) PrimitiveDecode(primValue Primitive, v any) error {
return md.unify(primValue.undecoded, rvalue(v))
}

// markDecodedRecursive is a helper to mark any key under the given tmap
// as decoded, recursing as needed
func markDecodedRecursive(md *MetaData, tmap map[string]any) {
for key := range tmap {
md.decoded[md.context.add(key).String()] = struct{}{}
if tmap, ok := tmap[key].(map[string]any); ok {
md.context = append(md.context, key)
markDecodedRecursive(md, tmap)
md.context = md.context[0 : len(md.context)-1]
}
}
}

// unify performs a sort of type unification based on the structure of `rv`,
// which is the client representation.
//
Expand All @@ -222,6 +235,11 @@ func (md *MetaData) unify(data any, rv reflect.Value) error {
if err != nil {
return md.parseErr(err)
}
// assume the Unmarshaler did it's job and mark all
// keys under this map decoded
if tmap, ok := data.(map[string]any); ok {
markDecodedRecursive(md, tmap)
}
return nil
}
if v, ok := rvi.(encoding.TextUnmarshaler); ok {
Expand Down
54 changes: 53 additions & 1 deletion decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,7 @@ func TestCustomEncode(t *testing.T) {
// Test for #341
func TestCustomDecode(t *testing.T) {
var outer Outer
_, err := Decode(`
meta, err := Decode(`
Int = 10
Enum = "OTHER_VALUE"
Slice = ["text1", "text2"]
Expand All @@ -1036,6 +1036,9 @@ func TestCustomDecode(t *testing.T) {
if fmt.Sprint(outer.Slice.value) != fmt.Sprint([]string{"text1", "text2"}) {
t.Errorf("\nhave:\n%v\nwant:\n%v\n", outer.Slice.value, []string{"text1", "text2"})
}
if len(meta.Undecoded()) > 0 {
t.Errorf("\ncustom decode leaves unencoded fields: %v\n", meta.Undecoded())
}
}

// TODO: this should be improved for v2:
Expand Down Expand Up @@ -1144,3 +1147,52 @@ func BenchmarkKey(b *testing.B) {
k.String()
}
}

type CustomStruct struct {
Foo string
TblB int64
TblInlineC int64
}

func (cs *CustomStruct) UnmarshalTOML(data interface{}) error {
d, _ := data.(map[string]interface{})
cs.Foo = d["foo"].(string)
cs.TblB = d["tbl"].(map[string]interface{})["b"].(int64)
cs.TblInlineC = d["tbl"].(map[string]interface{})["inline"].(map[string]interface{})["c"].(int64)

return nil
}

func TestDecodeCustomStructMarkedDecoded(t *testing.T) {
var cs CustomStruct
meta, err := Decode(`
foo = "bar"
a = 1
arr = [2]

[tbl]
b = 3

inline = {c = 4}
`, &cs)
if err != nil {
t.Fatalf("Decode failed: %s", err)
}

if cs.Foo != "bar" {
t.Errorf("\nhave:\n%v\nwant:\n%v\n", cs.Foo, "bar")
}
if cs.TblB != 3 {
t.Errorf("\nhave:\n%v\nwant:\n%v\n", cs.TblB, 3)
}
if cs.TblInlineC != 4 {
t.Errorf("\nhave:\n%v\nwant:\n%v\n", cs.TblB, 4)
}
// Note that even though the custom unmarshaler did not decode
// all fields as far as the metadata is concerned they are handlded.
// It is the job of the unmarshaler to ensure this or we would need
// a more powerful interface like UnmarshalTOML(data any, md *MetaData)
if len(meta.Undecoded()) > 0 {
t.Errorf("\ncustom decode leaves unencoded fields: %v\n", meta.Undecoded())
}
}
Loading