Skip to content

Commit

Permalink
serializable anonymous functions
Browse files Browse the repository at this point in the history
Signed-off-by: Achille Roussel <[email protected]>
  • Loading branch information
achille-roussel committed Sep 19, 2023
1 parent a84c23f commit 31582b0
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 54 deletions.
14 changes: 9 additions & 5 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,14 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr
},
})

colorsByDecl := map[*ast.FuncDecl]*types.Signature{}
colorsByDecl := map[ast.Node]*types.Signature{}
for fn, color := range colors {
decl, ok := fn.Syntax().(*ast.FuncDecl)
if !ok {
return fmt.Errorf("unsupported yield function %s (Syntax is %T, not *ast.FuncDecl)", fn, fn.Syntax())
decl := fn.Syntax()
switch decl.(type) {
case *ast.FuncDecl:
case *ast.FuncLit:
default:
return fmt.Errorf("unsupported yield function %s (Syntax is %T, not *ast.FuncDecl or *ast.FuncLit)", fn, decl)
}
colorsByDecl[decl] = color
}
Expand All @@ -216,6 +219,7 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr
if !ok {
continue
}

color, ok := colorsByDecl[decl]
if !ok {
continue
Expand All @@ -241,7 +245,7 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr
return err
}

functypesFile := generateFunctypes(prog.Package(p.Types))
functypesFile := generateFunctypes(prog.Package(p.Types), colors)
functypesPath := filepath.Join(packageDir, "coroutine_functypes.go")
if err := c.writeFile(functypesPath, functypesFile); err != nil {
return err
Expand Down
6 changes: 6 additions & 0 deletions compiler/coroutine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ func TestCoroutineYield(t *testing.T) {
yields: []int{0, 2, 4, 6, 8, 10, 12, 14, 16, 18},
},

{
name: "range over anonymous function",
coro: func() { RangeTriple(4) },
yields: []int{0, 3, 6, 9},
},

{
name: "select",
coro: func() { Select(8) },
Expand Down
32 changes: 25 additions & 7 deletions compiler/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package compiler

import (
"cmp"
"fmt"
"go/ast"
"go/token"
"slices"
Expand All @@ -10,18 +11,18 @@ import (
"golang.org/x/tools/go/ssa"
)

func generateFunctypes(pkg *ssa.Package) *ast.File {
func generateFunctypes(pkg *ssa.Package, colors functionColors) *ast.File {
var names = make([]string, 0, len(pkg.Members))
for name := range pkg.Members {
names = append(names, name)
}
slices.Sort(names)

var init ast.BlockStmt
var path = pkg.Pkg.Path()
for _, name := range names {
if fn, ok := pkg.Members[name].(*ssa.Function); ok {
generateFunctypesInit(path, &init, fn)
name := pkg.Pkg.Path() + "." + fn.Name()
generateFunctypesInit(pkg, fn, &init, name, colors)
}
}

Expand Down Expand Up @@ -49,7 +50,7 @@ func generateFunctypes(pkg *ssa.Package) *ast.File {
}
}

func generateFunctypesInit(path string, init *ast.BlockStmt, fn *ssa.Function) {
func generateFunctypesInit(pkg *ssa.Package, fn *ssa.Function, init *ast.BlockStmt, name string, colors functionColors) {
if fn.TypeParams() != nil {
return // ignore non-instantiated generic functions
}
Expand All @@ -68,7 +69,7 @@ func generateFunctypesInit(path string, init *ast.BlockStmt, fn *ssa.Function) {
Args: []ast.Expr{
&ast.BasicLit{
Kind: token.STRING,
Value: strconv.Quote(path + "." + fn.Name()),
Value: strconv.Quote(name),
},
},
},
Expand All @@ -79,7 +80,24 @@ func generateFunctypesInit(path string, init *ast.BlockStmt, fn *ssa.Function) {
return cmp.Compare(f1.Name(), f2.Name())
})

for _, anonFunc := range anonFuncs {
generateFunctypesInit(path, init, anonFunc)
for index, anonFunc := range anonFuncs {
_, colored := colors[anonFunc]
if colored {
// Colored functions (those rewritten into coroutines) have a
// deferred anonymous function injected at the beginning to perform
// stack unwinding, which takes the ".func1" name.
index++
}
name = anonFuncLinkName(name, index)
generateFunctypesInit(pkg, anonFunc, init, name, colors)
}
}

// This function computes the name that the linker gives to anonymous functions,
// using the base name of their parent function and appending ".func<index>".
//
// The function works with multiple levels of nesting as each level adds another
// ".func<index>" suffix, with the index being local to the parent scope.
func anonFuncLinkName(base string, index int) string {
return fmt.Sprintf("%s.func%d", base, index+1)
}
6 changes: 6 additions & 0 deletions compiler/testdata/coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,12 @@ func Double(n int) {
coroutine.Yield[int, any](2 * n)
}

func RangeTriple(n int) {
Range(n, func(i int) {
coroutine.Yield[int, any](3 * i)
})
}

func Select(n int) {
select {
default:
Expand Down
19 changes: 19 additions & 0 deletions compiler/testdata/coroutine_durable.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions compiler/testdata/coroutine_functypes.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

43 changes: 1 addition & 42 deletions compiler/unsupported.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,9 @@ import (
)

// unsupported checks a function for unsupported language features.
func unsupported(decl *ast.FuncDecl, info *types.Info) (err error) {
func unsupported(decl ast.Node, info *types.Info) (err error) {
ast.Inspect(decl, func(node ast.Node) bool {
switch nn := node.(type) {
case ast.Expr:
switch nn.(type) {
case *ast.FuncLit:
err = fmt.Errorf("not implemented: func literals")
}
if countFunctionCalls(nn, info) > 1 {
err = fmt.Errorf("not implemented: multiple function calls in an expression")
}

case ast.Stmt:
switch n := nn.(type) {
// Not yet supported:
Expand Down Expand Up @@ -83,35 +74,3 @@ func unsupported(decl *ast.FuncDecl, info *types.Info) (err error) {
})
return
}

func countFunctionCalls(expr ast.Expr, info *types.Info) (count int) {
ast.Inspect(expr, func(node ast.Node) bool {
c, ok := node.(*ast.CallExpr)
if !ok {
return true
}
switch f := c.Fun.(type) {
case *ast.Ident:
if obj := info.ObjectOf(f); types.Universe.Lookup(f.Name) == obj {
return true // skip builtins
} else if _, ok := obj.(*types.TypeName); ok {
return true // skip type casts
}
case *ast.SelectorExpr:
if x, ok := f.X.(*ast.Ident); ok {
if obj := info.ObjectOf(x); obj != nil {
if pkg, ok := obj.(*types.PkgName); ok {
pkgPath := pkg.Imported().Path()
switch {
case pkgPath == "unsafe":
return true // skip unsafe intrinsics
}
}
}
}
}
count++
return true
})
return
}

0 comments on commit 31582b0

Please sign in to comment.