Skip to content

Commit

Permalink
simplify single-expression coroutines
Browse files Browse the repository at this point in the history
Signed-off-by: Achille Roussel <[email protected]>
  • Loading branch information
achille-roussel committed Sep 21, 2023
1 parent dcc5018 commit acb9842
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 41 deletions.
90 changes: 73 additions & 17 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr
},
})

ssaFnsByDecl := map[ast.Node]*ssa.Function{}
colorsByDecl := map[ast.Node]*types.Signature{}
for fn, color := range colors {
decl := fn.Syntax()
Expand All @@ -231,6 +232,7 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr
default:
return fmt.Errorf("unsupported yield function %s (Syntax is %T, not *ast.FuncDecl or *ast.FuncLit)", fn, decl)
}
ssaFnsByDecl[decl] = fn
colorsByDecl[decl] = color
}

Expand All @@ -249,11 +251,22 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr
if err := unsupported(decl, p.TypesInfo); err != nil {
return err
}

scope := &scope{
colors: colorsByDecl,
objectIdent: 0,
}
gen.Decls = append(gen.Decls, scope.compileFuncDecl(p, decl, color))

compiledFunction := scope.compileFuncDecl(p, decl, color)
if functionBodyIsExpr(compiledFunction.Body) {
// If the function has a single expression it does not contain
// a deferred closure so we drop it from the list of colored
// functions so generateFunctypes does not mistakenly increment
// the local symbol counter when generating closure names.
delete(colors, ssaFnsByDecl[decl])
}

gen.Decls = append(gen.Decls, compiledFunction)
}
}

Expand Down Expand Up @@ -298,11 +311,56 @@ func (scope *scope) compileFuncDecl(p *packages.Package, fn *ast.FuncDecl, color
// Generate the coroutine function. At this stage, use the same name
// as the source function (and require that the caller use build tags
// to disambiguate function calls).
return &ast.FuncDecl{
gen := &ast.FuncDecl{
Doc: fn.Doc,
Name: fn.Name,
Type: funcTypeWithNamedResults(fn.Type),
Body: scope.compileFuncBody(p, fn.Type, fn.Body, color),
}

if functionBodyIsExpr(gen.Body) {
// If the function declaration contains function literals, we have to
// add the //go:noinline copmiler directive to prevent inlining or the
// resulting symbol name generated by the linker wouldn't match the
// predictions made in generateFunctypes.
//
// When functions are inlined, the linker creates a unique name
// combining the symbol name of the calling function and the symbol name
// of the closure. Knowing which functions will be inlined is difficult
// considering the score-base mechansim that Go uses and alterations
// like PGO, therefore we take the simple approach of disabling inlining
// instead.
//
// Note that we only need to do this for single-expression functions as
// otherwise the presence of a defer statement to unwind the coroutine
// already prevents inlining.
hasFuncLit := false

ast.Inspect(fn, func(n ast.Node) bool {
switch n.(type) {
case *ast.FuncLit:
hasFuncLit = true
}
return !hasFuncLit
})

if hasFuncLit {
gen.Doc = new(ast.CommentGroup)
if fn.Doc != nil {
gen.Doc.List = append(gen.Doc.List, fn.Doc.List...)
}
if len(gen.Doc.List) > 0 {
gen.Doc.List = append(gen.Doc.List, &ast.Comment{
Text: "//\n",
})
}
gen.Doc.List = append(gen.Doc.List, &ast.Comment{
Text: "//go:noinline\n",
})
}
}

return gen
}

func (scope *scope) compileFuncLit(p *packages.Package, fn *ast.FuncLit, color *types.Signature) *ast.FuncLit {
Expand All @@ -314,19 +372,7 @@ func (scope *scope) compileFuncLit(p *packages.Package, fn *ast.FuncLit, color *
}

func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, color *types.Signature) *ast.BlockStmt {
if functionBodyIsExpr(body) {
return body
}

gen := new(ast.BlockStmt)
ctx := ast.NewIdent("_c")
frame := ast.NewIdent("_f")
fp := ast.NewIdent("_fp")

yieldTypeExpr := make([]ast.Expr, 2)
yieldTypeExpr[0] = typeExpr(color.Params().At(0).Type())
yieldTypeExpr[1] = typeExpr(color.Results().At(0).Type())

body = desugar(body, p.TypesInfo).(*ast.BlockStmt)
body = astutil.Apply(body,
func(cursor *astutil.Cursor) bool {
switch n := cursor.Node().(type) {
Expand All @@ -341,8 +387,18 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body
nil,
).(*ast.BlockStmt)

// Desugar statements in the tree.
body = desugar(body, p.TypesInfo).(*ast.BlockStmt)
if functionBodyIsExpr(body) {
return body
}

gen := new(ast.BlockStmt)
ctx := ast.NewIdent("_c")
frame := ast.NewIdent("_f")
fp := ast.NewIdent("_fp")

yieldTypeExpr := make([]ast.Expr, 2)
yieldTypeExpr[0] = typeExpr(color.Params().At(0).Type())
yieldTypeExpr[1] = typeExpr(color.Results().At(0).Type())

// _c := coroutine.LoadContext[R, S]()
gen.List = append(gen.List, &ast.AssignStmt{
Expand Down
13 changes: 1 addition & 12 deletions compiler/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,7 @@ func generateFunctypesInit(pkg *ssa.Package, fn *ssa.Function, init *ast.BlockSt
// stack unwinding, which takes the ".func1" name.
_, colored := colors[fn]
if colored {
// We skip functions that only have one expression because they
// are not transformed.
var body *ast.BlockStmt
switch d := fn.Syntax().(type) {
case *ast.FuncDecl:
body = d.Body
case *ast.FuncLit:
body = d.Body
}
if !functionBodyIsExpr(body) {
index++
}
index++
}

for _, anonFunc := range anonFuncs {
Expand Down
18 changes: 6 additions & 12 deletions compiler/testdata/coroutine_durable.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit acb9842

Please sign in to comment.