Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix empty return with named results #120

Merged
merged 2 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body
// assignments that use := to assignments that use =. Constant decls are
// hoisted and also have their value assigned in the function prologue.
decls, frameType, frameInit := extractDecls(p, typ, body, recv, defers, p.TypesInfo)
renameObjects(body, p.TypesInfo, decls, frameName, frameType, frameInit, scope)
renameObjects(typ, body, p.TypesInfo, decls, frameName, frameType, frameInit, scope)

// var _f{n} F = coroutine.Push[F](&_c.Stack)
gen.List = append(gen.List, &ast.DeclStmt{Decl: &ast.GenDecl{
Expand Down Expand Up @@ -633,7 +633,7 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body
gen.List = append(gen.List, compiledBody.List...)

// If the function returns one or more values, it must end with a return statement;
// we inject it if the function body does not already has one.
// we inject it if the function body does not already have one.
if typ.Results != nil && len(typ.Results.List) > 0 {
needsReturn := len(gen.List) == 0
if !needsReturn {
Expand Down
7 changes: 7 additions & 0 deletions compiler/coroutine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,13 @@ func TestCoroutineYield(t *testing.T) {
yields: []int{1, 2, 3, 2, 4, 6, 3, 6, 9, 2, 4, 6, 4, 8, 12, 6, 12, 18, 3, 6, 9, 6, 12, 18, 9, 18, 27},
result: 27,
},

{
name: "return named values",
coroR: func() int { return ReturnNamedValue() },
yields: []int{11},
result: 42,
},
}

// This emulates the installation of function type information by the
Expand Down
55 changes: 52 additions & 3 deletions compiler/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func extractDecls(p *packages.Package, typ *ast.FuncType, body *ast.BlockStmt, r
// renameObjects renames types, constants and variables declared within
// a function. Each is given a unique name, so that declarations are safe
// to hoist into the function prologue.
func renameObjects(tree ast.Node, info *types.Info, decls []*ast.GenDecl, frameName *ast.Ident, frameType *ast.StructType, frameInit *ast.CompositeLit, scope *scope) {
func renameObjects(fntype *ast.FuncType, tree ast.Node, info *types.Info, decls []*ast.GenDecl, frameName *ast.Ident, frameType *ast.StructType, frameInit *ast.CompositeLit, scope *scope) {
// Scan decls to find objects, giving each new object a unique name.
names := make(map[types.Object]*ast.Ident, len(decls))
selectors := make(map[types.Object]*ast.SelectorExpr, len(frameType.Fields.List))
Expand Down Expand Up @@ -238,7 +238,7 @@ func renameObjects(tree ast.Node, info *types.Info, decls []*ast.GenDecl, frameN
// replacing if they are removed from the tree too early.
//
// Note that replacing identifiers is a recursive operation which traverses
// function literls.
// function literals.

astutil.Apply(tree,
func(cursor *astutil.Cursor) bool {
Expand Down Expand Up @@ -326,9 +326,55 @@ func renameObjects(tree ast.Node, info *types.Info, decls []*ast.GenDecl, frameN
},
nil,
)

// Perform a last pass to assigned named results before unnamed. It cannot
// be done in the renaming pass because it should not recurse into function
// literals, which the renaming pass does.
if hasNamedResults(fntype) {
astutil.Apply(tree,
func(cursor *astutil.Cursor) bool {
switch n := cursor.Node().(type) {
case *ast.FuncLit:
return false
case *ast.ReturnStmt:
if len(n.Results) > 0 {
return true
}

// Transform
// return
// into
// return (selector1), (selector2)...
for _, t := range fntype.Results.List {
ident := t.Names[0]
obj := info.ObjectOf(ident)
n.Results = append(n.Results, selectors[obj])
}
}

return true
}, nil)
}
}

func renameFuncRecvParamsResults(typ *ast.FuncType, recv *ast.FieldList, body *ast.BlockStmt, info *types.Info) {
func hasNamedResults(t *ast.FuncType) bool {
if t.Results == nil || len(t.Results.List) == 0 {
return false
}

for _, result := range t.Results.List {
for _, name := range result.Names {
if name == nil || name.Name == "" || name.Name == "_" {
continue
}
return true
}
}
return false
}

func renameFuncRecvParamsResults(typ *ast.FuncType, recv *ast.FieldList, body *ast.BlockStmt, info *types.Info) []*ast.Ident {
var namedResults []*ast.Ident
names := map[types.Object]*ast.Ident{}

fieldLists := []*ast.FieldList{recv, typ.Params, typ.Results}
Expand All @@ -345,6 +391,7 @@ func renameFuncRecvParamsResults(typ *ast.FuncType, recv *ast.FieldList, body *a
newIdent := ast.NewIdent("_fn" + strconv.Itoa(len(names)))
names[obj] = newIdent
info.Defs[newIdent] = obj
namedResults = append(namedResults, newIdent)
}
}
}
Expand All @@ -366,4 +413,6 @@ func renameFuncRecvParamsResults(typ *ast.FuncType, recv *ast.FieldList, body *a
}
return true
}, nil)

return namedResults
}
7 changes: 7 additions & 0 deletions compiler/testdata/coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -550,3 +550,10 @@ func varArgs(args ...int) {
coroutine.Yield[int, any](arg)
}
}

func ReturnNamedValue() (out int) {
out = 5
coroutine.Yield[int, any](11)
out = 42
return
}
41 changes: 41 additions & 0 deletions compiler/testdata/coroutine_durable.go
Original file line number Diff line number Diff line change
Expand Up @@ -3196,6 +3196,46 @@ func varArgs(_fn0 ...int) {
}
}
}

//go:noinline
func ReturnNamedValue() (_fn0 int) {
_c := coroutine.LoadContext[int, any]()
var _f0 *struct {
IP int
X0 int
} = coroutine.Push[struct {
IP int
X0 int
}](&_c.Stack)
if _f0.IP == 0 {
*_f0 = struct {
IP int
X0 int
}{}
}
defer func() {
if !_c.Unwinding() {
coroutine.Pop(&_c.Stack)
}
}()
switch {
case _f0.IP < 2:
_f0.X0 = 5
_f0.IP = 2
fallthrough
case _f0.IP < 3:
coroutine.Yield[int, any](11)
_f0.IP = 3
fallthrough
case _f0.IP < 4:
_f0.X0 = 42
_f0.IP = 4
fallthrough
case _f0.IP < 5:
return _f0.X0
}
return
pelletier marked this conversation as resolved.
Show resolved Hide resolved
}
func init() {
_types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.Double")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.EvenSquareGenerator")
Expand Down Expand Up @@ -3292,6 +3332,7 @@ func init() {
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeTripleFuncValue")
_types.RegisterFunc[func(i int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeTripleFuncValue.func2")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeYieldAndDeferAssign")
_types.RegisterFunc[func() (_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.ReturnNamedValue")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.Select")
_types.RegisterFunc[func(_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.Shadowing")
_types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.SomeFunctionThatShouldExistInTheCompiledFile")
Expand Down
Loading