From 43998335ffe07a9d7ee3e201e40046099d5b8b8a Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Tue, 19 Sep 2023 14:39:03 +1000 Subject: [PATCH 1/2] Add HTTP mod for testing --- compiler/testdata/http/go.mod | 7 +++++++ compiler/testdata/http/main.go | 37 ++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) create mode 100644 compiler/testdata/http/go.mod create mode 100644 compiler/testdata/http/main.go diff --git a/compiler/testdata/http/go.mod b/compiler/testdata/http/go.mod new file mode 100644 index 0000000..5c8c4c1 --- /dev/null +++ b/compiler/testdata/http/go.mod @@ -0,0 +1,7 @@ +module github.com/stealthrocket/coroutine/compiler/testdata/http + +go 1.21.0 + +require github.com/stealthrocket/coroutine v0.0.0-20230906012022-7474cda88ddc + +replace github.com/stealthrocket/coroutine => ../../../ diff --git a/compiler/testdata/http/main.go b/compiler/testdata/http/main.go new file mode 100644 index 0000000..2987b88 --- /dev/null +++ b/compiler/testdata/http/main.go @@ -0,0 +1,37 @@ +package main + +import ( + "fmt" + "net/http" + + "github.com/stealthrocket/coroutine" +) + +type yieldingRoundTripper struct{} + +func (*yieldingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + res := coroutine.Yield[*http.Request, *http.Response](req) + return res, nil +} + +func work() { + res, err := http.Get("http://example.com") + if err != nil { + panic(err) + } + fmt.Println(res.StatusCode) +} + +func main() { + http.DefaultTransport = &yieldingRoundTripper{} + + c := coroutine.New[*http.Request, *http.Response](work) + + for c.Next() { + req := c.Recv() + fmt.Println("Requesting", req.URL.String()) + c.Send(&http.Response{ + StatusCode: 200, + }) + } +} From dbd1e4480855c5a6460954705909e906ba644ff2 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Tue, 19 Sep 2023 14:39:32 +1000 Subject: [PATCH 2/2] Improve detection of enclosing module --- compiler/compile.go | 67 +++++++++++++++++++++++++++++---------------- 1 file changed, 43 insertions(+), 24 deletions(-) diff --git a/compiler/compile.go b/compiler/compile.go index 4718b11..2512c8b 100644 --- a/compiler/compile.go +++ b/compiler/compile.go @@ -20,13 +20,14 @@ import ( const coroutinePackage = "github.com/stealthrocket/coroutine" -// Compile compiles coroutines in one or more packages. +// Compile compiles coroutines in a module. // -// The path argument can either be a path to a package, a -// path to a file within a package, or a pattern that matches -// multiple packages (for example, /path/to/package/...). -// The path can be absolute or relative (to the current working -// directory). +// The path argument can either be a path to a package within +// the module, or a pattern that matches multiple packages in the +// module (for example, /path/to/module/...). In both cases, the +// nearest module is located and compiled as a whole. +// +// The path can be absolute, or relative to the current working directory. func Compile(path string, options ...Option) error { c := &compiler{ outputFilename: "coroc_generated.go", @@ -64,34 +65,49 @@ type compiler struct { func (c *compiler) compile(path string) error { log.SetFlags(log.LstdFlags | log.Lmicroseconds) - if path != "" && !strings.HasSuffix(path, "...") { - s, err := os.Stat(path) - if err != nil { - return err - } else if !s.IsDir() { - // Make sure we're loading whole packages. - path = filepath.Dir(path) - } + absPath, err := filepath.Abs(path) + if err != nil { + return err + } + var dotdotdot bool + absPath, dotdotdot = strings.CutSuffix(absPath, "...") + if s, err := os.Stat(absPath); err != nil { + return err + } else if !s.IsDir() { + // Make sure we're loading whole packages. + absPath = filepath.Dir(absPath) } - path = filepath.Clean(path) - if len(path) > 0 && path[0] != filepath.Separator && path[0] != '.' { - // Go interprets patterns without a leading dot as part of the - // stdlib (i.e. part of $GOROOT/src) rather than relative to - // the working dir. Note that filepath.Join(".", path) does not - // give the desired result here, hence the manual concat. - path = "." + string(filepath.Separator) + path + var pattern string + if dotdotdot { + pattern = "./..." + } else { + pattern = "." } log.Printf("reading, parsing and type-checking") conf := &packages.Config{ - Mode: packages.NeedName | packages.NeedFiles | packages.NeedSyntax | packages.NeedTypes | packages.NeedImports | packages.NeedDeps | packages.NeedTypesInfo, + Mode: packages.NeedName | packages.NeedModule | + packages.NeedImports | packages.NeedDeps | + packages.NeedFiles | packages.NeedSyntax | + packages.NeedTypes | packages.NeedTypesInfo, Fset: c.fset, + Dir: absPath, } - - pkgs, err := packages.Load(conf, path) + pkgs, err := packages.Load(conf, pattern) if err != nil { return fmt.Errorf("packages.Load %q: %w", path, err) } + var moduleDir string + for _, p := range pkgs { + if p.Module == nil { + return fmt.Errorf("package %s is not part of a module", p.PkgPath) + } + if moduleDir == "" { + moduleDir = p.Module.Dir + } else if moduleDir != p.Module.Dir { + return fmt.Errorf("pattern more than one module (%s + %s)", moduleDir, p.Module.Dir) + } + } flatpkgs := flattenPackages(pkgs) for _, p := range flatpkgs { for _, err := range p.Errors { @@ -151,6 +167,9 @@ func (c *compiler) compile(path string) error { } for p, colors := range colorsByPkg { + if p.Module == nil || p.Module.Dir != moduleDir { + return fmt.Errorf("not implemented: compilation for packages outside module (need to compile %s)", p.PkgPath) + } if err := c.compilePackage(p, colors, prog); err != nil { return err }