Skip to content

Commit

Permalink
Fix empty return with named results
Browse files Browse the repository at this point in the history
When a function has named results and uses empty return, the named result would
correctly be renamed and hoisted into the frame struct. In that scenario, using
an explicit return (such as "return myvalue") works, because the renaming pass
goes over the expression in the return statement and replaces the named result
with a selector into the frame struct.

However, when using an empty return (just "return"), the named result is not in
the AST, and the renaming pass would not see it.

This patch add and final pass to the renaming procedure to change empty returns
in functions with named results into explicit returns using selectors into the
frame struct. For example:

    func example() (out int) {
        out = 42
        return
    }

gets rewritten into:

    func example() (_fn0 int) {
        var _f0 *struct {
            IP int
            X0 int
        }
	_f0.X0 = 42
	return _f0.X0
    }

(other transformations omitted for this example)
  • Loading branch information
pelletier committed Nov 23, 2023
1 parent 7c96207 commit 5412894
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 5 deletions.
4 changes: 2 additions & 2 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body
// assignments that use := to assignments that use =. Constant decls are
// hoisted and also have their value assigned in the function prologue.
decls, frameType, frameInit := extractDecls(p, typ, body, recv, defers, p.TypesInfo)
renameObjects(body, p.TypesInfo, decls, frameName, frameType, frameInit, scope)
renameObjects(typ, body, p.TypesInfo, decls, frameName, frameType, frameInit, scope)

// var _f{n} F = coroutine.Push[F](&_c.Stack)
gen.List = append(gen.List, &ast.DeclStmt{Decl: &ast.GenDecl{
Expand Down Expand Up @@ -633,7 +633,7 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body
gen.List = append(gen.List, compiledBody.List...)

// If the function returns one or more values, it must end with a return statement;
// we inject it if the function body does not already has one.
// we inject it if the function body does not already have one.
if typ.Results != nil && len(typ.Results.List) > 0 {
needsReturn := len(gen.List) == 0
if !needsReturn {
Expand Down
7 changes: 7 additions & 0 deletions compiler/coroutine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,13 @@ func TestCoroutineYield(t *testing.T) {
yields: []int{1, 2, 3, 2, 4, 6, 3, 6, 9, 2, 4, 6, 4, 8, 12, 6, 12, 18, 3, 6, 9, 6, 12, 18, 9, 18, 27},
result: 27,
},

{
name: "return named values",
coroR: func() int { return ReturnNamedValue() },
yields: []int{11},
result: 42,
},
}

// This emulates the installation of function type information by the
Expand Down
55 changes: 52 additions & 3 deletions compiler/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func extractDecls(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, r
// renameObjects renames types, constants and variables declared within
// a function. Each is given a unique name, so that declarations are safe
// to hoist into the function prologue.
func renameObjects(tree ast.Node, info *types.Info, decls []*ast.GenDecl, frameName *ast.Ident, frameType *ast.StructType, frameInit *ast.CompositeLit, scope *scope) {
func renameObjects(fntype *ast.FuncType, tree ast.Node, info *types.Info, decls []*ast.GenDecl, frameName *ast.Ident, frameType *ast.StructType, frameInit *ast.CompositeLit, scope *scope) {
// Scan decls to find objects, giving each new object a unique name.
names := make(map[types.Object]*ast.Ident, len(decls))
selectors := make(map[types.Object]*ast.SelectorExpr, len(frameType.Fields.List))
Expand Down Expand Up @@ -238,7 +238,7 @@ func renameObjects(tree ast.Node, info *types.Info, decls []*ast.GenDecl, frameN
// replacing if they are removed from the tree too early.
//
// Note that replacing identifiers is a recursive operation which traverses
// function literls.
// function literals.

astutil.Apply(tree,
func(cursor *astutil.Cursor) bool {
Expand Down Expand Up @@ -326,9 +326,55 @@ func renameObjects(tree ast.Node, info *types.Info, decls []*ast.GenDecl, frameN
},
nil,
)

// Perform a last pass to assigned named results before unnamed. It cannot
// be done in the renaming pass because it should not recurse into function
// literals, which the renaming pass does.
if hasNamedResults(fntype) {
astutil.Apply(tree,
func(cursor *astutil.Cursor) bool {
switch n := cursor.Node().(type) {
case *ast.FuncLit:
return false
case *ast.ReturnStmt:
if len(n.Results) > 0 {
return true
}

// Transform
// return
// into
// return (selector1), (selector2)...
for _, t := range fntype.Results.List {
ident := t.Names[0]
obj := info.ObjectOf(ident)
n.Results = append(n.Results, selectors[obj])
}
}

return true
}, nil)
}
}

func renameFuncRecvParamsResults(typ *ast.FuncType, recv *ast.FieldList, body *ast.BlockStmt, info *types.Info) {
func hasNamedResults(t *ast.FuncType) bool {
if t.Results == nil || len(t.Results.List) == 0 {
return false
}

for _, result := range t.Results.List {
for _, name := range result.Names {
if name == nil || name.Name == "" || name.Name == "_" {
continue
}
return true
}
}
return false
}

func renameFuncRecvParamsResults(typ *ast.FuncType, recv *ast.FieldList, body *ast.BlockStmt, info *types.Info) []*ast.Ident {
var namedResults []*ast.Ident
names := map[types.Object]*ast.Ident{}

fieldLists := []*ast.FieldList{recv, typ.Params, typ.Results}
Expand All @@ -345,6 +391,7 @@ func renameFuncRecvParamsResults(typ *ast.FuncType, recv *ast.FieldList, body *a
newIdent := ast.NewIdent("_fn" + strconv.Itoa(len(names)))
names[obj] = newIdent
info.Defs[newIdent] = obj
namedResults = append(namedResults, newIdent)
}
}
}
Expand All @@ -366,4 +413,6 @@ func renameFuncRecvParamsResults(typ *ast.FuncType, recv *ast.FieldList, body *a
}
return true
}, nil)

return namedResults
}
7 changes: 7 additions & 0 deletions compiler/testdata/coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -550,3 +550,10 @@ func varArgs(args ...int) {
coroutine.Yield[int, any](arg)
}
}

func ReturnNamedValue() (out int) {
out = 5
coroutine.Yield[int, any](11)
out = 42
return
}
41 changes: 41 additions & 0 deletions compiler/testdata/coroutine_durable.go
Original file line number Diff line number Diff line change
Expand Up @@ -3196,6 +3196,46 @@ func varArgs(_fn0 ...int) {
}
}
}

//go:noinline
func ReturnNamedValue() (_fn0 int) {
_c := coroutine.LoadContext[int, any]()
var _f0 *struct {
IP int
X0 int
} = coroutine.Push[struct {
IP int
X0 int
}](&_c.Stack)
if _f0.IP == 0 {
*_f0 = struct {
IP int
X0 int
}{}
}
defer func() {
if !_c.Unwinding() {
coroutine.Pop(&_c.Stack)
}
}()
switch {
case _f0.IP < 2:
_f0.X0 = 5
_f0.IP = 2
fallthrough
case _f0.IP < 3:
coroutine.Yield[int, any](11)
_f0.IP = 3
fallthrough
case _f0.IP < 4:
_f0.X0 = 42
_f0.IP = 4
fallthrough
case _f0.IP < 5:
return _f0.X0
}
return
}
func init() {
_types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.Double")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.EvenSquareGenerator")
Expand Down Expand Up @@ -3292,6 +3332,7 @@ func init() {
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeTripleFuncValue")
_types.RegisterFunc[func(i int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeTripleFuncValue.func2")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeYieldAndDeferAssign")
_types.RegisterFunc[func() (_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.ReturnNamedValue")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.Select")
_types.RegisterFunc[func(_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.Shadowing")
_types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.SomeFunctionThatShouldExistInTheCompiledFile")
Expand Down

0 comments on commit 5412894

Please sign in to comment.