From 31582b0c2a3b079da175aa380d57d2b48481c1e3 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Mon, 18 Sep 2023 21:53:50 -0700 Subject: [PATCH] serializable anonymous functions Signed-off-by: Achille Roussel --- compiler/compile.go | 14 +++++--- compiler/coroutine_test.go | 6 ++++ compiler/function.go | 32 ++++++++++++++---- compiler/testdata/coroutine.go | 6 ++++ compiler/testdata/coroutine_durable.go | 19 +++++++++++ compiler/testdata/coroutine_functypes.go | 2 ++ compiler/unsupported.go | 43 +----------------------- 7 files changed, 68 insertions(+), 54 deletions(-) diff --git a/compiler/compile.go b/compiler/compile.go index 4718b11..01e6d41 100644 --- a/compiler/compile.go +++ b/compiler/compile.go @@ -202,11 +202,14 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr }, }) - colorsByDecl := map[*ast.FuncDecl]*types.Signature{} + colorsByDecl := map[ast.Node]*types.Signature{} for fn, color := range colors { - decl, ok := fn.Syntax().(*ast.FuncDecl) - if !ok { - return fmt.Errorf("unsupported yield function %s (Syntax is %T, not *ast.FuncDecl)", fn, fn.Syntax()) + decl := fn.Syntax() + switch decl.(type) { + case *ast.FuncDecl: + case *ast.FuncLit: + default: + return fmt.Errorf("unsupported yield function %s (Syntax is %T, not *ast.FuncDecl or *ast.FuncLit)", fn, decl) } colorsByDecl[decl] = color } @@ -216,6 +219,7 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr if !ok { continue } + color, ok := colorsByDecl[decl] if !ok { continue @@ -241,7 +245,7 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr return err } - functypesFile := generateFunctypes(prog.Package(p.Types)) + functypesFile := generateFunctypes(prog.Package(p.Types), colors) functypesPath := filepath.Join(packageDir, "coroutine_functypes.go") if err := c.writeFile(functypesPath, functypesFile); err != nil { return err diff --git a/compiler/coroutine_test.go b/compiler/coroutine_test.go index b5acfcd..b3e605e 100644 --- a/compiler/coroutine_test.go +++ b/compiler/coroutine_test.go @@ -105,6 +105,12 @@ func TestCoroutineYield(t *testing.T) { yields: []int{0, 2, 4, 6, 8, 10, 12, 14, 16, 18}, }, + { + name: "range over anonymous function", + coro: func() { RangeTriple(4) }, + yields: []int{0, 3, 6, 9}, + }, + { name: "select", coro: func() { Select(8) }, diff --git a/compiler/function.go b/compiler/function.go index 5340dbd..497c54c 100644 --- a/compiler/function.go +++ b/compiler/function.go @@ -2,6 +2,7 @@ package compiler import ( "cmp" + "fmt" "go/ast" "go/token" "slices" @@ -10,7 +11,7 @@ import ( "golang.org/x/tools/go/ssa" ) -func generateFunctypes(pkg *ssa.Package) *ast.File { +func generateFunctypes(pkg *ssa.Package, colors functionColors) *ast.File { var names = make([]string, 0, len(pkg.Members)) for name := range pkg.Members { names = append(names, name) @@ -18,10 +19,10 @@ func generateFunctypes(pkg *ssa.Package) *ast.File { slices.Sort(names) var init ast.BlockStmt - var path = pkg.Pkg.Path() for _, name := range names { if fn, ok := pkg.Members[name].(*ssa.Function); ok { - generateFunctypesInit(path, &init, fn) + name := pkg.Pkg.Path() + "." + fn.Name() + generateFunctypesInit(pkg, fn, &init, name, colors) } } @@ -49,7 +50,7 @@ func generateFunctypes(pkg *ssa.Package) *ast.File { } } -func generateFunctypesInit(path string, init *ast.BlockStmt, fn *ssa.Function) { +func generateFunctypesInit(pkg *ssa.Package, fn *ssa.Function, init *ast.BlockStmt, name string, colors functionColors) { if fn.TypeParams() != nil { return // ignore non-instantiated generic functions } @@ -68,7 +69,7 @@ func generateFunctypesInit(path string, init *ast.BlockStmt, fn *ssa.Function) { Args: []ast.Expr{ &ast.BasicLit{ Kind: token.STRING, - Value: strconv.Quote(path + "." + fn.Name()), + Value: strconv.Quote(name), }, }, }, @@ -79,7 +80,24 @@ func generateFunctypesInit(path string, init *ast.BlockStmt, fn *ssa.Function) { return cmp.Compare(f1.Name(), f2.Name()) }) - for _, anonFunc := range anonFuncs { - generateFunctypesInit(path, init, anonFunc) + for index, anonFunc := range anonFuncs { + _, colored := colors[anonFunc] + if colored { + // Colored functions (those rewritten into coroutines) have a + // deferred anonymous function injected at the beginning to perform + // stack unwinding, which takes the ".func1" name. + index++ + } + name = anonFuncLinkName(name, index) + generateFunctypesInit(pkg, anonFunc, init, name, colors) } } + +// This function computes the name that the linker gives to anonymous functions, +// using the base name of their parent function and appending ".func". +// +// The function works with multiple levels of nesting as each level adds another +// ".func" suffix, with the index being local to the parent scope. +func anonFuncLinkName(base string, index int) string { + return fmt.Sprintf("%s.func%d", base, index+1) +} diff --git a/compiler/testdata/coroutine.go b/compiler/testdata/coroutine.go index 2cae21e..b6fe42f 100644 --- a/compiler/testdata/coroutine.go +++ b/compiler/testdata/coroutine.go @@ -266,6 +266,12 @@ func Double(n int) { coroutine.Yield[int, any](2 * n) } +func RangeTriple(n int) { + Range(n, func(i int) { + coroutine.Yield[int, any](3 * i) + }) +} + func Select(n int) { select { default: diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index fb45381..dc14a88 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -1365,6 +1365,25 @@ func Double(n int) { coroutine.Yield[int, any](2 * n) } +func RangeTriple(n int) { + _c := coroutine.LoadContext[int, any]() + _f, _fp := _c.Push() + if _f.IP > 0 { + n = _f.Get(0).(int) + } + defer func() { + if _c.Unwinding() { + _f.Set(0, n) + _c.Store(_fp, _f) + } else { + _c.Pop() + } + }() + Range(n, func(i int) { + coroutine.Yield[int, any](3 * i) + }) +} + func Select(n int) { _c := coroutine.LoadContext[int, any]() _f, _fp := _c.Push() diff --git a/compiler/testdata/coroutine_functypes.go b/compiler/testdata/coroutine_functypes.go index ed2760b..c478326 100644 --- a/compiler/testdata/coroutine_functypes.go +++ b/compiler/testdata/coroutine_functypes.go @@ -18,6 +18,8 @@ func init() { _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeArrayIndexValueGenerator") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeOverMaps") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeSliceIndexGenerator") + _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeTriple") + _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeTriple.func2") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.Select") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.Shadowing") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGenerator") diff --git a/compiler/unsupported.go b/compiler/unsupported.go index 739c25a..f8a8a6e 100644 --- a/compiler/unsupported.go +++ b/compiler/unsupported.go @@ -8,18 +8,9 @@ import ( ) // unsupported checks a function for unsupported language features. -func unsupported(decl *ast.FuncDecl, info *types.Info) (err error) { +func unsupported(decl ast.Node, info *types.Info) (err error) { ast.Inspect(decl, func(node ast.Node) bool { switch nn := node.(type) { - case ast.Expr: - switch nn.(type) { - case *ast.FuncLit: - err = fmt.Errorf("not implemented: func literals") - } - if countFunctionCalls(nn, info) > 1 { - err = fmt.Errorf("not implemented: multiple function calls in an expression") - } - case ast.Stmt: switch n := nn.(type) { // Not yet supported: @@ -83,35 +74,3 @@ func unsupported(decl *ast.FuncDecl, info *types.Info) (err error) { }) return } - -func countFunctionCalls(expr ast.Expr, info *types.Info) (count int) { - ast.Inspect(expr, func(node ast.Node) bool { - c, ok := node.(*ast.CallExpr) - if !ok { - return true - } - switch f := c.Fun.(type) { - case *ast.Ident: - if obj := info.ObjectOf(f); types.Universe.Lookup(f.Name) == obj { - return true // skip builtins - } else if _, ok := obj.(*types.TypeName); ok { - return true // skip type casts - } - case *ast.SelectorExpr: - if x, ok := f.X.(*ast.Ident); ok { - if obj := info.ObjectOf(x); obj != nil { - if pkg, ok := obj.(*types.PkgName); ok { - pkgPath := pkg.Imported().Path() - switch { - case pkgPath == "unsafe": - return true // skip unsafe intrinsics - } - } - } - } - } - count++ - return true - }) - return -}