diff --git a/goyaml.v2/patch_test.go b/goyaml.v2/patch_test.go new file mode 100644 index 0000000..d82655d --- /dev/null +++ b/goyaml.v2/patch_test.go @@ -0,0 +1,23 @@ +package yaml_test + +import ( + "testing" + + yaml "sigs.k8s.io/yaml/goyaml.v2" +) + +func TestIssue117(t *testing.T) { + data := []byte(` +a: +<<: +- +? +- +`) + + x := map[string]interface{}{} + err := yaml.Unmarshal([]byte(data), &x) + if err == nil { + t.Errorf("expected error, got none") + } +} diff --git a/goyaml.v3/decode.go b/goyaml.v3/decode.go index 0173b69..02e2b17 100644 --- a/goyaml.v3/decode.go +++ b/goyaml.v3/decode.go @@ -832,10 +832,10 @@ func (d *decoder) mapping(n *Node, out reflect.Value) (good bool) { if d.unmarshal(n.Content[i], k) { if mergedFields != nil { ki := k.Interface() - if mergedFields[ki] { + if d.getPossiblyUnhashableKey(mergedFields, ki) { continue } - mergedFields[ki] = true + d.setPossiblyUnhashableKey(mergedFields, ki, true) } kkind := k.Kind() if kkind == reflect.Interface { @@ -956,6 +956,24 @@ func failWantMap() { failf("map merge requires map or sequence of maps as the value") } +func (d *decoder) setPossiblyUnhashableKey(m map[interface{}]bool, key interface{}, value bool) { + defer func() { + if err := recover(); err != nil { + failf("%v", err) + } + }() + m[key] = value +} + +func (d *decoder) getPossiblyUnhashableKey(m map[interface{}]bool, key interface{}) bool { + defer func() { + if err := recover(); err != nil { + failf("%v", err) + } + }() + return m[key] +} + func (d *decoder) merge(parent *Node, merge *Node, out reflect.Value) { mergedFields := d.mergedFields if mergedFields == nil { @@ -963,7 +981,7 @@ func (d *decoder) merge(parent *Node, merge *Node, out reflect.Value) { for i := 0; i < len(parent.Content); i += 2 { k := reflect.New(ifaceType).Elem() if d.unmarshal(parent.Content[i], k) { - d.mergedFields[k.Interface()] = true + d.setPossiblyUnhashableKey(d.mergedFields, k.Interface(), true) } } } diff --git a/goyaml.v3/patch_test.go b/goyaml.v3/patch_test.go index 7220301..3b76276 100644 --- a/goyaml.v3/patch_test.go +++ b/goyaml.v3/patch_test.go @@ -18,6 +18,7 @@ package yaml_test import ( "bytes" + "testing" . "gopkg.in/check.v1" yaml "sigs.k8s.io/yaml/goyaml.v3" @@ -158,3 +159,19 @@ func (s *S) TestNewLinePreserved(c *C) { // the newline at the start of the file should be preserved c.Assert(string(data), Equals, "_: |4\n\n a:\n b:\n c: d\n") } + +func TestIssue117(t *testing.T) { + data := []byte(` +a: +<<: +- +? +- +`) + + x := map[string]interface{}{} + err := yaml.Unmarshal([]byte(data), &x) + if err == nil { + t.Errorf("expected error, got none") + } +}