Skip to content

Commit

Permalink
generate serializable function
Browse files Browse the repository at this point in the history
Signed-off-by: Achille Roussel <[email protected]>
  • Loading branch information
achille-roussel committed Sep 19, 2023
1 parent f5ad6ab commit 643f2d5
Show file tree
Hide file tree
Showing 11 changed files with 361 additions and 61 deletions.
2 changes: 1 addition & 1 deletion compiler/Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
GO ?= go

TARGET = testdata/coroutine_durable.go
TARGET = testdata/coroutine_durable.go testdata/coroutine_functypes.go

test: clean generate
$(GO) test ./...
Expand Down
57 changes: 33 additions & 24 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,42 @@ func (c *compiler) compile(path string) error {
}

for p, colors := range colorsByPkg {
if err := c.compilePackage(p, colors); err != nil {
if err := c.compilePackage(p, colors, prog); err != nil {
return err
}
}

log.Printf("done")

return nil
}

func (c *compiler) compilePackage(p *packages.Package, colors functionColors) error {
func (c *compiler) writeFile(path string, file *ast.File) error {
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
// Comments are awkward to attach to the tree (they rely on token.Pos, which
// is coupled to a token.FileSet). Instead, just write out the raw strings.
var b strings.Builder
b.WriteString(`// Code generated by coroc. DO NOT EDIT`)
b.WriteString("\n\n")
if c.buildTags != "" {
b.WriteString(`//go:build `)
b.WriteString(c.buildTags)
b.WriteString("\n\n")
}
if _, err := f.WriteString(b.String()); err != nil {
return err
}
// Format/write the remainder of the AST.
if err := format.Node(f, c.fset, file); err != nil {
return err
}
return f.Close()
}

func (c *compiler) compilePackage(p *packages.Package, colors functionColors, prog *ssa.Program) error {
log.Printf("compiling package %s", p.Name)

// Generate the coroutine AST.
Expand Down Expand Up @@ -210,35 +235,19 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors) er
return err
}

// Get ready to write.
packageDir := filepath.Dir(p.GoFiles[0])
outputPath := filepath.Join(packageDir, c.outputFilename)
outputFile, err := os.Create(outputPath)
if err != nil {
return fmt.Errorf("os.Create %q: %w", outputPath, err)
}
defer outputFile.Close()

// Comments are awkward to attach to the tree (they rely on token.Pos, which
// is coupled to a token.FileSet). Instead, just write out the raw strings.
var b strings.Builder
b.WriteString(`// Code generated by coroc. DO NOT EDIT`)
b.WriteString("\n\n")
if c.buildTags != "" {
b.WriteString(`//go:build `)
b.WriteString(c.buildTags)
b.WriteString("\n\n")
}
if _, err := outputFile.WriteString(b.String()); err != nil {
if err := c.writeFile(outputPath, gen); err != nil {
return err
}

// Format/write the remainder of the AST.
if err := format.Node(outputFile, c.fset, gen); err != nil {
functypesFile := generateFunctypes(prog.Package(p.Types))
functypesPath := filepath.Join(packageDir, "coroutine_functypes.go")
if err := c.writeFile(functypesPath, functypesFile); err != nil {
return err
}

return outputFile.Close()
return nil
}

func (c *compiler) compileFunction(p *packages.Package, fn *ast.FuncDecl, color *types.Signature) *ast.FuncDecl {
Expand Down
14 changes: 9 additions & 5 deletions compiler/coroutine_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package compiler

import (
"reflect"
"slices"
"testing"

Expand Down Expand Up @@ -99,16 +98,21 @@ func TestCoroutineYield(t *testing.T) {
coro: func() { RangeOverMaps(5) },
yields: []int{0, 5, 5, 50, 5, 4, 3, 2, 1, 0},
},

{
name: "range over function",
coro: func() { Range(10, Double) },
yields: []int{0, 2, 4, 6, 8, 10, 12, 14, 16, 18},
},
}

// TODO: remove me
//
// This emulates the installation of function type information by the
// compiler until we have codegen for it.
// compiler because we are not doing codegen for the test files in this
// package.
for _, test := range tests {
a := types.FuncAddr(test.coro)
f := types.FuncByAddr(a)
f.Type = reflect.TypeOf(func() {})
types.RegisterFunc[func()](f.Name)
}

for _, test := range tests {
Expand Down
85 changes: 85 additions & 0 deletions compiler/function.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package compiler

import (
"cmp"
"go/ast"
"go/token"
"slices"
"strconv"

"golang.org/x/tools/go/ssa"
)

func generateFunctypes(pkg *ssa.Package) *ast.File {
var names = make([]string, 0, len(pkg.Members))
for name := range pkg.Members {
names = append(names, name)
}
slices.Sort(names)

var init ast.BlockStmt
var path = pkg.Pkg.Path()
for _, name := range names {
if fn, ok := pkg.Members[name].(*ssa.Function); ok {
generateFunctypesInit(path, &init, fn)
}
}

return &ast.File{
Name: ast.NewIdent(pkg.Pkg.Name()),
Decls: []ast.Decl{
&ast.GenDecl{
Tok: token.IMPORT,
Specs: []ast.Spec{
&ast.ImportSpec{
Name: ast.NewIdent("_types"),
Path: &ast.BasicLit{
Kind: token.STRING,
Value: `"github.com/stealthrocket/coroutine/types"`,
},
},
},
},
&ast.FuncDecl{
Name: ast.NewIdent("init"),
Type: &ast.FuncType{},
Body: &init,
},
},
}
}

func generateFunctypesInit(path string, init *ast.BlockStmt, fn *ssa.Function) {
if fn.TypeParams() != nil {
return // ignore non-instantiated generic functions
}

init.List = append(init.List, &ast.ExprStmt{
X: &ast.CallExpr{
Fun: &ast.IndexListExpr{
X: &ast.SelectorExpr{
X: ast.NewIdent("_types"),
Sel: ast.NewIdent("RegisterFunc"),
},
Indices: []ast.Expr{
newFuncType(fn.Signature),
},
},
Args: []ast.Expr{
&ast.BasicLit{
Kind: token.STRING,
Value: strconv.Quote(path + "." + fn.Name()),
},
},
},
})

anonFuncs := slices.Clone(fn.AnonFuncs)
slices.SortFunc(anonFuncs, func(f1, f2 *ssa.Function) int {
return cmp.Compare(f1.Name(), f2.Name())
})

for _, anonFunc := range anonFuncs {
generateFunctypesInit(path, init, anonFunc)
}
}
10 changes: 10 additions & 0 deletions compiler/testdata/coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,13 @@ func RangeOverMaps(n int) {
coroutine.Yield[int, any](len(m2))
}
}

func Range(n int, do func(int)) {
for i := 0; i < n; i++ {
do(i)
}
}

func Double(n int) {
coroutine.Yield[int, any](2 * n)
}
Loading

0 comments on commit 643f2d5

Please sign in to comment.