Skip to content

Commit

Permalink
serializable anonymous functions (#38)
Browse files Browse the repository at this point in the history
This PR adds the ability to have anonymous functions on the call path of
coroutines, including being able to serialize them.

I didn't really do it on purpose but it seems that when the function
literal is injected directly as argument to a function call, it does not
get captured in the coroutine state and remains inlined in the code
(which is optimal I believe?).

I removed checks in `unsupported` which were getting in the way of
supporting function literals, tho I don't fully understand the reasons
why those checks were here in the first place so I could use a review on
that @chriso
  • Loading branch information
achille-roussel authored Sep 19, 2023
2 parents 261eb4e + a9545c3 commit 52df481
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 18 deletions.
14 changes: 9 additions & 5 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,14 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr
},
})

colorsByDecl := map[*ast.FuncDecl]*types.Signature{}
colorsByDecl := map[ast.Node]*types.Signature{}
for fn, color := range colors {
decl, ok := fn.Syntax().(*ast.FuncDecl)
if !ok {
return fmt.Errorf("unsupported yield function %s (Syntax is %T, not *ast.FuncDecl)", fn, fn.Syntax())
decl := fn.Syntax()
switch decl.(type) {
case *ast.FuncDecl:
case *ast.FuncLit:
default:
return fmt.Errorf("unsupported yield function %s (Syntax is %T, not *ast.FuncDecl or *ast.FuncLit)", fn, decl)
}
colorsByDecl[decl] = color
}
Expand All @@ -235,6 +238,7 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr
if !ok {
continue
}

color, ok := colorsByDecl[decl]
if !ok {
continue
Expand All @@ -260,7 +264,7 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr
return err
}

functypesFile := generateFunctypes(prog.Package(p.Types))
functypesFile := generateFunctypes(prog.Package(p.Types), colors)
functypesPath := filepath.Join(packageDir, "coroutine_functypes.go")
if err := c.writeFile(functypesPath, functypesFile); err != nil {
return err
Expand Down
14 changes: 14 additions & 0 deletions compiler/coroutine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,20 @@ func TestCoroutineYield(t *testing.T) {
yields: []int{0, 2, 4, 6, 8, 10, 12, 14, 16, 18},
},

// TODO: desugar function call expressions to enable this test.
//
// {
// name: "range over anonymous function",
// coro: func() { RangeTriple(4) },
// yields: []int{0, 3, 6, 9},
// },

{
name: "range over anonymous function value",
coro: func() { RangeTripleFuncValue(4) },
yields: []int{0, 3, 6, 9},
},

{
name: "select",
coro: func() { Select(8) },
Expand Down
32 changes: 25 additions & 7 deletions compiler/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package compiler

import (
"cmp"
"fmt"
"go/ast"
"go/token"
"slices"
Expand All @@ -10,18 +11,18 @@ import (
"golang.org/x/tools/go/ssa"
)

func generateFunctypes(pkg *ssa.Package) *ast.File {
func generateFunctypes(pkg *ssa.Package, colors functionColors) *ast.File {
var names = make([]string, 0, len(pkg.Members))
for name := range pkg.Members {
names = append(names, name)
}
slices.Sort(names)

var init ast.BlockStmt
var path = pkg.Pkg.Path()
for _, name := range names {
if fn, ok := pkg.Members[name].(*ssa.Function); ok {
generateFunctypesInit(path, &init, fn)
name := pkg.Pkg.Path() + "." + fn.Name()
generateFunctypesInit(pkg, fn, &init, name, colors)
}
}

Expand Down Expand Up @@ -49,7 +50,7 @@ func generateFunctypes(pkg *ssa.Package) *ast.File {
}
}

func generateFunctypesInit(path string, init *ast.BlockStmt, fn *ssa.Function) {
func generateFunctypesInit(pkg *ssa.Package, fn *ssa.Function, init *ast.BlockStmt, name string, colors functionColors) {
if fn.TypeParams() != nil {
return // ignore non-instantiated generic functions
}
Expand All @@ -68,7 +69,7 @@ func generateFunctypesInit(path string, init *ast.BlockStmt, fn *ssa.Function) {
Args: []ast.Expr{
&ast.BasicLit{
Kind: token.STRING,
Value: strconv.Quote(path + "." + fn.Name()),
Value: strconv.Quote(name),
},
},
},
Expand All @@ -79,7 +80,24 @@ func generateFunctypesInit(path string, init *ast.BlockStmt, fn *ssa.Function) {
return cmp.Compare(f1.Name(), f2.Name())
})

for _, anonFunc := range anonFuncs {
generateFunctypesInit(path, init, anonFunc)
for index, anonFunc := range anonFuncs {
_, colored := colors[anonFunc]
if colored {
// Colored functions (those rewritten into coroutines) have a
// deferred anonymous function injected at the beginning to perform
// stack unwinding, which takes the ".func1" name.
index++
}
name = anonFuncLinkName(name, index)
generateFunctypesInit(pkg, anonFunc, init, name, colors)
}
}

// This function computes the name that the linker gives to anonymous functions,
// using the base name of their parent function and appending ".func<index>".
//
// The function works with multiple levels of nesting as each level adds another
// ".func<index>" suffix, with the index being local to the parent scope.
func anonFuncLinkName(base string, index int) string {
return fmt.Sprintf("%s.func%d", base, index+1)
}
13 changes: 13 additions & 0 deletions compiler/testdata/coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,19 @@ func Double(n int) {
coroutine.Yield[int, any](2 * n)
}

// func RangeTriple(n int) {
// Range(n, func(i int) {
// coroutine.Yield[int, any](3 * i)
// })
// }

func RangeTripleFuncValue(n int) {
f := func(i int) {
coroutine.Yield[int, any](3 * i)
}
Range(n, f)
}

func Select(n int) {
select {
default:
Expand Down
29 changes: 29 additions & 0 deletions compiler/testdata/coroutine_durable.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions compiler/testdata/coroutine_functypes.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 5 additions & 6 deletions compiler/unsupported.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,17 @@ import (
)

// unsupported checks a function for unsupported language features.
func unsupported(decl *ast.FuncDecl, info *types.Info) (err error) {
func unsupported(decl ast.Node, info *types.Info) (err error) {
ast.Inspect(decl, func(node ast.Node) bool {
switch nn := node.(type) {
case ast.Expr:
switch nn.(type) {
case *ast.FuncLit:
err = fmt.Errorf("not implemented: func literals")
}
if countFunctionCalls(nn, info) > 1 {
err = fmt.Errorf("not implemented: multiple function calls in an expression")
default:
if countFunctionCalls(nn, info) > 1 {
err = fmt.Errorf("not implemented: multiple function calls in an expression")
}
}

case ast.Stmt:
switch n := nn.(type) {
// Not yet supported:
Expand Down

0 comments on commit 52df481

Please sign in to comment.