From 145d380bd199951f3f942085939f0c515307f8fc Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Wed, 20 Sep 2023 20:51:59 -0700 Subject: [PATCH] fix: compile non-coroutine functions Signed-off-by: Achille Roussel --- compiler/compile.go | 35 ++++++++++++------------ compiler/coroutine_test.go | 6 ++++ compiler/testdata/coroutine.go | 3 ++ compiler/testdata/coroutine_durable.go | 3 ++ compiler/testdata/coroutine_functypes.go | 1 + 5 files changed, 31 insertions(+), 17 deletions(-) diff --git a/compiler/compile.go b/compiler/compile.go index abde5a1..2f796ff 100644 --- a/compiler/compile.go +++ b/compiler/compile.go @@ -242,6 +242,7 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors, pr } color, ok := colorsByDecl[decl] if !ok { + gen.Decls = append(gen.Decls, decl) continue } // Reject certain language features for now. @@ -322,6 +323,23 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body yieldTypeExpr[0] = typeExpr(color.Params().At(0).Type()) yieldTypeExpr[1] = typeExpr(color.Results().At(0).Type()) + body = astutil.Apply(body, + func(cursor *astutil.Cursor) bool { + switch n := cursor.Node().(type) { + case *ast.FuncLit: + color, ok := scope.colors[n] + if ok { + cursor.Replace(scope.compileFuncLit(p, n, color)) + } + } + return true + }, + nil, + ).(*ast.BlockStmt) + + // Desugar statements in the tree. + body = desugar(body, p.TypesInfo).(*ast.BlockStmt) + // _c := coroutine.LoadContext[R, S]() gen.List = append(gen.List, &ast.AssignStmt{ Lhs: []ast.Expr{ctx}, @@ -350,23 +368,6 @@ func (scope *scope) compileFuncBody(p *packages.Package, typ *ast.FuncType, body }, }) - body = astutil.Apply(body, - func(cursor *astutil.Cursor) bool { - switch n := cursor.Node().(type) { - case *ast.FuncLit: - color, ok := scope.colors[n] - if ok { - cursor.Replace(scope.compileFuncLit(p, n, color)) - } - } - return true - }, - nil, - ).(*ast.BlockStmt) - - // Desugar statements in the tree. - body = desugar(body, p.TypesInfo).(*ast.BlockStmt) - // Handle declarations. // // Types, constants and variables can be defined within any scope in the diff --git a/compiler/coroutine_test.go b/compiler/coroutine_test.go index 08260d5..8fd3e55 100644 --- a/compiler/coroutine_test.go +++ b/compiler/coroutine_test.go @@ -9,6 +9,12 @@ import ( "github.com/stealthrocket/coroutine/types" ) +func init() { + // Breaks if the compiler did not retain simple top-level functions in the + // output file. + SomeFunctionThatShouldExistInTheCompiledFile() +} + func TestCoroutineYield(t *testing.T) { tests := []struct { name string diff --git a/compiler/testdata/coroutine.go b/compiler/testdata/coroutine.go index 7b45ae9..89114ab 100644 --- a/compiler/testdata/coroutine.go +++ b/compiler/testdata/coroutine.go @@ -11,6 +11,9 @@ import ( //go:generate coroc --output coroutine_durable.go --tags durable +func SomeFunctionThatShouldExistInTheCompiledFile() { +} + func Identity(n int) { coroutine.Yield[int, any](n) } diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index bace287..d946cdb 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -15,6 +15,9 @@ import ( unsafe "unsafe" ) +func SomeFunctionThatShouldExistInTheCompiledFile() { +} + func Identity(n int) { _c := coroutine.LoadContext[int, any]() _f, _fp := _c.Push() diff --git a/compiler/testdata/coroutine_functypes.go b/compiler/testdata/coroutine_functypes.go index 0208a7f..c67425e 100644 --- a/compiler/testdata/coroutine_functypes.go +++ b/compiler/testdata/coroutine_functypes.go @@ -56,6 +56,7 @@ func init() { _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeTripleFuncValue.func2") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.Select") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.Shadowing") + _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.SomeFunctionThatShouldExistInTheCompiledFile") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGenerator") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGeneratorTwice") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGeneratorTwiceLoop")