From b3c3c69cdeb597d8d3ce1d3bfad7e40eefd147a6 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Mon, 18 Sep 2023 14:54:49 -0700 Subject: [PATCH] generate serializable function Signed-off-by: Achille Roussel --- compiler/Makefile | 2 +- compiler/compile.go | 57 ++++--- compiler/coroutine_test.go | 14 +- compiler/function.go | 85 ++++++++++ compiler/testdata/coroutine.go | 10 ++ compiler/testdata/coroutine_durable.go | 200 +++++++++++++++-------- compiler/testdata/coroutine_functypes.go | 28 ++++ compiler/types.go | 25 +++ internal/serde/reflect.go | 53 ++++-- internal/serde/serde.go | 12 +- types/func.go | 15 ++ 11 files changed, 394 insertions(+), 107 deletions(-) create mode 100644 compiler/function.go create mode 100644 compiler/testdata/coroutine_functypes.go diff --git a/compiler/Makefile b/compiler/Makefile index a1bd37b..9fb85e8 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -1,6 +1,6 @@ GO ?= go -TARGET = testdata/coroutine_durable.go +TARGET = testdata/coroutine_durable.go testdata/coroutine_functypes.go test: clean generate $(GO) test ./... diff --git a/compiler/compile.go b/compiler/compile.go index 403d0b0..4718b11 100644 --- a/compiler/compile.go +++ b/compiler/compile.go @@ -151,17 +151,42 @@ func (c *compiler) compile(path string) error { } for p, colors := range colorsByPkg { - if err := c.compilePackage(p, colors); err != nil { + if err := c.compilePackage(p, colors, prog); err != nil { return err } } log.Printf("done") - return nil } -func (c *compiler) compilePackage(p *packages.Package, colors functionColors) error { +func (c *compiler) writeFile(path string, file *ast.File) error { + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + // Comments are awkward to attach to the tree (they rely on token.Pos, which + // is coupled to a token.FileSet). Instead, just write out the raw strings. + var b strings.Builder + b.WriteString(`// Code generated by coroc. DO NOT EDIT`) + b.WriteString("\n\n") + if c.buildTags != "" { + b.WriteString(`//go:build `) + b.WriteString(c.buildTags) + b.WriteString("\n\n") + } + if _, err := f.WriteString(b.String()); err != nil { + return err + } + // Format/write the remainder of the AST. + if err := format.Node(f, c.fset, file); err != nil { + return err + } + return f.Close() +} + +func (c *compiler) compilePackage(p *packages.Package, colors functionColors, prog *ssa.Program) error { log.Printf("compiling package %s", p.Name) // Generate the coroutine AST. @@ -210,35 +235,19 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors) er return err } - // Get ready to write. packageDir := filepath.Dir(p.GoFiles[0]) outputPath := filepath.Join(packageDir, c.outputFilename) - outputFile, err := os.Create(outputPath) - if err != nil { - return fmt.Errorf("os.Create %q: %w", outputPath, err) - } - defer outputFile.Close() - - // Comments are awkward to attach to the tree (they rely on token.Pos, which - // is coupled to a token.FileSet). Instead, just write out the raw strings. - var b strings.Builder - b.WriteString(`// Code generated by coroc. DO NOT EDIT`) - b.WriteString("\n\n") - if c.buildTags != "" { - b.WriteString(`//go:build `) - b.WriteString(c.buildTags) - b.WriteString("\n\n") - } - if _, err := outputFile.WriteString(b.String()); err != nil { + if err := c.writeFile(outputPath, gen); err != nil { return err } - // Format/write the remainder of the AST. - if err := format.Node(outputFile, c.fset, gen); err != nil { + functypesFile := generateFunctypes(prog.Package(p.Types)) + functypesPath := filepath.Join(packageDir, "coroutine_functypes.go") + if err := c.writeFile(functypesPath, functypesFile); err != nil { return err } - return outputFile.Close() + return nil } func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color *types.Signature) *ast.FuncDecl { diff --git a/compiler/coroutine_test.go b/compiler/coroutine_test.go index ee5c3d6..b5acfcd 100644 --- a/compiler/coroutine_test.go +++ b/compiler/coroutine_test.go @@ -1,7 +1,6 @@ package compiler import ( - "reflect" "slices" "testing" @@ -100,6 +99,12 @@ func TestCoroutineYield(t *testing.T) { yields: []int{0, 5, 5, 50, 5, 4, 3, 2, 1, 0}, }, + { + name: "range over function", + coro: func() { Range(10, Double) }, + yields: []int{0, 2, 4, 6, 8, 10, 12, 14, 16, 18}, + }, + { name: "select", coro: func() { Select(8) }, @@ -107,14 +112,13 @@ func TestCoroutineYield(t *testing.T) { }, } - // TODO: remove me - // // This emulates the installation of function type information by the - // compiler until we have codegen for it. + // compiler because we are not doing codegen for the test files in this + // package. for _, test := range tests { a := types.FuncAddr(test.coro) f := types.FuncByAddr(a) - f.Type = reflect.TypeOf(func() {}) + types.RegisterFunc[func()](f.Name) } for _, test := range tests { diff --git a/compiler/function.go b/compiler/function.go new file mode 100644 index 0000000..5340dbd --- /dev/null +++ b/compiler/function.go @@ -0,0 +1,85 @@ +package compiler + +import ( + "cmp" + "go/ast" + "go/token" + "slices" + "strconv" + + "golang.org/x/tools/go/ssa" +) + +func generateFunctypes(pkg *ssa.Package) *ast.File { + var names = make([]string, 0, len(pkg.Members)) + for name := range pkg.Members { + names = append(names, name) + } + slices.Sort(names) + + var init ast.BlockStmt + var path = pkg.Pkg.Path() + for _, name := range names { + if fn, ok := pkg.Members[name].(*ssa.Function); ok { + generateFunctypesInit(path, &init, fn) + } + } + + return &ast.File{ + Name: ast.NewIdent(pkg.Pkg.Name()), + Decls: []ast.Decl{ + &ast.GenDecl{ + Tok: token.IMPORT, + Specs: []ast.Spec{ + &ast.ImportSpec{ + Name: ast.NewIdent("_types"), + Path: &ast.BasicLit{ + Kind: token.STRING, + Value: `"github.com/stealthrocket/coroutine/types"`, + }, + }, + }, + }, + &ast.FuncDecl{ + Name: ast.NewIdent("init"), + Type: &ast.FuncType{}, + Body: &init, + }, + }, + } +} + +func generateFunctypesInit(path string, init *ast.BlockStmt, fn *ssa.Function) { + if fn.TypeParams() != nil { + return // ignore non-instantiated generic functions + } + + init.List = append(init.List, &ast.ExprStmt{ + X: &ast.CallExpr{ + Fun: &ast.IndexListExpr{ + X: &ast.SelectorExpr{ + X: ast.NewIdent("_types"), + Sel: ast.NewIdent("RegisterFunc"), + }, + Indices: []ast.Expr{ + newFuncType(fn.Signature), + }, + }, + Args: []ast.Expr{ + &ast.BasicLit{ + Kind: token.STRING, + Value: strconv.Quote(path + "." + fn.Name()), + }, + }, + }, + }) + + anonFuncs := slices.Clone(fn.AnonFuncs) + slices.SortFunc(anonFuncs, func(f1, f2 *ssa.Function) int { + return cmp.Compare(f1.Name(), f2.Name()) + }) + + for _, anonFunc := range anonFuncs { + generateFunctypesInit(path, init, anonFunc) + } +} diff --git a/compiler/testdata/coroutine.go b/compiler/testdata/coroutine.go index 053a40f..2cae21e 100644 --- a/compiler/testdata/coroutine.go +++ b/compiler/testdata/coroutine.go @@ -256,6 +256,16 @@ func RangeOverMaps(n int) { } } +func Range(n int, do func(int)) { + for i := 0; i < n; i++ { + do(i) + } +} + +func Double(n int) { + coroutine.Yield[int, any](2 * n) +} + func Select(n int) { select { default: diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index bc0f703..fb45381 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -1317,6 +1317,54 @@ func RangeOverMaps(n int) { } } +func Range(n int, do func(int)) { + _c := coroutine.LoadContext[int, any]() + _f, _fp := _c.Push() + var _o0 int + if _f.IP > 0 { + n = _f.Get(0).(int) + do = _f.Get(1).(func(int)) + _o0 = _f.Get(2).(int) + } + defer func() { + if _c.Unwinding() { + _f.Set(0, n) + _f.Set(1, do) + _f.Set(2, _o0) + _c.Store(_fp, _f) + } else { + _c.Pop() + } + }() + switch { + case _f.IP < 2: + _o0 = 0 + _f.IP = 2 + fallthrough + case _f.IP < 3: + for ; _o0 < n; _o0, _f.IP = _o0+1, 2 { + do(_o0) + } + } +} + +func Double(n int) { + _c := coroutine.LoadContext[int, any]() + _f, _fp := _c.Push() + if _f.IP > 0 { + n = _f.Get(0).(int) + } + defer func() { + if _c.Unwinding() { + _f.Set(0, n) + _c.Store(_fp, _f) + } else { + _c.Pop() + } + }() + coroutine.Yield[int, any](2 * n) +} + func Select(n int) { _c := coroutine.LoadContext[int, any]() _f, _fp := _c.Push() @@ -1495,14 +1543,11 @@ func Select(n int) { func init() { serde.RegisterType[**byte]() serde.RegisterType[*[100000]uintptr]() - serde.RegisterType[*[1125899906842623]byte]() serde.RegisterType[*[131072]uint16]() serde.RegisterType[*[140737488355327]byte]() serde.RegisterType[*[16]byte]() serde.RegisterType[*[171]uint8]() serde.RegisterType[*[1]uintptr]() - serde.RegisterType[*[268435456]uintptr]() - serde.RegisterType[*[281474976710655]uint32]() serde.RegisterType[*[2]byte]() serde.RegisterType[*[2]float32]() serde.RegisterType[*[2]float64]() @@ -1512,12 +1557,11 @@ func init() { serde.RegisterType[*[32]rune]() serde.RegisterType[*[32]uintptr]() serde.RegisterType[*[4]byte]() - serde.RegisterType[*[562949953421311]uint16]() + serde.RegisterType[*[512]uintptr]() serde.RegisterType[*[65536]uintptr]() serde.RegisterType[*[70368744177663]uint16]() serde.RegisterType[*[8]byte]() serde.RegisterType[*[8]uint8]() - serde.RegisterType[*[]byte]() serde.RegisterType[*[]uint64]() serde.RegisterType[*bool]() serde.RegisterType[*byte]() @@ -1530,80 +1574,77 @@ func init() { serde.RegisterType[*uint64]() serde.RegisterType[*uint8]() serde.RegisterType[*uintptr]() - serde.RegisterType[[0]byte]() - serde.RegisterType[[0]uint8]() serde.RegisterType[[0]uintptr]() serde.RegisterType[[1000]uintptr]() serde.RegisterType[[100]byte]() serde.RegisterType[[1024]bool]() serde.RegisterType[[1024]byte]() + serde.RegisterType[[1024]int8]() serde.RegisterType[[1024]uint8]() serde.RegisterType[[1048576]uint8]() serde.RegisterType[[104]byte]() + serde.RegisterType[[104]int8]() + serde.RegisterType[[107]string]() serde.RegisterType[[108]byte]() - serde.RegisterType[[108]int8]() serde.RegisterType[[10]byte]() serde.RegisterType[[10]string]() serde.RegisterType[[128]byte]() serde.RegisterType[[128]uint64]() serde.RegisterType[[128]uintptr]() serde.RegisterType[[129]uint8]() + serde.RegisterType[[12]int8]() serde.RegisterType[[131072]uintptr]() - serde.RegisterType[[133]string]() serde.RegisterType[[13]int32]() serde.RegisterType[[14]byte]() serde.RegisterType[[14]int8]() - serde.RegisterType[[15]uint64]() serde.RegisterType[[16384]byte]() serde.RegisterType[[16384]uint8]() serde.RegisterType[[16]byte]() - serde.RegisterType[[16]int64]() + serde.RegisterType[[16]int8]() serde.RegisterType[[16]uint64]() + serde.RegisterType[[16]uintptr]() serde.RegisterType[[17]string]() serde.RegisterType[[1]byte]() + serde.RegisterType[[1]uint32]() serde.RegisterType[[1]uint64]() serde.RegisterType[[1]uint8]() serde.RegisterType[[1]uintptr]() serde.RegisterType[[20]byte]() + serde.RegisterType[[20]uint8]() serde.RegisterType[[21]byte]() - serde.RegisterType[[23]uint64]() serde.RegisterType[[249]uint8]() serde.RegisterType[[24]byte]() - serde.RegisterType[[24]uint32]() serde.RegisterType[[252]uintptr]() serde.RegisterType[[253]uintptr]() - serde.RegisterType[[256]int8]() serde.RegisterType[[256]uint64]() + serde.RegisterType[[29]uint64]() serde.RegisterType[[2]byte]() serde.RegisterType[[2]int]() serde.RegisterType[[2]int32]() + serde.RegisterType[[2]int64]() + serde.RegisterType[[2]uint32]() serde.RegisterType[[2]uint64]() serde.RegisterType[[2]uintptr]() serde.RegisterType[[32]byte]() + serde.RegisterType[[32]int32]() serde.RegisterType[[32]string]() - serde.RegisterType[[32]uint8]() + serde.RegisterType[[32]uint32]() serde.RegisterType[[32]uintptr]() serde.RegisterType[[33]float64]() serde.RegisterType[[3]byte]() serde.RegisterType[[3]int]() - serde.RegisterType[[3]int64]() - serde.RegisterType[[3]uint16]() - serde.RegisterType[[3]uint32]() - serde.RegisterType[[3]uint64]() serde.RegisterType[[4096]byte]() - serde.RegisterType[[40]byte]() - serde.RegisterType[[44]byte]() + serde.RegisterType[[40]int8]() serde.RegisterType[[4]byte]() serde.RegisterType[[4]float64]() - serde.RegisterType[[4]int64]() serde.RegisterType[[4]string]() - serde.RegisterType[[4]uint16]() serde.RegisterType[[4]uint32]() serde.RegisterType[[4]uint64]() serde.RegisterType[[4]uintptr]() serde.RegisterType[[50]uintptr]() serde.RegisterType[[512]byte]() serde.RegisterType[[512]uintptr]() + serde.RegisterType[[56]int8]() serde.RegisterType[[5]byte]() serde.RegisterType[[5]uint]() serde.RegisterType[[61]struct { @@ -1613,11 +1654,9 @@ func init() { }]() serde.RegisterType[[64488]byte]() serde.RegisterType[[64]byte]() + serde.RegisterType[[64]uint64]() serde.RegisterType[[64]uintptr]() serde.RegisterType[[65528]byte]() - serde.RegisterType[[65]int8]() - serde.RegisterType[[65]uint32]() - serde.RegisterType[[65]uintptr]() serde.RegisterType[[68]struct { Size uint32 Mallocs uint64 @@ -1627,20 +1666,18 @@ func init() { serde.RegisterType[[68]uint32]() serde.RegisterType[[68]uint64]() serde.RegisterType[[68]uint8]() - serde.RegisterType[[6]byte]() serde.RegisterType[[6]int]() - serde.RegisterType[[6]int8]() serde.RegisterType[[6]uintptr]() - serde.RegisterType[[8192]byte]() + serde.RegisterType[[7]uint64]() + serde.RegisterType[[88]byte]() serde.RegisterType[[8]byte]() + serde.RegisterType[[8]int8]() serde.RegisterType[[8]string]() serde.RegisterType[[8]uint32]() - serde.RegisterType[[8]uint64]() serde.RegisterType[[8]uint8]() + serde.RegisterType[[92]int8]() serde.RegisterType[[96]byte]() - serde.RegisterType[[96]int8]() serde.RegisterType[[9]string]() - serde.RegisterType[[9]uintptr]() serde.RegisterType[[]*byte]() serde.RegisterType[[][]int32]() serde.RegisterType[[]byte]() @@ -1706,10 +1743,53 @@ func init() { needed bool alignme uint64 }]() + serde.RegisterType[struct { + fd int32 + cmd int32 + arg int32 + ret int32 + errno int32 + }]() serde.RegisterType[struct { fill uint64 capacity uint64 }]() + serde.RegisterType[struct { + fn uintptr + a1 uintptr + a2 uintptr + a3 uintptr + a4 uintptr + a5 uintptr + a6 uintptr + r1 uintptr + r2 uintptr + err uintptr + }]() + serde.RegisterType[struct { + fn uintptr + a1 uintptr + a2 uintptr + a3 uintptr + a4 uintptr + a5 uintptr + f1 float64 + r1 uintptr + }]() + serde.RegisterType[struct { + fn uintptr + a1 uintptr + a2 uintptr + a3 uintptr + r1 uintptr + r2 uintptr + err uintptr + }]() + serde.RegisterType[struct { + t int64 + numer uint32 + denom uint32 + }]() serde.RegisterType[struct { tick uint64 i int @@ -1722,73 +1802,65 @@ func init() { serde.RegisterType[sync.Pool]() serde.RegisterType[sync.RWMutex]() serde.RegisterType[sync.WaitGroup]() + serde.RegisterType[syscall.BpfHdr]() + serde.RegisterType[syscall.BpfInsn]() + serde.RegisterType[syscall.BpfProgram]() + serde.RegisterType[syscall.BpfStat]() + serde.RegisterType[syscall.BpfVersion]() serde.RegisterType[syscall.Cmsghdr]() serde.RegisterType[syscall.Credential]() serde.RegisterType[syscall.Dirent]() - serde.RegisterType[syscall.EpollEvent]() serde.RegisterType[syscall.Errno]() + serde.RegisterType[syscall.Fbootstraptransfer_t]() serde.RegisterType[syscall.FdSet]() serde.RegisterType[syscall.Flock_t]() serde.RegisterType[syscall.Fsid]() + serde.RegisterType[syscall.Fstore_t]() serde.RegisterType[syscall.ICMPv6Filter]() serde.RegisterType[syscall.IPMreq]() - serde.RegisterType[syscall.IPMreqn]() serde.RegisterType[syscall.IPv6MTUInfo]() serde.RegisterType[syscall.IPv6Mreq]() - serde.RegisterType[syscall.IfAddrmsg]() - serde.RegisterType[syscall.IfInfomsg]() + serde.RegisterType[syscall.IfData]() + serde.RegisterType[syscall.IfMsghdr]() + serde.RegisterType[syscall.IfaMsghdr]() + serde.RegisterType[syscall.IfmaMsghdr]() + serde.RegisterType[syscall.IfmaMsghdr2]() serde.RegisterType[syscall.Inet4Pktinfo]() serde.RegisterType[syscall.Inet6Pktinfo]() - serde.RegisterType[syscall.InotifyEvent]() + serde.RegisterType[syscall.InterfaceAddrMessage]() + serde.RegisterType[syscall.InterfaceMessage]() + serde.RegisterType[syscall.InterfaceMulticastAddrMessage]() serde.RegisterType[syscall.Iovec]() + serde.RegisterType[syscall.Kevent_t]() serde.RegisterType[syscall.Linger]() + serde.RegisterType[syscall.Log2phys_t]() serde.RegisterType[syscall.Msghdr]() - serde.RegisterType[syscall.NetlinkMessage]() - serde.RegisterType[syscall.NetlinkRouteAttr]() - serde.RegisterType[syscall.NetlinkRouteRequest]() - serde.RegisterType[syscall.NlAttr]() - serde.RegisterType[syscall.NlMsgerr]() - serde.RegisterType[syscall.NlMsghdr]() serde.RegisterType[syscall.ProcAttr]() - serde.RegisterType[syscall.PtraceRegs]() + serde.RegisterType[syscall.Radvisory_t]() serde.RegisterType[syscall.RawSockaddr]() serde.RegisterType[syscall.RawSockaddrAny]() + serde.RegisterType[syscall.RawSockaddrDatalink]() serde.RegisterType[syscall.RawSockaddrInet4]() serde.RegisterType[syscall.RawSockaddrInet6]() - serde.RegisterType[syscall.RawSockaddrLinklayer]() - serde.RegisterType[syscall.RawSockaddrNetlink]() serde.RegisterType[syscall.RawSockaddrUnix]() serde.RegisterType[syscall.Rlimit]() - serde.RegisterType[syscall.RtAttr]() - serde.RegisterType[syscall.RtGenmsg]() - serde.RegisterType[syscall.RtMsg]() - serde.RegisterType[syscall.RtNexthop]() + serde.RegisterType[syscall.RouteMessage]() + serde.RegisterType[syscall.RtMetrics]() + serde.RegisterType[syscall.RtMsghdr]() serde.RegisterType[syscall.Rusage]() serde.RegisterType[syscall.Signal]() - serde.RegisterType[syscall.SockFilter]() - serde.RegisterType[syscall.SockFprog]() + serde.RegisterType[syscall.SockaddrDatalink]() serde.RegisterType[syscall.SockaddrInet4]() serde.RegisterType[syscall.SockaddrInet6]() - serde.RegisterType[syscall.SockaddrLinklayer]() - serde.RegisterType[syscall.SockaddrNetlink]() serde.RegisterType[syscall.SockaddrUnix]() serde.RegisterType[syscall.SocketControlMessage]() serde.RegisterType[syscall.Stat_t]() serde.RegisterType[syscall.Statfs_t]() serde.RegisterType[syscall.SysProcAttr]() - serde.RegisterType[syscall.SysProcIDMap]() - serde.RegisterType[syscall.Sysinfo_t]() - serde.RegisterType[syscall.TCPInfo]() serde.RegisterType[syscall.Termios]() - serde.RegisterType[syscall.Time_t]() serde.RegisterType[syscall.Timespec]() serde.RegisterType[syscall.Timeval]() - serde.RegisterType[syscall.Timex]() - serde.RegisterType[syscall.Tms]() - serde.RegisterType[syscall.Ucred]() - serde.RegisterType[syscall.Ustat_t]() - serde.RegisterType[syscall.Utimbuf]() - serde.RegisterType[syscall.Utsname]() + serde.RegisterType[syscall.Timeval32]() serde.RegisterType[syscall.WaitStatus]() serde.RegisterType[time.Duration]() serde.RegisterType[time.Location]() diff --git a/compiler/testdata/coroutine_functypes.go b/compiler/testdata/coroutine_functypes.go new file mode 100644 index 0000000..ed2760b --- /dev/null +++ b/compiler/testdata/coroutine_functypes.go @@ -0,0 +1,28 @@ +// Code generated by coroc. DO NOT EDIT + +//go:build durable + +package testdata + +import _types "github.com/stealthrocket/coroutine/types" + +func init() { + _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.Double") + _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.EvenSquareGenerator") + _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.FizzBuzzIfGenerator") + _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.FizzBuzzSwitchGenerator") + _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.Identity") + _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.LoopBreakAndContinue") + _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.NestedLoops") + _types.RegisterFunc[func(int, func(int))]("github.com/stealthrocket/coroutine/compiler/testdata.Range") + _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeArrayIndexValueGenerator") + _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeOverMaps") + _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeSliceIndexGenerator") + _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.Select") + _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.Shadowing") + _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGenerator") + _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGeneratorTwice") + _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGeneratorTwiceLoop") + _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.TypeSwitchingGenerator") + _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.init") +} diff --git a/compiler/types.go b/compiler/types.go index de448c3..91b42a1 100644 --- a/compiler/types.go +++ b/compiler/types.go @@ -43,6 +43,31 @@ func typeExpr(typ types.Type) ast.Expr { if t.Empty() { return ast.NewIdent("any") } + case *types.Signature: + return newFuncType(t) } panic(fmt.Sprintf("not implemented: %T", typ)) } + +func newFuncType(signature *types.Signature) *ast.FuncType { + return &ast.FuncType{ + Params: newFieldList(signature.Params()), + Results: newFieldList(signature.Results()), + } +} + +func newFieldList(tuple *types.Tuple) *ast.FieldList { + return &ast.FieldList{ + List: newFields(tuple), + } +} + +func newFields(tuple *types.Tuple) []*ast.Field { + fields := make([]*ast.Field, tuple.Len()) + for i := range fields { + fields[i] = &ast.Field{ + Type: typeExpr(tuple.At(i).Type()), + } + } + return fields +} diff --git a/internal/serde/reflect.go b/internal/serde/reflect.go index 0080a71..846868b 100644 --- a/internal/serde/reflect.go +++ b/internal/serde/reflect.go @@ -24,14 +24,28 @@ func serializeType(s *Serializer, t reflect.Type) { return } - if t.Kind() != reflect.Array { + switch t.Kind() { + case reflect.Array: + serializeVarint(s, -1) + serializeVarint(s, t.Len()) + serializeType(s, t.Elem()) + + case reflect.Func: // TODO: generics + numIn, numOut, variadic := t.NumIn(), t.NumOut(), t.IsVariadic() + serializeVarint(s, -2) + serializeVarint(s, numIn) + serializeVarint(s, numOut) + serializeBool(s, variadic) + for i := 0; i < numIn; i++ { + serializeType(s, t.In(i)) + } + for i := 0; i < numOut; i++ { + serializeType(s, t.Out(i)) + } + + default: serializeVarint(s, int(Types.idOf(t))) - return } - - serializeVarint(s, -1) - serializeVarint(s, t.Len()) - serializeType(s, t.Elem()) } func deserializeType(d *Deserializer) reflect.Type { @@ -44,13 +58,32 @@ func deserializeType(d *Deserializer) reflect.Type { return Types.typeOf(sID(n)) } - if n != -1 { + switch n { + case -1: + l := deserializeVarint(d) + et := deserializeType(d) + return reflect.ArrayOf(l, et) + + case -2: + numIn := deserializeVarint(d) + numOut := deserializeVarint(d) + variadic := deserializeBool(d) + in := deserializeTypes(d, numIn) + out := deserializeTypes(d, numOut) + return reflect.FuncOf(in, out, variadic) + + default: panic(fmt.Errorf("unknown type first int: %d", n)) } - l := deserializeVarint(d) - et := deserializeType(d) - return reflect.ArrayOf(l, et) +} + +func deserializeTypes(d *Deserializer, n int) []reflect.Type { + types := make([]reflect.Type, n) + for i := range types { + types[i] = deserializeType(d) + } + return types } func SerializeAny(s *Serializer, t reflect.Type, p unsafe.Pointer) { diff --git a/internal/serde/serde.go b/internal/serde/serde.go index 913f1fd..7ee47f2 100644 --- a/internal/serde/serde.go +++ b/internal/serde/serde.go @@ -8,7 +8,6 @@ package serde import ( "encoding/binary" "fmt" - "log/slog" "reflect" "unsafe" ) @@ -76,8 +75,6 @@ func (d *Deserializer) readPtr() (unsafe.Pointer, sID) { d.b = d.b[n:] i := sID(x) p := d.ptrs[i] - - slog.Debug("Deserializer ReadPtr", "i", i, "p", p, "n", n) return p, i } @@ -152,3 +149,12 @@ func deserializeVarint(d *Deserializer) int { d.b = d.b[n:] return int(l) } + +func serializeBool(s *Serializer, v bool) { + SerializeBool(s, v) +} + +func deserializeBool(d *Deserializer) (v bool) { + DeserializeBool(d, &v) + return +} diff --git a/types/func.go b/types/func.go index 86791d0..d1441a4 100644 --- a/types/func.go +++ b/types/func.go @@ -44,6 +44,21 @@ type Func struct { Closure reflect.Type } +// RegisterFunc is a helper function used to register function types. The type +// parameter must be a function type, but no compile nor runtime checks are used +// to enforce it; passing anything other than a function type will likely result +// in panics later on when the program attempts to serialize the function value. +// +// The name argument is a unique identifier of the Go symbol that represents the +// function, which has the package path as prefix, and the dot-separated sequence +// identifying the function in the package. +func RegisterFunc[Type any](name string) { + if f := FuncByName(name); f != nil { + var signature Type + f.Type = reflect.TypeOf(signature) + } +} + // Go function values are pointers to an object starting with the function // address, whether they are referencing top-level functions or closures. //