Skip to content

Commit

Permalink
Add testdata/http (#37)
Browse files Browse the repository at this point in the history
This adds a HTTP module with a yielding `http.RoundTripper` for testing.

We can't compile it yet; the callgraph doesn't seem to pickup the fact
that `http.Get` calls or might call the custom `RoundTripper`
implementation. I'll look at this in a follow-up.

I've improved the loading process to handle modules that don't enclose
the current working directory (like using `go {build,list,test} -C
<dir>`).
  • Loading branch information
chriso authored Sep 19, 2023
2 parents a84c23f + dbd1e44 commit 261eb4e
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 24 deletions.
67 changes: 43 additions & 24 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ import (

const coroutinePackage = "github.com/stealthrocket/coroutine"

// Compile compiles coroutines in one or more packages.
// Compile compiles coroutines in a module.
//
// The path argument can either be a path to a package, a
// path to a file within a package, or a pattern that matches
// multiple packages (for example, /path/to/package/...).
// The path can be absolute or relative (to the current working
// directory).
// The path argument can either be a path to a package within
// the module, or a pattern that matches multiple packages in the
// module (for example, /path/to/module/...). In both cases, the
// nearest module is located and compiled as a whole.
//
// The path can be absolute, or relative to the current working directory.
func Compile(path string, options ...Option) error {
c := &compiler{
outputFilename: "coroc_generated.go",
Expand Down Expand Up @@ -64,34 +65,49 @@ type compiler struct {
func (c *compiler) compile(path string) error {
log.SetFlags(log.LstdFlags | log.Lmicroseconds)

if path != "" && !strings.HasSuffix(path, "...") {
s, err := os.Stat(path)
if err != nil {
return err
} else if !s.IsDir() {
// Make sure we're loading whole packages.
path = filepath.Dir(path)
}
absPath, err := filepath.Abs(path)
if err != nil {
return err
}
var dotdotdot bool
absPath, dotdotdot = strings.CutSuffix(absPath, "...")
if s, err := os.Stat(absPath); err != nil {
return err
} else if !s.IsDir() {
// Make sure we're loading whole packages.
absPath = filepath.Dir(absPath)
}
path = filepath.Clean(path)
if len(path) > 0 && path[0] != filepath.Separator && path[0] != '.' {
// Go interprets patterns without a leading dot as part of the
// stdlib (i.e. part of $GOROOT/src) rather than relative to
// the working dir. Note that filepath.Join(".", path) does not
// give the desired result here, hence the manual concat.
path = "." + string(filepath.Separator) + path
var pattern string
if dotdotdot {
pattern = "./..."
} else {
pattern = "."
}

log.Printf("reading, parsing and type-checking")
conf := &packages.Config{
Mode: packages.NeedName | packages.NeedFiles | packages.NeedSyntax | packages.NeedTypes | packages.NeedImports | packages.NeedDeps | packages.NeedTypesInfo,
Mode: packages.NeedName | packages.NeedModule |
packages.NeedImports | packages.NeedDeps |
packages.NeedFiles | packages.NeedSyntax |
packages.NeedTypes | packages.NeedTypesInfo,
Fset: c.fset,
Dir: absPath,
}

pkgs, err := packages.Load(conf, path)
pkgs, err := packages.Load(conf, pattern)
if err != nil {
return fmt.Errorf("packages.Load %q: %w", path, err)
}
var moduleDir string
for _, p := range pkgs {
if p.Module == nil {
return fmt.Errorf("package %s is not part of a module", p.PkgPath)
}
if moduleDir == "" {
moduleDir = p.Module.Dir
} else if moduleDir != p.Module.Dir {
return fmt.Errorf("pattern more than one module (%s + %s)", moduleDir, p.Module.Dir)
}
}
flatpkgs := flattenPackages(pkgs)
for _, p := range flatpkgs {
for _, err := range p.Errors {
Expand Down Expand Up @@ -151,6 +167,9 @@ func (c *compiler) compile(path string) error {
}

for p, colors := range colorsByPkg {
if p.Module == nil || p.Module.Dir != moduleDir {
return fmt.Errorf("not implemented: compilation for packages outside module (need to compile %s)", p.PkgPath)
}
if err := c.compilePackage(p, colors, prog); err != nil {
return err
}
Expand Down
7 changes: 7 additions & 0 deletions compiler/testdata/http/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module github.com/stealthrocket/coroutine/compiler/testdata/http

go 1.21.0

require github.com/stealthrocket/coroutine v0.0.0-20230906012022-7474cda88ddc

replace github.com/stealthrocket/coroutine => ../../../
37 changes: 37 additions & 0 deletions compiler/testdata/http/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package main

import (
"fmt"
"net/http"

"github.com/stealthrocket/coroutine"
)

type yieldingRoundTripper struct{}

func (*yieldingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
res := coroutine.Yield[*http.Request, *http.Response](req)
return res, nil
}

func work() {
res, err := http.Get("http://example.com")
if err != nil {
panic(err)
}
fmt.Println(res.StatusCode)
}

func main() {
http.DefaultTransport = &yieldingRoundTripper{}

c := coroutine.New[*http.Request, *http.Response](work)

for c.Next() {
req := c.Recv()
fmt.Println("Requesting", req.URL.String())
c.Send(&http.Response{
StatusCode: 200,
})
}
}

0 comments on commit 261eb4e

Please sign in to comment.