From f2e655a49ef6d1d1a809ef01fcf701d0a7d9ade7 Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Tue, 19 Sep 2023 19:33:13 -0400 Subject: [PATCH] Rewrite scan --- internal/serde/reflect.go | 39 ++-- internal/serde/scan.go | 371 +++++++++++++++++++++----------------- internal/serde/serde.go | 8 +- internal/serde/typemap.go | 3 + serde/serde_test.go | 45 +++++ 5 files changed, 281 insertions(+), 185 deletions(-) diff --git a/internal/serde/reflect.go b/internal/serde/reflect.go index 846868b..05c340a 100644 --- a/internal/serde/reflect.go +++ b/internal/serde/reflect.go @@ -209,7 +209,7 @@ func DeserializeAny(d *Deserializer, t reflect.Type, p unsafe.Pointer) { } func serializePointedAt(s *Serializer, t reflect.Type, p unsafe.Pointer) { - // fmt.Printf("Serialize pointed at: %d (%s)\n", p, t) + fmt.Printf("Serialize pointed at: %d (%s)\n", p, t) // If this is a nil pointer, write it as such. if p == nil { // fmt.Printf("\t=>NIL\n") @@ -219,9 +219,9 @@ func serializePointedAt(s *Serializer, t reflect.Type, p unsafe.Pointer) { id, new := s.assignPointerID(p) serializeVarint(s, int(id)) - // fmt.Printf("\t=>Assigned ID %d\n", id) + fmt.Printf("\t=>Assigned ID %d\n", id) if !new { - // fmt.Printf("\t=>Already seen\n") + fmt.Printf("\t=>Already seen\n") // This exact pointer has already been serialized. Write its ID // and move on. return @@ -231,30 +231,37 @@ func serializePointedAt(s *Serializer, t reflect.Type, p unsafe.Pointer) { // Now, this is pointer that is seen for the first time. // Check the region of this pointer. - r := s.regions.regionOf(p) + r := s.containers.of(p) // If this pointer does not belong to any region or is the container of // the region, write a negative offset to flag it is on its own, and // write its data. - if !r.valid() || (r.offset(p) == 0 && t == r.typ) { - // fmt.Printf("\t=>Is container (region %t)\n", r.Valid()) + if !r.valid() { + fmt.Printf("\t=>Is standalone\n") serializeVarint(s, -1) SerializeAny(s, t, p) return } // The pointer points into a memory region. - offset := r.offset(p) + offset := int(r.offset(p)) serializeVarint(s, offset) - // fmt.Printf("\t=>Offset in container: %d\n", offset) + fmt.Printf("\t=>Offset in container: %d\n", offset) // Write the type of the container. serializeType(s, r.typ) - // fmt.Printf("\t=>Container at: %d (%s)\n", r.Pointer(), r.Type()) + fmt.Printf("\t=>Container at: %d (%s)\n", r.addr, r.typ) // Serialize the parent. - serializePointedAt(s, r.typ, r.start) + + if offset == 0 { + serializeVarint(s, int(id)) + serializeVarint(s, -1) + SerializeAny(s, r.typ, r.addr) + return + } + serializePointedAt(s, r.typ, r.addr) } func deserializePointedAt(d *Deserializer, t reflect.Type) reflect.Value { @@ -264,17 +271,17 @@ func deserializePointedAt(d *Deserializer, t reflect.Type) reflect.Value { // reflect.Value that contains a *T (where T is given by the argument // t). - // fmt.Printf("Deserialize pointed at: %s\n", t) + fmt.Printf("Deserialize pointed at: %s\n", t) ptr, id := d.readPtr() - // fmt.Printf("\t=> ptr=%d, id=%d\n", ptr, id) + fmt.Printf("\t=> ptr=%d, id=%d\n", ptr, id) if ptr != nil || id == 0 { // pointer already seen or nil - // fmt.Printf("\t=>Returning existing data\n") + fmt.Printf("\t=>Returning existing data\n") return reflect.NewAt(t, ptr) } offset := deserializeVarint(d) - // fmt.Printf("\t=>Read offset %d\n", offset) + fmt.Printf("\t=>Read offset %d\n", offset) // Negative offset means this is either a container or a standalone // value. @@ -291,7 +298,7 @@ func deserializePointedAt(d *Deserializer, t reflect.Type) reflect.Value { // then return the pointer itself with an offset. ct := deserializeType(d) - // fmt.Printf("\t=>Container type: %s\n", ct) + fmt.Printf("\t=>Container type: %s\n", ct) // cp is a pointer to the container cp := deserializePointedAt(d, ct) @@ -299,7 +306,7 @@ func deserializePointedAt(d *Deserializer, t reflect.Type) reflect.Value { // Create the pointer with an offset into the container. ep := unsafe.Add(cp.UnsafePointer(), offset) r := reflect.NewAt(t, ep) - d.store(id, ep) + // d.store(id, ep) // fmt.Printf("\t=>Returning id=%d ep=%d\n", id, ep) return r } diff --git a/internal/serde/scan.go b/internal/serde/scan.go index a7f422a..c99a8de 100644 --- a/internal/serde/scan.go +++ b/internal/serde/scan.go @@ -2,203 +2,245 @@ package serde import ( "fmt" + "log" "reflect" "sort" "unsafe" ) -type regions []region +type container struct { + addr unsafe.Pointer + typ reflect.Type +} -func (r *regions) dump() { - fmt.Println("========== MEMORY REGIONS ==========") - fmt.Println("Found", len(*r), "regions.") - for i, r := range *r { - fmt.Printf("#%d: [%d-%d] %d %s\n", i, r.start, r.end, r.size(), r.typ) - } - fmt.Println("====================================") +// Returns true iff at least one byte of the address space is shared between c +// and x (opposite of [disjoints]). +func (c container) overlaps(x container) bool { + return !c.disjoints(x) } -// debug function to ensure the state hold its invariants. panic if they don't. -func (r *regions) validate() { - s := *r - if len(s) == 0 { - return - } +// Returns true iff there is not a single byte of the address space is shared +// between c and x (opposite of [overlaps]). +func (c container) disjoints(x container) bool { + return (uintptr(c.addr)+c.size()) <= uintptr(x.addr) || + (uintptr(x.addr)+x.size()) <= uintptr(c.addr) +} - for i := 0; i < len(s); i++ { - if uintptr(s[i].start) > uintptr(s[i].end) { - panic(fmt.Errorf("region #%d has invalid bounds: start=%d end=%d delta=%d", i, s[i].start, s[i].end, s[i].size())) - } - if s[i].typ == nil { - panic(fmt.Errorf("region #%d has nil type", i)) - } - if i == 0 { - continue - } - if uintptr(s[i].start) <= uintptr(s[i-1].end) { - r.dump() - panic(fmt.Errorf("region #%d and #%d overlap", i-1, i)) - } - } +// Returns true iff x is fully included in c. +func (c container) contains(x container) bool { + return uintptr(x.addr) >= uintptr(c.addr) && uintptr(x.addr)+x.size() <= uintptr(c.addr)+c.size() } -// size computes the amount of bytes coverred by all known regions. -func (r *regions) size() int { - n := 0 - for _, r := range *r { - n += r.size() - } - return n +// Returns true iff c starts before x. +func (c container) before(x container) bool { + return uintptr(c.addr) <= uintptr(x.addr) } -func (r *regions) regionOf(p unsafe.Pointer) region { - // fmt.Printf("Searching regions for %d\n", p) - addr := uintptr(p) - s := *r - if len(s) == 0 { - // fmt.Printf("\t=> No regions\n") - return region{} - } +func (c container) after(x container) bool { + return uintptr(c.addr) > uintptr(x.addr) +} - i := sort.Search(len(s), func(i int) bool { - return uintptr(s[i].start) >= addr - }) - // fmt.Printf("\t=> i = %d\n", i) +// Size in bytes of c. +func (c container) size() uintptr { + return c.typ.Size() +} + +func (c container) isStruct() bool { + return c.typ.Kind() == reflect.Struct +} + +func (c container) isArray() bool { + return c.typ.Kind() == reflect.Array +} - if i < len(s) && uintptr(s[i].start) == addr { - return s[i] +func (c container) valid() bool { + return c.typ != nil +} + +func (c container) has(p unsafe.Pointer) bool { + return uintptr(p) >= uintptr(c.addr) && uintptr(p) < (uintptr(c.addr)+c.size()) +} + +func (c container) offset(p unsafe.Pointer) uintptr { + return uintptr(p) - uintptr(c.addr) +} + +func (c container) compare(p unsafe.Pointer) int { + if c.has(p) { + return 0 + } + if uintptr(p) < uintptr(c.addr) { + return -1 } + return 1 +} - if i > 0 { - i-- +func (c container) String() string { + return fmt.Sprintf("[%d-%d[ %d %s", c.addr, uintptr(c.addr)+c.size(), c.size(), c.typ) +} + +type containers []container + +func (c *containers) dump() { + s := *c + log.Printf("====================== CONTAINERS ======================") + log.Printf("Count: %d", len(s)) + for i, x := range s { + log.Printf("#%d: %s", i, x) } - if uintptr(s[i].start) > addr || uintptr(s[i].end) < addr { - return region{} + log.Printf("========================================================") +} + +func (c *containers) of(p unsafe.Pointer) container { + s := *c + i, found := sort.Find(len(s), func(i int) int { + return s[i].compare(p) + }) + if !found { + return container{} } return s[i] - } -func (r *regions) add(t reflect.Type, start unsafe.Pointer) { - size := t.Size() - if size == 0 { +func (c *containers) add(t reflect.Type, p unsafe.Pointer) { + if t.Size() == 0 { return } - end := unsafe.Add(start, size-1) + if p == nil { + panic("tried to add nil pointer") + } + switch t.Kind() { + case reflect.Struct, reflect.Array: + default: + panic(fmt.Errorf("tried to add non struct or array container: %s (%s)", t, t.Kind())) + } - // fmt.Printf("Adding [%d-%d[ %d %s\n", startAddr, endAddr, endAddr-startAddr, t) - startSize := r.size() defer func() { - //r.Dump() - r.validate() - endSize := r.size() - if endSize < startSize { - panic(fmt.Errorf("regions shrunk (%d -> %d)", startSize, endSize)) + r := recover() + if r != nil { + c.dump() + panic(r) } }() - s := *r - - if len(s) == 0 { - *r = append(s, region{ - start: start, - end: end, - typ: t, - }) - return + x := container{addr: p, typ: t} + i := c.insert(x) + c.fixup(i) + if i > 0 { + c.fixup(i - 1) } - // Invariants: - // (1) len(s) > 0 - // (2) s is sorted by start address - // (3) s contains no overlapping range + c.dump() +} - i := sort.Search(len(s), func(i int) bool { - return uintptr(s[i].start) >= uintptr(start) - }) - //fmt.Println("\ti =", i) +func (c *containers) fixup(i int) { + s := *c - if i < len(s) && uintptr(s[i].start) == uintptr(start) { - // Pointer is present in the set. If it's contained in the - // region that already exists, we are done. - if uintptr(s[i].end) >= uintptr(end) { - return - } + log.Println("fixup:", i) - // Otherwise extend the region. - s[i].end = end - s[i].typ = t + if i == len(s)-1 { + return + } - // To maintain invariant (3), keep extending the selected region - // until it becomes the last one or the next range is disjoint. - r.extend(i) + x := s[i] + next := s[i+1] + + if !x.overlaps(next) { + // Not at least an overlap, nothing to do. + log.Println("=> no overlap") return } - // Pointer did not point to the beginning of a region. - // Attempt to grow the previous region. - if i > 0 { - if uintptr(start) <= uintptr(s[i-1].end) { - if uintptr(end) >= uintptr(s[i-1].end) { - s[i-1].end = end - r.extend(i - 1) - } + if x.contains(next) { + log.Println("=>contains") + if x.isStruct() { + // Struct fully contains next element. Remove the next + // element and nothing else to do. + c.remove(i + 1) return } + // Array fully contains next container. Nothing to do + return } - // Attempt to grow the next region. - if i+1 < len(s) { - if uintptr(end) >= uintptr(s[i+1].start) { - s[i+1].start = start - if uintptr(end) >= uintptr(s[i+1].end) { - s[i+1].end = end - } - s[i+1].typ = t - r.extend(i + 1) - return - } + // There is some overlap. The only thing we accept to merge are arrays + // of the same type. + if !x.isArray() || !next.isArray() || x.typ.Elem() != next.typ.Elem() { + panic(fmt.Errorf("only support merging arrays of same type (%s, %s)", x.typ, next.typ)) } - // Just insert it. - s = append(s, region{}) - copy(s[i+1:], s[i:]) - s[i] = region{start: start, end: end, typ: t} - *r = s - r.extend(i) + c.merge(i) + + // Do it again in case the merge connected new areas. + c.fixup(i) } -// extend attempts to grow region i by swallowing any region after it, as long -// as it would make one continous region. It is used after a modification of -// region i to maintain the invariants. -func (r *regions) extend(i int) { - s := *r - grown := 0 - for next := i + 1; next < len(s) && uintptr(s[i].end) >= uintptr(s[next].start); next++ { - s[i].end = s[next].end - grown++ +func (c *containers) merge(i int) { + s := *c + a := s[i] + b := s[i+1] + + elemSize := a.typ.Elem().Size() + + // sanity check alignment + if (uintptr(b.addr)-uintptr(a.addr))%uintptr(elemSize) != 0 { + panic("overlapping arrays aren't aligned") } - copy(s[i+1:], s[i+1+grown:]) - *r = s[:len(s)-grown] -} -type region struct { - start unsafe.Pointer // inclusive - end unsafe.Pointer // inclusive - typ reflect.Type -} + // new element count of the array + newlen := int((uintptr(b.addr)-uintptr(a.addr))/elemSize) + b.typ.Len() + s[i].typ = reflect.ArrayOf(newlen, a.typ.Elem()) -func (r region) valid() bool { - return r.typ != nil + c.remove(i + 1) } -func (r region) size() int { - return int(uintptr(r.end)-uintptr(r.start)) + 1 +func (c *containers) remove(i int) { + before := len(*c) + s := *c + copy(s[i:], s[i+1:]) + *c = s[:len(s)-1] + after := len(*c) + if after >= before { + panic("did not remove anything") + } } -func (r region) offset(p unsafe.Pointer) int { - return int(uintptr(p) - uintptr(r.start)) +func (c *containers) insert(x container) int { + log.Print("inserting ", x) + *c = append(*c, container{}) + s := *c + // Find where to insert the new container. By start address first, then + // by decreasing size (so that the bigger container comes before). + i := sort.Search(len(s)-1, func(i int) bool { + if s[i].after(x) { + return true + } + if s[i].addr == x.addr { + return x.size() > s[i].size() + } + return false + }) + fmt.Println("i=", i) + copy(s[i+1:], s[i:]) + s[i] = x + + // Debug assertion. + for i, x := range s { + if i == 0 { + continue + } + if uintptr(x.addr) < uintptr(s[i-1].addr) { + panic("bad address order after insert") + } + if uintptr(x.addr) == uintptr(s[i-1].addr) { + if x.size() > s[i-1].size() { + panic("invalid size order after insert") + } + } + } + + return i } // scan the value of type t at address p recursively to build up the serializer @@ -222,25 +264,8 @@ func scan(s *Serializer, t reflect.Type, p unsafe.Pointer) { switch t.Kind() { case reflect.Invalid: panic("handling invalid reflect.Type") - case reflect.Bool, - reflect.Int, - reflect.Int8, - reflect.Int16, - reflect.Int32, - reflect.Int64, - reflect.Uint, - reflect.Uint8, - reflect.Uint16, - reflect.Uint32, - reflect.Uint64, - reflect.Uintptr, - reflect.Float32, - reflect.Float64, - reflect.Complex64, - reflect.Complex128: - s.regions.add(t, p) case reflect.Array: - s.regions.add(t, p) + s.containers.add(t, p) et := t.Elem() es := int(et.Size()) for i := 0; i < t.Len(); i++ { @@ -259,7 +284,7 @@ func scan(s *Serializer, t reflect.Type, p unsafe.Pointer) { // Create a new type for the backing array. xt := reflect.ArrayOf(sr.Cap(), t.Elem()) - s.regions.add(xt, ep) + s.containers.add(xt, ep) for i := 0; i < sr.Len(); i++ { ep := unsafe.Add(ep, es*i) scan(s, et, ep) @@ -278,7 +303,7 @@ func scan(s *Serializer, t reflect.Type, p unsafe.Pointer) { scan(s, et, eptr) case reflect.Struct: - s.regions.add(t, p) + s.containers.add(t, p) n := t.NumField() for i := 0; i < n; i++ { f := t.Field(i) @@ -293,7 +318,7 @@ func scan(s *Serializer, t reflect.Type, p unsafe.Pointer) { str := *(*string)(p) sp := unsafe.StringData(str) xt := reflect.ArrayOf(len(str), byteT) - s.regions.add(xt, unsafe.Pointer(sp)) + s.containers.add(xt, unsafe.Pointer(sp)) case reflect.Map: m := r.Elem() if m.IsNil() || m.Len() == 0 { @@ -311,7 +336,23 @@ func scan(s *Serializer, t reflect.Type, p unsafe.Pointer) { vp := (*iface)(unsafe.Pointer(&v)).ptr scan(s, vt, vp) } - + case reflect.Bool, + reflect.Int, + reflect.Int8, + reflect.Int16, + reflect.Int32, + reflect.Int64, + reflect.Uint, + reflect.Uint8, + reflect.Uint16, + reflect.Uint32, + reflect.Uint64, + reflect.Uintptr, + reflect.Float32, + reflect.Float64, + reflect.Complex64, + reflect.Complex128: + // nothing to do default: // TODO: // Chan diff --git a/internal/serde/serde.go b/internal/serde/serde.go index 7ee47f2..76883f0 100644 --- a/internal/serde/serde.go +++ b/internal/serde/serde.go @@ -37,7 +37,7 @@ func Serialize(x any) []byte { // scan dirties s.scanptrs, so clean it up. clear(s.scanptrs) - // s.regions.Dump() + s.containers.dump() SerializeAny(s, t, p) return s.b @@ -80,7 +80,7 @@ func (d *Deserializer) readPtr() (unsafe.Pointer, sID) { func (d *Deserializer) store(i sID, p unsafe.Pointer) { if d.ptrs[i] != nil { - panic(fmt.Errorf("trying to overwirte known ID %d with %p", i, p)) + panic(fmt.Errorf("trying to overwrite known ID %d with %p", i, p)) } d.ptrs[i] = p } @@ -113,8 +113,8 @@ func (d *Deserializer) store(i sID, p unsafe.Pointer) { // shared memory. Only outermost containers are serialized. All pointers either // point to a container, or an offset into that container. type Serializer struct { - ptrs map[unsafe.Pointer]sID - regions regions + ptrs map[unsafe.Pointer]sID + containers containers // TODO: move out. just used temporarily by scan scanptrs map[reflect.Value]struct{} diff --git a/internal/serde/typemap.go b/internal/serde/typemap.go index ce8d5b8..bd4cb5c 100644 --- a/internal/serde/typemap.go +++ b/internal/serde/typemap.go @@ -96,6 +96,9 @@ func (m *TypeMap) Add(t reflect.Type) { for i := 0; i < t.NumField(); i++ { m.Add(t.Field(i).Type) } + case reflect.String: + // strings are presented as [X]byte + m.Add(byteT) } } diff --git a/serde/serde_test.go b/serde/serde_test.go index cf58894..b39913c 100644 --- a/serde/serde_test.go +++ b/serde/serde_test.go @@ -286,6 +286,51 @@ func TestReflectSharing(t *testing.T) { assertEqual(t, 11, out.s3[0]) }) + testReflect(t, "slice backing array with set capacities", func(t *testing.T) { + data := make([]int, 10) + for i := range data { + data[i] = i + } + + type X struct { + s1 []int + s2 []int + s3 []int + } + + orig := X{ + s1: data[0:3:3], + s2: data[2:8:8], + s3: data[7:10:10], + } + assertEqual(t, []int{0, 1, 2}, orig.s1) + assertEqual(t, []int{2, 3, 4, 5, 6, 7}, orig.s2) + assertEqual(t, []int{7, 8, 9}, orig.s3) + + assertEqual(t, 3, cap(orig.s1)) + assertEqual(t, 3, len(orig.s1)) + assertEqual(t, 6, cap(orig.s2)) + assertEqual(t, 6, len(orig.s2)) + assertEqual(t, 3, cap(orig.s3)) + assertEqual(t, 3, len(orig.s3)) + + serde.RegisterType[X]() + + out := assertRoundTrip(t, orig) + + // verify that the initial arrays were shared + orig.s1[2] = 42 + assertEqual(t, 42, orig.s2[0]) + orig.s2[5] = 11 + assertEqual(t, 11, orig.s3[0]) + + // verify the result's underlying array is shared + out.s1[2] = 42 + assertEqual(t, 42, out.s2[0]) + out.s2[5] = 11 + assertEqual(t, 11, out.s3[0]) + }) + testReflect(t, "struct fields extra pointers", func(t *testing.T) { type A struct { X, Y int