From acb9842df26583467daf241c123b00d250e8c508 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Wed, 20 Sep 2023 21:41:17 -0700 Subject: [PATCH] simplify single-expression coroutines Signed-off-by: Achille Roussel --- compiler/compile.go | 90 +++++++++++++++++++++----- compiler/function.go | 13 +--- compiler/testdata/coroutine_durable.go | 18 ++---- 3 files changed, 80 insertions(+), 41 deletions(-) diff --git a/compiler/compile.go b/compiler/compile.go index 9affc7b..1b0f4ec 100644 --- a/compiler/compile.go +++ b/compiler/compile.go @@ -222,6 +222,7 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr }, }) + ssaFnsByDecl := map[ast.Node]*ssa.Function{} colorsByDecl := map[ast.Node]*types.Signature{} for fn, color := range colors { decl := fn.Syntax() @@ -231,6 +232,7 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr default: return fmt.Errorf("unsupported yield function %s (Syntax is %T, not *ast.FuncDecl or *ast.FuncLit)", fn, decl) } + ssaFnsByDecl[decl] = fn colorsByDecl[decl] = color } @@ -249,11 +251,22 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr if err := unsupported(decl, p.TypesInfo); err != nil { return err } + scope := &scope{ colors: colorsByDecl, objectIdent: 0, } - gen.Decls = append(gen.Decls, scope.compileFuncDecl(p, decl, color)) + + compiledFunction := scope.compileFuncDecl(p, decl, color) + if functionBodyIsExpr(compiledFunction.Body) { + // If the function has a single expression it does not contain + // a deferred closure so we drop it from the list of colored + // functions so generateFunctypes does not mistakenly increment + // the local symbol counter when generating closure names. + delete(colors, ssaFnsByDecl[decl]) + } + + gen.Decls = append(gen.Decls, compiledFunction) } } @@ -298,11 +311,56 @@ func (scope *scope) compileFuncDecl(p *packages.Package, fn *ast.FuncDecl, color // Generate the coroutine function. At this stage, use the same name // as the source function (and require that the caller use build tags // to disambiguate function calls). - return &ast.FuncDecl{ + gen := &ast.FuncDecl{ + Doc: fn.Doc, Name: fn.Name, Type: funcTypeWithNamedResults(fn.Type), Body: scope.compileFuncBody(p, fn.Type, fn.Body, color), } + + if functionBodyIsExpr(gen.Body) { + // If the function declaration contains function literals, we have to + // add the //go:noinline copmiler directive to prevent inlining or the + // resulting symbol name generated by the linker wouldn't match the + // predictions made in generateFunctypes. + // + // When functions are inlined, the linker creates a unique name + // combining the symbol name of the calling function and the symbol name + // of the closure. Knowing which functions will be inlined is difficult + // considering the score-base mechansim that Go uses and alterations + // like PGO, therefore we take the simple approach of disabling inlining + // instead. + // + // Note that we only need to do this for single-expression functions as + // otherwise the presence of a defer statement to unwind the coroutine + // already prevents inlining. + hasFuncLit := false + + ast.Inspect(fn, func(n ast.Node) bool { + switch n.(type) { + case *ast.FuncLit: + hasFuncLit = true + } + return !hasFuncLit + }) + + if hasFuncLit { + gen.Doc = new(ast.CommentGroup) + if fn.Doc != nil { + gen.Doc.List = append(gen.Doc.List, fn.Doc.List...) + } + if len(gen.Doc.List) > 0 { + gen.Doc.List = append(gen.Doc.List, &ast.Comment{ + Text: "//\n", + }) + } + gen.Doc.List = append(gen.Doc.List, &ast.Comment{ + Text: "//go:noinline\n", + }) + } + } + + return gen } func (scope *scope) compileFuncLit(p *packages.Package, fn *ast.FuncLit, color *types.Signature) *ast.FuncLit { @@ -314,19 +372,7 @@ func (scope *scope) compileFuncLit(p *packages.Package, fn *ast.FuncLit, color * } func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, color *types.Signature) *ast.BlockStmt { - if functionBodyIsExpr(body) { - return body - } - - gen := new(ast.BlockStmt) - ctx := ast.NewIdent("_c") - frame := ast.NewIdent("_f") - fp := ast.NewIdent("_fp") - - yieldTypeExpr := make([]ast.Expr, 2) - yieldTypeExpr[0] = typeExpr(color.Params().At(0).Type()) - yieldTypeExpr[1] = typeExpr(color.Results().At(0).Type()) - + body = desugar(body, p.TypesInfo).(*ast.BlockStmt) body = astutil.Apply(body, func(cursor *astutil.Cursor) bool { switch n := cursor.Node().(type) { @@ -341,8 +387,18 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body nil, ).(*ast.BlockStmt) - // Desugar statements in the tree. - body = desugar(body, p.TypesInfo).(*ast.BlockStmt) + if functionBodyIsExpr(body) { + return body + } + + gen := new(ast.BlockStmt) + ctx := ast.NewIdent("_c") + frame := ast.NewIdent("_f") + fp := ast.NewIdent("_fp") + + yieldTypeExpr := make([]ast.Expr, 2) + yieldTypeExpr[0] = typeExpr(color.Params().At(0).Type()) + yieldTypeExpr[1] = typeExpr(color.Results().At(0).Type()) // _c := coroutine.LoadContext[R, S]() gen.List = append(gen.List, &ast.AssignStmt{ diff --git a/compiler/function.go b/compiler/function.go index e49d489..2279dba 100644 --- a/compiler/function.go +++ b/compiler/function.go @@ -120,18 +120,7 @@ func generateFunctypesInit(pkg *ssa.Package, fn *ssa.Function, init *ast.BlockSt // stack unwinding, which takes the ".func1" name. _, colored := colors[fn] if colored { - // We skip functions that only have one expression because they - // are not transformed. - var body *ast.BlockStmt - switch d := fn.Syntax().(type) { - case *ast.FuncDecl: - body = d.Body - case *ast.FuncLit: - body = d.Body - } - if !functionBodyIsExpr(body) { - index++ - } + index++ } for _, anonFunc := range anonFuncs { diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index 04156ed..309969f 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -18,9 +18,7 @@ import ( func SomeFunctionThatShouldExistInTheCompiledFile() { } -func Identity(n int) { - coroutine.Yield[int, any](n) -} +func Identity(n int) { coroutine.Yield[int, any](n) } func SquareGenerator(n int) { _c := coroutine.LoadContext[int, any]() @@ -2236,14 +2234,11 @@ func Range(n int, do func(int)) { } } -func Double(n int) { - coroutine.Yield[int, any](2 * n) -} +func Double(n int) { coroutine.Yield[int, any](2 * n) } +//go:noinline func RangeTriple(n int) { - Range(n, func(i int) { - coroutine.Yield[int, any](3 * i) - }) + Range(n, func(i int) { coroutine.Yield[int, any](3 * i) }) } func RangeTripleFuncValue(n int) { @@ -2269,12 +2264,11 @@ func RangeTripleFuncValue(n int) { }() switch { case _f.IP < 2: - _o0 = func(i int) { - coroutine.Yield[int, any](3 * i) - } + _o0 = func(i int) { coroutine.Yield[int, any](3 * i) } _f.IP = 2 fallthrough case _f.IP < 3: + Range(n, _o0) } }