Skip to content

Commit

Permalink
Merge branch 'main' into generics2
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Dec 14, 2023
2 parents c3b7be3 + b379c8b commit 1f6b17d
Show file tree
Hide file tree
Showing 7 changed files with 288 additions and 12 deletions.
57 changes: 52 additions & 5 deletions compiler/cmd/coroc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ package main
import (
"flag"
"fmt"
"io"
"log"
"os"
"runtime/debug"
"runtime/pprof"

"github.com/stealthrocket/coroutine/compiler"
)
Expand All @@ -17,9 +20,26 @@ USAGE:
OPTIONS:
-h, --help Show this help information
-l, --list List all files that would be compiled
-v, --version Show the compiler version
ADVANCED OPTIONS:
-cpuprofile Write CPU profile to file
-memprofile Write memory profile to file
`

var (
showVersion bool
onlyListFiles bool
cpuProfile string
memProfile string
)

func boolFlag(ptr *bool, short, long string) {
flag.BoolVar(ptr, short, false, "")
flag.BoolVar(ptr, long, false, "")
}

func main() {
if err := run(); err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err)
Expand All @@ -30,17 +50,38 @@ func main() {
func run() error {
flag.Usage = func() { println(usage[1:]) }

var showVersion bool
flag.BoolVar(&showVersion, "v", false, "")
flag.BoolVar(&showVersion, "version", false, "")

boolFlag(&showVersion, "v", "version")
boolFlag(&onlyListFiles, "l", "list")
flag.StringVar(&cpuProfile, "cpuprofile", "", "")
flag.StringVar(&memProfile, "memprofile", "", "")
flag.Parse()

if showVersion {
fmt.Println(version())
return nil
}

if memProfile != "" {
f, err := os.Create(memProfile)
if err != nil {
return err
}
defer f.Close()
defer pprof.WriteHeapProfile(f)
}

if cpuProfile != "" {
f, err := os.Create(cpuProfile)
if err != nil {
return err
}
defer f.Close()
if err := pprof.StartCPUProfile(f); err != nil {
return err
}
defer pprof.StopCPUProfile()
}

path := flag.Arg(0)
if path == "" {
// If the compiler was invoked via go generate, the GOFILE
Expand All @@ -55,7 +96,13 @@ func run() error {
}
}

return compiler.Compile(path)
if onlyListFiles {
log.SetOutput(io.Discard)
}

return compiler.Compile(path,
compiler.OnlyListFiles(onlyListFiles),
)
}

func version() (version string) {
Expand Down
22 changes: 17 additions & 5 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,9 @@ func Compile(path string, options ...Option) error {
return c.compile(path)
}

// Option configures the compiler.
type Option func(*compiler)

type compiler struct {
onlyListFiles bool

prog *ssa.Program
generics map[*ssa.Function][]*ssa.Function
coroutinePkg *packages.Package
Expand Down Expand Up @@ -118,11 +117,12 @@ func (c *compiler) compile(path string) error {
c.prog.Build()

log.Printf("building call graph")
cg := vta.CallGraph(ssautil.AllFunctions(c.prog), cha.CallGraph(c.prog))
functions := ssautil.AllFunctions(c.prog)
cg := vta.CallGraph(functions, cha.CallGraph(c.prog))

log.Printf("collecting generic instances")
c.generics = map[*ssa.Function][]*ssa.Function{}
for fn := range ssautil.AllFunctions(c.prog) {
for fn := range functions {
if fn.Signature.TypeParams() != nil {
if _, ok := c.generics[fn]; !ok {
c.generics[fn] = nil
Expand Down Expand Up @@ -183,6 +183,17 @@ func (c *compiler) compile(path string) error {
pkgColors[fn] = color
}

if c.onlyListFiles {
cwd, _ := os.Getwd()
for pkg := range colorsByPkg {
for _, filePath := range pkg.GoFiles {
relPath, _ := filepath.Rel(cwd, filePath)
fmt.Println(relPath)
}
}
return nil
}

// Before mutating packages, we need to ensure that packages exist in a
// location where mutations can be made safely (without affecting other
// builds).
Expand Down Expand Up @@ -219,6 +230,7 @@ func (c *compiler) compile(path string) error {
// Reject packages outside ./vendor.
return fmt.Errorf("cannot mutate package %s (%s) safely. Please vendor dependencies: go mod vendor", p.PkgPath, dir)
}

if len(needVendoring) > 0 {
log.Printf("vendoring GOROOT packages")
newRoot := filepath.Join(moduleDir, "goroot")
Expand Down
5 changes: 5 additions & 0 deletions compiler/coroutine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ func TestCoroutineYield(t *testing.T) {
result: 42,
},

{
name: "closure capturing receiver and param",
coro: func() { StructClosure(3) },
yields: []int{-1, 10, 100, 1000, 11, 101, 1000, 12, 102, 1000},
},
{
name: "generic function",
coro: func() { IdentityGenericInt(11) },
Expand Down
28 changes: 27 additions & 1 deletion compiler/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,33 @@ func packagePath(p *packages.Package) string {
}

func functionPath(p *packages.Package, f *ast.FuncDecl) string {
return packagePath(p) + "." + f.Name.Name
var b strings.Builder
b.WriteString(packagePath(p))
if f.Recv != nil {
signature := p.TypesInfo.Defs[f.Name].Type().(*types.Signature)
recvType := signature.Recv().Type()
isptr := false
if ptr, ok := recvType.(*types.Pointer); ok {
recvType = ptr.Elem()
isptr = true
}
b.WriteByte('.')
if isptr {
b.WriteString("(*")
}
switch t := recvType.(type) {
case *types.Named:
b.WriteString(t.Obj().Name())
default:
panic(fmt.Sprintf("not implemented: %T", t))
}
if isptr {
b.WriteByte(')')
}
}
b.WriteByte('.')
b.WriteString(f.Name.Name)
return b.String()
}

func (c *compiler) generateFunctypes(p *packages.Package, f *ast.File, colors map[ast.Node]*types.Signature) {
Expand Down
10 changes: 10 additions & 0 deletions compiler/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package compiler

// Option configures the compiler.
type Option func(*compiler)

func OnlyListFiles(enabled bool) Option {
return func(c *compiler) {
c.onlyListFiles = enabled
}
}
27 changes: 27 additions & 0 deletions compiler/testdata/coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,33 @@ func ReturnNamedValue() (out int) {
return
}

type Box struct {
x int
}

func (b *Box) Closure(y int) func(int) {
// Force compilation of this method and the closure within.
// Remove once #84 is fixed.
coroutine.Yield[int, any](-1)

return func(z int) {
coroutine.Yield[int, any](b.x)
coroutine.Yield[int, any](y)
coroutine.Yield[int, any](z)
b.x++
y++
z++ // mutation is lost
}
}

func StructClosure(n int) {
box := Box{10}
fn := box.Closure(100)
for i := 0; i < n; i++ {
fn(1000)
}
}

func IdentityGeneric[T any](n T) {
coroutine.Yield[T, any](n)
}
Expand Down
Loading

0 comments on commit 1f6b17d

Please sign in to comment.