diff --git a/compiler/compile.go b/compiler/compile.go index 6e83310..e58a0f6 100644 --- a/compiler/compile.go +++ b/compiler/compile.go @@ -232,18 +232,6 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr gen := &ast.File{ Name: ast.NewIdent(p.Name), } - gen.Decls = append(gen.Decls, &ast.GenDecl{ - Tok: token.IMPORT, - Specs: []ast.Spec{ - &ast.ImportSpec{ - Path: &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(coroutinePackage)}, - }, - // Add unsafe for unsafe.Sizeof(). - &ast.ImportSpec{ - Path: &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote("unsafe")}, - }, - }, - }) ssaFnsByDecl := map[ast.Node]*ssa.Function{} colorsByDecl := map[ast.Node]*types.Signature{} @@ -293,10 +281,8 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr } } - log.Print("building type register init function") - if err := generateTypesInit(c.fset, gen, p); err != nil { - return err - } + // Find all the required imports for this file. + gen = addImports(p, gen) packageDir := filepath.Dir(p.GoFiles[0]) outputPath := filepath.Join(packageDir, c.outputFilename) @@ -313,6 +299,50 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr return nil } +func addImports(p *packages.Package, gen *ast.File) *ast.File { + imports := map[string]string{} + + ast.Inspect(gen, func(n ast.Node) bool { + switch x := n.(type) { + case *ast.SelectorExpr: + ident, ok := x.X.(*ast.Ident) + if !ok { + break + } + obj := p.TypesInfo.ObjectOf(ident) + pkgname, ok := obj.(*types.PkgName) + if !ok { + break + } + + pkg := pkgname.Imported().Path() + + if existing, ok := imports[ident.Name]; ok && existing != pkg { + fmt.Println("existing:", ident.Name, existing) + fmt.Println("new:", pkg) + panic("conflicting imports") + } + imports[ident.Name] = pkg + } + return true + }) + + importspecs := make([]ast.Spec, 0, len(imports)) + for name, path := range imports { + importspecs = append(importspecs, &ast.ImportSpec{ + Name: ast.NewIdent(name), + Path: &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(path)}, + }) + } + + gen.Decls = append([]ast.Decl{&ast.GenDecl{ + Tok: token.IMPORT, + Specs: importspecs, + }}, gen.Decls...) + + return gen +} + type scope struct { colors map[ast.Node]*types.Signature // Index used to generate unique object identifiers within the scope of a @@ -620,5 +650,6 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body gen.List = append(gen.List, &ast.ReturnStmt{}) } } + return gen } diff --git a/compiler/serde.go b/compiler/serde.go deleted file mode 100644 index 13393f1..0000000 --- a/compiler/serde.go +++ /dev/null @@ -1,346 +0,0 @@ -package compiler - -import ( - "fmt" - "go/ast" - "go/token" - "go/types" - "math" - "path" - "sort" - "strconv" - "strings" - - "golang.org/x/tools/go/ast/astutil" - "golang.org/x/tools/go/packages" -) - -const publicSerdePackage = "github.com/stealthrocket/coroutine/serde" - -// GenerateTypesInit searches pkg for types that serde can handle and append an -// init() function to the provided ast to register them on package load. -func generateTypesInit(fset *token.FileSet, gen *ast.File, pkg *packages.Package) error { - // Prepare the imports map with the imports already in the AST. - w := typesWalker{ - imports: makeImportsMap(fset, gen), - gen: gen, - seenTypes: make(map[string]struct{}), - seenPkgs: make(map[string]struct{}), - from: pkg.Types.Path(), - } - - w.Walk(pkg) - - // Found no type. - if len(w.types) == 0 { - return nil - } - - coropkg := ast.NewIdent(w.addImport(publicSerdePackage)) - - sort.Slice(w.newimports, func(i, j int) bool { - return w.newimports[i][0] < w.newimports[j][0] - }) - - for _, imp := range w.newimports { - added := astutil.AddNamedImport(fset, gen, imp[0], imp[1]) - if !added { - panic(fmt.Errorf(`import '%s' "%s" was supposed to be missing`, imp[0], imp[1])) - } - } - - fun := &ast.FuncDecl{ - Name: ast.NewIdent("init"), - Type: &ast.FuncType{}, - Body: &ast.BlockStmt{}, - } - - sort.Strings(w.types) - for _, t := range w.types { - var typeExpr ast.Expr - pkg, name, found := strings.Cut(t, ".") - if found { - // eg: coroutine.RegisterType[syscall.RtGenmsg]() - typeExpr = &ast.SelectorExpr{ - X: ast.NewIdent(pkg), - Sel: ast.NewIdent(name), - } - } else { - // eg: coroutine.RegisterType[MyStruct]() - typeExpr = ast.NewIdent(pkg) - } - - fun.Body.List = append(fun.Body.List, &ast.ExprStmt{ - X: &ast.CallExpr{ - Fun: &ast.IndexListExpr{ - X: &ast.SelectorExpr{ - X: coropkg, - Sel: ast.NewIdent("RegisterType"), - }, - Indices: []ast.Expr{ - typeExpr, - }, - }, - }, - }) - } - - gen.Decls = append(gen.Decls, fun) - - return nil -} - -type typesWalker struct { - imports importsmap - gen *ast.File - from string - seenTypes map[string]struct{} - seenPkgs map[string]struct{} - - // outputs - types []string - newimports [][2]string // name, path -} - -func (w *typesWalker) Walk(p *packages.Package) { - if _, ok := w.seenPkgs[p.ID]; ok { - return - } - w.seenPkgs[p.ID] = struct{}{} - - // Walk the type definitions in this package. - for _, o := range p.TypesInfo.Defs { - if o == nil { - continue - } - t := o.Type() - if !required(t) || !supported(w.from, t) { - continue - } - - w.add(t) - } - - // Walk type instances from generics in this package. - for _, i := range p.TypesInfo.Instances { - if !required(i.Type) || !supported(w.from, i.Type) { - continue - } - w.add(i.Type) - } - - // Recurse into imports. - for _, i := range p.Imports { - w.Walk(i) - } -} - -// Import the given package path, and return the actual name used in case of -// conflict. -func (w *typesWalker) addImport(path string) string { - new, importname := w.imports.Add(path) - if new { - w.newimports = append(w.newimports, [2]string{importname, path}) - } - return importname -} - -func (w *typesWalker) add(t types.Type) { - typename := types.TypeString(t, nil) - if _, ok := w.seenTypes[typename]; ok { - return - } - - path, name, found := cutLast(typename, ".") - if !found { - name = path - } else if path != w.from { - importname := w.addImport(path) - name = importname + "." + name - } - - w.types = append(w.types, name) - w.seenTypes[typename] = struct{}{} -} - -func makeImportsMap(fset *token.FileSet, gen *ast.File) importsmap { - m := importsmap{} - - importgroups := astutil.Imports(fset, gen) - for _, g := range importgroups { - for _, imp := range g { - path, err := strconv.Unquote(imp.Path.Value) - if err != nil { - panic(fmt.Errorf("package in import not quoted: %w", err)) - } - name := importSpecName(imp) - m.Ensure(name, path) - } - } - - return m -} - -func importSpecName(imp *ast.ImportSpec) string { - if imp.Name != nil { - return imp.Name.Name - } - // guess from the path - path, err := strconv.Unquote(imp.Path.Value) - if err != nil { - panic(fmt.Errorf("package in import not quoted: %w", err)) - } - lastSlash := strings.LastIndex(path, "/") - if lastSlash == -1 { - return path - } - return path[lastSlash+1:] -} - -func supported(from string, t types.Type) bool { - return supportedType(t) && supportedImport(from, t) -} - -func required(t types.Type) bool { - // we only require named types to be registered - _, ok := t.(*types.Named) - return ok -} - -func supportedImport(from string, t types.Type) bool { - typename := types.TypeString(t, nil) - path, name, found := cutLast(typename, ".") - if !found { - name = path - path = "" - } - - if !validPath(path) { - return false - } - if !found { - return true - } - if strings.Contains(path, "internal/") && !strings.HasPrefix(path, from+"/internal/") { - return false - } - if !public(name) { - return false - } - return true -} - -func validPath(s string) bool { - return !strings.ContainsAny(s, "[]<>{}* ") -} - -func supportedType(t types.Type) bool { - switch x := t.(type) { - case *types.Signature: - // don't know how to serialize functions - return false - case *types.Named: - // uninstantiated type parameter - if x.Origin() != t { - return false - } - - tp := x.TypeParams() - if tp != nil { - // Had type parameters at some point. need to check if - // they are instantiated. - if x.TypeArgs().Len() != tp.Len() { - return false - } - } - - return supportedType(t.Underlying()) - case *types.Interface: - // TODO: should this be relaxed? - return false - case *types.TypeParam: - return false - case *types.Chan: - return false - case *types.Pointer: - return supportedType(x.Elem()) - case *types.Array: - return supportedType(x.Elem()) - case *types.Slice: - return supportedType(x.Elem()) - case *types.Map: - return supportedType(x.Elem()) && supportedType(x.Key()) - case *types.Basic: - switch x.Kind() { - case types.UntypedBool, - types.UntypedInt, - types.UntypedRune, - types.UntypedFloat, - types.UntypedComplex, - types.UntypedString, - types.UntypedNil, - types.Invalid: - return false - } - } - return true -} - -type importsmap struct { - byName map[string]string - byPath map[string]string -} - -// Add a new import and assign it an imported name, avoiding clashes. Return -// true if a new import was added. -func (m *importsmap) Add(p string) (bool, string) { - name, ok := m.byPath[p] - if ok { - return false, name - } - if m.byName == nil { - m.byName = make(map[string]string) - m.byPath = make(map[string]string) - } - - original := path.Base(p) - for i := 0; i <= math.MaxInt; i++ { - name := original - if i > 0 { - name = fmt.Sprintf("%s_%d", original, i) - } - _, ok = m.byName[name] - if ok { // name clash - continue - } - m.byName[name] = p - m.byPath[p] = name - return true, name - } - - panic("exhausted suffixes") -} - -// Ensure path is imported, and its import name matches the one provided. Panics -// otherwise. -func (m *importsmap) Ensure(name, path string) { - // Since the goal is to panic, it's ok to potentially add an import with - // the wrong name. - _, importname := m.Add(path) - if importname != name { - panic(fmt.Errorf("import package '%s' is imported as '%s'; expected '%s'", path, importname, name)) - } -} - -func cutLast(s, sep string) (before, after string, found bool) { - i := strings.LastIndex(s, sep) - if i < 0 { - return s, "", false - } - return string(s[:i]), string(s[i+1:]), true -} - -func public(name string) bool { - c := name[0] // want to panic if len is 0 - return c >= 'A' && c <= 'Z' -} diff --git a/compiler/serde_test.go b/compiler/serde_test.go deleted file mode 100644 index adc924c..0000000 --- a/compiler/serde_test.go +++ /dev/null @@ -1,138 +0,0 @@ -package compiler_test - -import ( - "fmt" - "log/slog" - "os" - "testing" - - "github.com/google/go-cmp/cmp" - testdata "github.com/stealthrocket/coroutine/compiler/testdata/serde" - "github.com/stealthrocket/coroutine/internal/serde" -) - -func enableDebugLogs() { - removeTime := func(groups []string, a slog.Attr) slog.Attr { - if a.Key == slog.TimeKey && len(groups) == 0 { - return slog.Attr{} - } - return a - } - - var programLevel = new(slog.LevelVar) - h := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ - Level: programLevel, - ReplaceAttr: removeTime, - }) - slog.SetDefault(slog.New(h)) - programLevel.Set(slog.LevelDebug) -} - -func TestStruct1Empty(t *testing.T) { - enableDebugLogs() - - s := testdata.Struct1{} - - roundtripStruct1(t, s) -} - -func TestStruct1Iface(t *testing.T) { - enableDebugLogs() - - for i, s := range []testdata.Struct1{ - {Iface: int(42)}, - {Iface: true}, - {Iface: "hello"}, - {Iface: testdata.Inner{ - A: 111, - B: "test1", - }}, - } { - s := s - t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - roundtripStruct1(t, s) - }) - } - -} - -func makes1() testdata.Struct1 { - str := "pointed at" - myint := 999 - myintptr := &myint - - bounce1 := &testdata.Bounce{ - Value: 1, - } - bounce2 := &testdata.Bounce{ - Value: 2, - } - bounce1.Other = bounce2 - bounce2.Other = bounce1 - - return testdata.Struct1{ - Str: "hello", - Int: 42, - Ints: []int64{1, 2, 3}, - - Bool: true, - Uint64: 11, - Uint32: 22, - Uint16: 33, - Uint8: 44, - Int64: -11, - Int32: -22, - Int16: -33, - Int8: -44, - Float32: 42.11, - Float64: 420.110, - Complex64: 42 + 11i, - Complex128: 420 + 110i, - - FooSer: testdata.NewFoo(), - StrPtr: &str, - IntPtr: &myint, - IntPtrPtr: &myintptr, - - InnerV: testdata.Inner{ - A: 53, - B: "test", - }, - - InnerP: &testdata.Inner{ - A: 99, - B: "hello", - }, - - Bounce1: bounce1, - - MapStrStr: map[string]string{"one": "un", "two": "deux", "three": "trois"}, - } -} - -func TestStruct1(t *testing.T) { - enableDebugLogs() - - s := makes1() - roundtripStruct1(t, s) -} - -func roundtripStruct1(t *testing.T, s testdata.Struct1) { - t.Helper() - - serde.RegisterType[testdata.Struct1]() - b := serde.Serialize(s) - s2, b := serde.Deserialize(b) - - opts := []cmp.Option{ - cmp.AllowUnexported(testdata.Foo{}), - } - - if diff := cmp.Diff(s, s2, opts...); diff != "" { - t.Fatalf("mismatch (-want +got):\n%s", diff) - } - - if len(b) > 0 { - t.Fatalf("leftover bytes: %d", len(b)) - } -} diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index c95824e..6d2c738 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -5,14 +5,9 @@ package testdata import ( - "github.com/stealthrocket/coroutine" - serde "github.com/stealthrocket/coroutine/serde" - runtime "runtime" - sync "sync" - atomic "sync/atomic" - syscall "syscall" + coroutine "github.com/stealthrocket/coroutine" time "time" - "unsafe" + unsafe "unsafe" ) func SomeFunctionThatShouldExistInTheCompiledFile() { @@ -3127,97 +3122,3 @@ func b(v int) (_ int) { } return } -func init() { - serde.RegisterType[atomic.Bool]() - serde.RegisterType[atomic.Int32]() - serde.RegisterType[atomic.Int64]() - serde.RegisterType[atomic.Uint32]() - serde.RegisterType[atomic.Uint64]() - serde.RegisterType[atomic.Uintptr]() - serde.RegisterType[atomic.Value]() - serde.RegisterType[runtime.BlockProfileRecord]() - serde.RegisterType[runtime.Frame]() - serde.RegisterType[runtime.Frames]() - serde.RegisterType[runtime.Func]() - serde.RegisterType[runtime.MemProfileRecord]() - serde.RegisterType[runtime.MemStats]() - serde.RegisterType[runtime.PanicNilError]() - serde.RegisterType[runtime.Pinner]() - serde.RegisterType[runtime.StackRecord]() - serde.RegisterType[runtime.TypeAssertionError]() - serde.RegisterType[sync.Cond]() - serde.RegisterType[sync.Map]() - serde.RegisterType[sync.Mutex]() - serde.RegisterType[sync.Once]() - serde.RegisterType[sync.Pool]() - serde.RegisterType[sync.RWMutex]() - serde.RegisterType[sync.WaitGroup]() - serde.RegisterType[syscall.BpfHdr]() - serde.RegisterType[syscall.BpfInsn]() - serde.RegisterType[syscall.BpfProgram]() - serde.RegisterType[syscall.BpfStat]() - serde.RegisterType[syscall.BpfVersion]() - serde.RegisterType[syscall.Cmsghdr]() - serde.RegisterType[syscall.Credential]() - serde.RegisterType[syscall.Dirent]() - serde.RegisterType[syscall.Errno]() - serde.RegisterType[syscall.Fbootstraptransfer_t]() - serde.RegisterType[syscall.FdSet]() - serde.RegisterType[syscall.Flock_t]() - serde.RegisterType[syscall.Fsid]() - serde.RegisterType[syscall.Fstore_t]() - serde.RegisterType[syscall.ICMPv6Filter]() - serde.RegisterType[syscall.IPMreq]() - serde.RegisterType[syscall.IPv6MTUInfo]() - serde.RegisterType[syscall.IPv6Mreq]() - serde.RegisterType[syscall.IfData]() - serde.RegisterType[syscall.IfMsghdr]() - serde.RegisterType[syscall.IfaMsghdr]() - serde.RegisterType[syscall.IfmaMsghdr]() - serde.RegisterType[syscall.IfmaMsghdr2]() - serde.RegisterType[syscall.Inet4Pktinfo]() - serde.RegisterType[syscall.Inet6Pktinfo]() - serde.RegisterType[syscall.InterfaceAddrMessage]() - serde.RegisterType[syscall.InterfaceMessage]() - serde.RegisterType[syscall.InterfaceMulticastAddrMessage]() - serde.RegisterType[syscall.Iovec]() - serde.RegisterType[syscall.Kevent_t]() - serde.RegisterType[syscall.Linger]() - serde.RegisterType[syscall.Log2phys_t]() - serde.RegisterType[syscall.Msghdr]() - serde.RegisterType[syscall.ProcAttr]() - serde.RegisterType[syscall.Radvisory_t]() - serde.RegisterType[syscall.RawSockaddr]() - serde.RegisterType[syscall.RawSockaddrAny]() - serde.RegisterType[syscall.RawSockaddrDatalink]() - serde.RegisterType[syscall.RawSockaddrInet4]() - serde.RegisterType[syscall.RawSockaddrInet6]() - serde.RegisterType[syscall.RawSockaddrUnix]() - serde.RegisterType[syscall.Rlimit]() - serde.RegisterType[syscall.RouteMessage]() - serde.RegisterType[syscall.RtMetrics]() - serde.RegisterType[syscall.RtMsghdr]() - serde.RegisterType[syscall.Rusage]() - serde.RegisterType[syscall.Signal]() - serde.RegisterType[syscall.SockaddrDatalink]() - serde.RegisterType[syscall.SockaddrInet4]() - serde.RegisterType[syscall.SockaddrInet6]() - serde.RegisterType[syscall.SockaddrUnix]() - serde.RegisterType[syscall.SocketControlMessage]() - serde.RegisterType[syscall.Stat_t]() - serde.RegisterType[syscall.Statfs_t]() - serde.RegisterType[syscall.SysProcAttr]() - serde.RegisterType[syscall.Termios]() - serde.RegisterType[syscall.Timespec]() - serde.RegisterType[syscall.Timeval]() - serde.RegisterType[syscall.Timeval32]() - serde.RegisterType[syscall.WaitStatus]() - serde.RegisterType[time.Duration]() - serde.RegisterType[time.Location]() - serde.RegisterType[time.Month]() - serde.RegisterType[time.ParseError]() - serde.RegisterType[time.Ticker]() - serde.RegisterType[time.Time]() - serde.RegisterType[time.Timer]() - serde.RegisterType[time.Weekday]() -} diff --git a/coroutine_durable.go b/coroutine_durable.go index 1310979..dac3f8c 100644 --- a/coroutine_durable.go +++ b/coroutine_durable.go @@ -12,10 +12,6 @@ type serializedCoroutine struct { resume bool } -func init() { - serde.RegisterType[serializedCoroutine]() -} - // Context is passed to a coroutine and flows through all // functions that Yield (or could yield). type Context[R, S any] struct { diff --git a/internal/serde/codec.go b/internal/serde/codec.go index c650bb6..0462e89 100644 --- a/internal/serde/codec.go +++ b/internal/serde/codec.go @@ -6,7 +6,7 @@ import ( ) func init() { - RegisterTypeWithSerde[time.Time](serializeTime, deserializeTime) + RegisterSerde[time.Time](serializeTime, deserializeTime) } func serializeTime(s *Serializer, x *time.Time) error { diff --git a/internal/serde/typemap.go b/internal/serde/typemap.go index 967fc42..98c4528 100644 --- a/internal/serde/typemap.go +++ b/internal/serde/typemap.go @@ -9,66 +9,18 @@ import ( // Global type register. var Types *TypeMap = NewTypeMap() -// RegisterType into the global register to make it known to the serialization -// system. It is only required to register named types. -// -// coroc usually generates calls to this function. It should be called in an -// init function so that types are always registered in the same order. -// -// Named types are recursively added. -func RegisterType[T any]() { - registerType[T](Types) -} - -// Scan T and add all named types to the type map. -func registerType[T any](tm *TypeMap) { - t := reflect.TypeOf((*T)(nil)).Elem() - tm.Add(t) - addNamedTypes(tm, make(set[reflect.Type]), t) -} - -func addNamedTypes(tm *TypeMap, seen set[reflect.Type], t reflect.Type) { - if seen.has(t) { - return - } - seen.add(t) - if named(t) { - tm.Add(t) - } - switch t.Kind() { - case reflect.Struct: - for i := 0; i < t.NumField(); i++ { - f := t.Field(i) - addNamedTypes(tm, seen, f.Type) - } - case reflect.Func: - for i := 0; i < t.NumIn(); i++ { - addNamedTypes(tm, seen, t.In(i)) - } - for i := 0; i < t.NumOut(); i++ { - addNamedTypes(tm, seen, t.Out(i)) - } - case reflect.Map: - addNamedTypes(tm, seen, t.Key()) - fallthrough - case reflect.Slice, reflect.Array, reflect.Pointer: - addNamedTypes(tm, seen, t.Elem()) - } -} - -// RegisterTypeWithSerde is the same as [RegisterType] but assigns serialization -// and deserialization for this type. -func RegisterTypeWithSerde[T any]( +// RegisterSerde assigns custom functions to serialize and deserialize a +// specific type. +func RegisterSerde[T any]( serializer func(*Serializer, *T) error, deserializer func(*Deserializer, *T) error) { - registerTypeWithSerde[T](Types, serializer, deserializer) + registerSerde[T](Types, serializer, deserializer) } -func registerTypeWithSerde[T any](tm *TypeMap, +func registerSerde[T any](tm *TypeMap, serializer func(*Serializer, *T) error, deserializer func(*Deserializer, *T) error) { - registerType[T](tm) t := reflect.TypeOf((*T)(nil)).Elem() s := func(s *Serializer, p unsafe.Pointer) { @@ -90,6 +42,7 @@ type SerializerFn func(*Serializer, unsafe.Pointer) type DeserializerFn func(d *Deserializer, p unsafe.Pointer) type serde struct { + id int ser SerializerFn des DeserializerFn } @@ -112,22 +65,15 @@ func (m *TypeMap) Attach(t reflect.Type, ser SerializerFn, des DeserializerFn) { panic("both serializer and deserializer need to be provided") } - _, ok := m.cache.GetK(t) - if !ok { - panic(fmt.Errorf("register type %s before attaching serde", t)) - } - - m.serdes[t] = serde{ser: ser, des: des} -} - -func (m *TypeMap) Add(t reflect.Type) { - if _, ok := m.cache.GetK(t); ok { - return + s, exists := m.serdes[t] + if !exists { + s.id = len(m.custom) + m.custom = append(m.custom, t) } + s.ser = ser + s.des = des - x := &typeinfo{kind: typeCustom, val: len(m.custom)} - m.custom = append(m.custom, t) - m.cache.Add(t, x) + m.serdes[t] = s } func (m *TypeMap) serdeOf(x reflect.Type) (serde, bool) { @@ -150,22 +96,12 @@ func (m *doublemap[K, V]) GetV(v V) (K, bool) { return k, ok } -func (m *doublemap[K, V]) Add(k K, v V) { +func (m *doublemap[K, V]) Add(k K, v V) V { if m.fromK == nil { m.fromK = make(map[K]V) m.fromV = make(map[V]K) } m.fromK[k] = v m.fromV[v] = k -} - -type set[T comparable] map[T]struct{} - -func (s set[T]) has(x T) bool { - _, ok := s[x] - return ok -} - -func (s set[T]) add(x T) { - s[x] = struct{}{} + return v } diff --git a/internal/serde/types.go b/internal/serde/types.go index f3067be..1b9d95f 100644 --- a/internal/serde/types.go +++ b/internal/serde/types.go @@ -25,6 +25,10 @@ const ( // to get right, and we will be revamping serde anyway. type typeinfo struct { kind typekind + + // Only present for named types. See documentation of [namedTypeOffset]. + offset namedTypeOffset + // - typeCustom uses this field to store the index in the typemap of the // custom type it represents. // - typeBasic uses it to store the reflect.Kind it represents. @@ -41,6 +45,10 @@ type typeinfo struct { } func (t *typeinfo) reflectType(tm *TypeMap) reflect.Type { + if t.offset != 0 { + return typeForOffset(t.offset) + } + switch t.kind { case typeNone: return nil @@ -142,13 +150,27 @@ func (m *TypeMap) ToType(t reflect.Type) *typeinfo { } if t == nil { - return &typeinfo{kind: typeNone} + return m.cache.Add(t, &typeinfo{kind: typeNone}) } + var offset namedTypeOffset if named(t) { - panic(fmt.Errorf("named type should be registered (%s)", t)) + offset = offsetForType(t) + // Technically types with an offset do not need more information + // than that. However for debugging purposes also generate the + // rest of the type information. } + if s, ok := m.serdes[t]; ok { + return m.cache.Add(t, &typeinfo{ + kind: typeCustom, + offset: offset, + val: s.id, + }) + } + + ti := &typeinfo{offset: offset} + m.cache.Add(t, ti) // add now for recursion switch t.Kind() { case reflect.Invalid: panic("can't handle reflect.Invalid") @@ -170,40 +192,29 @@ func (m *TypeMap) ToType(t reflect.Type) *typeinfo { reflect.Complex128, reflect.String, reflect.Interface: - return &typeinfo{ - kind: typeBasic, - val: int(t.Kind()), - } + ti.kind = typeBasic + ti.val = int(t.Kind()) case reflect.Array: - return &typeinfo{ - kind: typeArray, - elem: m.ToType(t.Elem()), - val: t.Len(), - } + ti.kind = typeArray + ti.val = t.Len() + ti.elem = m.ToType(t.Elem()) case reflect.Map: - return &typeinfo{ - kind: typeMap, - key: m.ToType(t.Key()), - elem: m.ToType(t.Elem()), - } + ti.kind = typeMap + ti.key = m.ToType(t.Key()) + ti.elem = m.ToType(t.Elem()) case reflect.Pointer: - return &typeinfo{ - kind: typePointer, - elem: m.ToType(t.Elem()), - } + ti.kind = typePointer + ti.elem = m.ToType(t.Elem()) case reflect.Slice: - return &typeinfo{ - kind: typeSlice, - elem: m.ToType(t.Elem()), - } + ti.kind = typeSlice + ti.elem = m.ToType(t.Elem()) case reflect.Struct: n := t.NumField() fields := make([]Field, n) for i := 0; i < n; i++ { f := t.Field(i) - // Unexported fields are not supported. - if !f.IsExported() { - panic(fmt.Errorf("struct with unexported fields should be registered (%s)", t)) + if !f.IsExported() && offset == 0 { + ti.offset = offsetForType(t) } fields[i].name = f.Name fields[i].anon = f.Anonymous @@ -212,10 +223,8 @@ func (m *TypeMap) ToType(t reflect.Type) *typeinfo { fields[i].tag = string(f.Tag) fields[i].typ = m.ToType(f.Type) } - return &typeinfo{ - kind: typeStruct, - fields: fields, - } + ti.kind = typeStruct + ti.fields = fields case reflect.Func: nin := t.NumIn() nout := t.NumOut() @@ -226,14 +235,13 @@ func (m *TypeMap) ToType(t reflect.Type) *typeinfo { for i := 0; i < nout; i++ { types[nin+i] = m.ToType(t.Out(i)) } - return &typeinfo{ - kind: typeFunc, - val: nin<<1 | boolint(t.IsVariadic()), - args: types, - } + ti.kind = typeFunc + ti.val = nin<<1 | boolint(t.IsVariadic()) + ti.args = types default: panic(fmt.Errorf("unsupported reflect.Kind (%s)", t.Kind())) } + return ti } func boolint(x bool) int { diff --git a/internal/serde/unsafe.go b/internal/serde/unsafe.go index c52fb21..d31c1cc 100644 --- a/internal/serde/unsafe.go +++ b/internal/serde/unsafe.go @@ -56,3 +56,23 @@ func staticOffset(p unsafe.Pointer) int { func staticPointer(offset int) unsafe.Pointer { return unsafe.Add(staticuint64s, offset) } + +// namedType offset is the number of bytes from the address of the 'byte' type +// value to the ptr field of a reflect.Type. It is used to roundtrip named types +// for a given version of the program. +type namedTypeOffset int + +func offsetForType(t reflect.Type) namedTypeOffset { + tptr := (*iface)(unsafe.Pointer(&t)).ptr + bptr := (*iface)(unsafe.Pointer(&byteT)).ptr + return namedTypeOffset(uintptr(tptr) - uintptr(bptr)) +} + +func typeForOffset(offset namedTypeOffset) reflect.Type { + biface := (*iface)(unsafe.Pointer(&byteT)) + tiface := &iface{ + typ: biface.typ, + ptr: unsafe.Add(biface.ptr, offset), + } + return *(*reflect.Type)(unsafe.Pointer(tiface)) +} diff --git a/serde/serde_test.go b/serde/serde_test.go index b4176fa..f1cb3d3 100644 --- a/serde/serde_test.go +++ b/serde/serde_test.go @@ -67,7 +67,6 @@ func TestReflect(t *testing.T) { for _, x := range cases { t := reflect.TypeOf(x) - serdeinternal.Types.Add(t) if t.Kind() == reflect.Func { a := types.FuncAddr(x) @@ -99,7 +98,6 @@ func TestInt257(t *testing.T) { true, one, } - serde.RegisterType[[]any]() assertRoundTrip(t, x) } @@ -137,7 +135,7 @@ func TestReflectCustom(t *testing.T) { } testReflect(t, "int wrapper", func(t *testing.T) { - serde.RegisterTypeWithSerde[int](ser, des) + serde.RegisterSerde[int](ser, des) x := 42 p := &x @@ -160,8 +158,7 @@ func TestReflectCustom(t *testing.T) { y Y } - serde.RegisterType[X]() - serde.RegisterTypeWithSerde[int](ser, des) + serde.RegisterSerde[int](ser, des) x := X{ foo: "test", @@ -186,8 +183,7 @@ func TestReflectCustom(t *testing.T) { y *Y } - serde.RegisterType[X]() - serde.RegisterTypeWithSerde[int](ser, des) + serde.RegisterSerde[int](ser, des) x := &X{y: &Y{}} x.y.foo = "test" @@ -202,8 +198,7 @@ func TestReflectCustom(t *testing.T) { }) testReflect(t, "custom type in slice", func(t *testing.T) { - serde.RegisterTypeWithSerde[int](ser, des) - serde.RegisterType[[]int]() + serde.RegisterSerde[int](ser, des) x := []int{1, 2, 3, 42, 5, 6} assertRoundTrip(t, x) b := serdeinternal.Serialize(x) @@ -226,7 +221,7 @@ func TestReflectCustom(t *testing.T) { return nil } - serde.RegisterTypeWithSerde[http.Client](ser, des) + serde.RegisterSerde[http.Client](ser, des) x := http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { @@ -268,8 +263,6 @@ func TestReflectSharing(t *testing.T) { x.a[5] = 6 assertEqual(t, 6, x.b[5]) - serde.RegisterType[X]() - out := assertRoundTrip(t, x) // check map is shared after @@ -305,8 +298,6 @@ func TestReflectSharing(t *testing.T) { assertEqual(t, 3, cap(orig.s3)) assertEqual(t, 3, len(orig.s3)) - serde.RegisterType[X]() - out := assertRoundTrip(t, orig) // verify that the initial arrays were shared @@ -350,8 +341,6 @@ func TestReflectSharing(t *testing.T) { assertEqual(t, 3, cap(orig.s3)) assertEqual(t, 3, len(orig.s3)) - serde.RegisterType[X]() - out := assertRoundTrip(t, orig) // verify that the initial arrays were shared @@ -394,7 +383,6 @@ func TestReflectSharing(t *testing.T) { x.A.Y = 42 assertEqual(t, 42, *x.B.P) - serde.RegisterType[X]() out := assertRoundTrip(t, x) // verify the resulting pointer is correct @@ -411,8 +399,6 @@ func TestReflectSharing(t *testing.T) { x.z = x assertEqual(t, x, x.z) - serde.RegisterType[X]() - out := assertRoundTrip(t, x) assertEqual(t, out, out.z) @@ -431,7 +417,6 @@ func TestReflectSharing(t *testing.T) { x := X{Y{Z{42}}} - serde.RegisterType[X]() assertRoundTrip(t, x) }) @@ -455,7 +440,6 @@ func TestReflectSharing(t *testing.T) { }, } - serde.RegisterType[X]() assertRoundTrip(t, x) }) @@ -477,7 +461,6 @@ func TestReflectSharing(t *testing.T) { assertEqual(t, unsafe.Pointer(x), unsafe.Pointer(x.y.z)) - serde.RegisterType[X]() out := assertRoundTrip(t, x) out.z.v = "test" @@ -506,8 +489,6 @@ func TestReflectSharing(t *testing.T) { assertEqual(t, 3, cap(x.s1)) assertEqual(t, 2, cap(x.s2)) - serde.RegisterType[X]() - out := assertRoundTrip(t, x) // check underlying arrays are not shared @@ -527,7 +508,6 @@ func TestReflectSharing(t *testing.T) { "trois": data[0:3], } - serde.RegisterType[map[string][]int]() out := assertRoundTrip(t, x) out["un"][0] = 100 diff --git a/serde/typemap.go b/serde/typemap.go index bd1df9b..c3bd73c 100644 --- a/serde/typemap.go +++ b/serde/typemap.go @@ -10,22 +10,8 @@ type SerializerFn[T any] func(*Serializer, *T) error // DeserializerFn is the signature of customer deserializer functions. type DeserializerFn[T any] func(*Deserializer, *T) error -// RegisterType adds T to the global type register, as well as *T and all -// types contained within. -// -// Types can be registered multiple times, but care should be taken to always -// register them in a deterministic order; init functions are a good place for -// that. Most of the time, coroc takes care of registering types. -func RegisterType[T any]() { - serde.RegisterType[T]() -} - -// RegisterTypeWithSerde adds T to the global type register in the same way -// [RegisterType] does, but also attaches custom serialization and -// deserialization functions to T. -// -// If T already has custom serialization and deserialization functions, -// [RegisterTypeWithSerde] panics. -func RegisterTypeWithSerde[T any](ser SerializerFn[T], des DeserializerFn[T]) { - serde.RegisterTypeWithSerde[T](ser, des) +// RegisterSerde attaches custom serialization and deserialization functions to +// type T. +func RegisterSerde[T any](ser SerializerFn[T], des DeserializerFn[T]) { + serde.RegisterSerde[T](ser, des) }