diff --git a/internal/serde/reflect.go b/internal/serde/reflect.go index 4be703d..9433184 100644 --- a/internal/serde/reflect.go +++ b/internal/serde/reflect.go @@ -294,13 +294,23 @@ func deserializePointedAt(d *Deserializer, t reflect.Type) reflect.Value { } func serializeMap(s *Serializer, t reflect.Type, p unsafe.Pointer) { - size := 0 r := reflect.NewAt(t, p).Elem() + if r.IsNil() { - size = -1 - } else { - size = r.Len() + serializeVarint(s, 0) + return } + + mapptr := r.UnsafePointer() + + id, new := s.assignPointerID(mapptr) + serializeVarint(s, int(id)) + if !new { + return + } + + size := r.Len() + serializeVarint(s, size) // TODO: allocs @@ -316,13 +326,27 @@ func serializeMap(s *Serializer, t reflect.Type, p unsafe.Pointer) { } func deserializeMap(d *Deserializer, t reflect.Type, p unsafe.Pointer) { + r := reflect.NewAt(t, p) + + ptr, id := d.readPtr() + if id == 0 { + // nil map + return + } + if ptr != nil { + // already deserialized at ptr + existing := reflect.NewAt(t, ptr).Elem() + r.Elem().Set(existing) + return + } + n := deserializeVarint(d) if n < 0 { // nil map return } nv := reflect.MakeMapWithSize(t, n) - r := reflect.NewAt(t, p) r.Elem().Set(nv) + d.store(id, p) for i := 0; i < n; i++ { k := reflect.New(t.Key()) DeserializeAny(d, t.Key(), k.UnsafePointer()) diff --git a/internal/serde/scan.go b/internal/serde/scan.go index e79b1f1..69a0914 100644 --- a/internal/serde/scan.go +++ b/internal/serde/scan.go @@ -149,6 +149,7 @@ func (c *containers) fixup(i int) { c.remove(i + 1) return } + c.remove(i + 1) // Array fully contains next container. Nothing to do return } diff --git a/serde/serde_test.go b/serde/serde_test.go index b39913c..90057d0 100644 --- a/serde/serde_test.go +++ b/serde/serde_test.go @@ -241,6 +241,32 @@ func TestReflectCustom(t *testing.T) { } func TestReflectSharing(t *testing.T) { + testReflect(t, "maps of ints", func(t *testing.T) { + m := map[int]int{1: 2, 3: 4} + + type X struct { + a map[int]int + b map[int]int + } + + x := X{ + a: m, + b: m, + } + + // make sure map is shared beforehand + x.a[5] = 6 + assertEqual(t, 6, x.b[5]) + + serde.RegisterType[X]() + + out := assertRoundTrip(t, x) + + // check map is shared after + out.a[7] = 8 + assertEqual(t, 8, out.b[7]) + }) + testReflect(t, "slice backing array", func(t *testing.T) { data := make([]int, 10) for i := range data {