diff --git a/compiler/testdata/http/main_durable.go b/compiler/testdata/http/main_durable.go index 2a52f59..e75eb94 100644 --- a/compiler/testdata/http/main_durable.go +++ b/compiler/testdata/http/main_durable.go @@ -5,9 +5,9 @@ package main import ( - http "net/http" coroutine "github.com/stealthrocket/coroutine" fmt "fmt" + http "net/http" ) import _types "github.com/stealthrocket/coroutine/types" //go:noinline diff --git a/compiler/vendor.go b/compiler/vendor.go index a48cd02..249b1c5 100644 --- a/compiler/vendor.go +++ b/compiler/vendor.go @@ -19,7 +19,6 @@ const copyConcurrency = 16 func vendor(newRoot string, pkgs []*packages.Package) error { goroot := runtime.GOROOT() - var dirs []string var scanErr error packages.Visit(pkgs, func(p *packages.Package) bool { path := filepath.Dir(p.GoFiles[0]) @@ -34,52 +33,14 @@ func vendor(newRoot string, pkgs []*packages.Package) error { scanErr = fmt.Errorf("package %s (%s) is not in GOROOT (%s)", p.PkgPath, path, goroot) return false } - dirs = append(dirs, rel) return scanErr == nil }, nil) if scanErr != nil { return scanErr } - type copyOperation struct{ src, dst string } - ops := make(chan copyOperation, 256) - var group errgroup.Group - group.Go(func() error { - for _, rel := range dirs { - srcDir := filepath.Join(goroot, rel) - dstDir := filepath.Join(newRoot, rel) - if err := os.MkdirAll(dstDir, 0755); err != nil { - return err - } - entries, err := os.ReadDir(srcDir) - if err != nil { - return err - } - for _, entry := range entries { - if entry.IsDir() { - continue - } - name := entry.Name() - ops <- copyOperation{ - src: filepath.Join(srcDir, name), - dst: filepath.Join(dstDir, name), - } - } - } - close(ops) - return nil - }) - for i := 0; i < copyConcurrency; i++ { - group.Go(func() error { - for op := range ops { - if err := copyFile(op.dst, op.src); err != nil { - return err - } - } - return nil - }) - } - if err := group.Wait(); err != nil { + // Copy the entire GOROOT/src directory. + if err := copyDir(filepath.Join(newRoot, "src"), filepath.Join(goroot, "src")); err != nil { return err } @@ -101,6 +62,57 @@ func vendor(newRoot string, pkgs []*packages.Package) error { return err } +type copyOperation struct{ src, dst string } + +func copyDir(dst, src string) error { + ops := make(chan copyOperation, 256) + + var group errgroup.Group + group.Go(func() error { + err := scanDir(dst, src, ops) + close(ops) + return err + }) + for i := 0; i < copyConcurrency; i++ { + group.Go(func() error { + for op := range ops { + if err := copyFile(op.dst, op.src); err != nil { + return err + } + } + return nil + }) + } + return group.Wait() +} + +func scanDir(dst, src string, ops chan<- copyOperation) error { + if err := os.MkdirAll(dst, 0755); err != nil && !errors.Is(err, os.ErrExist) { + return err + } + entries, err := os.ReadDir(src) + if err != nil { + return err + } + for _, entry := range entries { + name := entry.Name() + fmt.Println(dst, src, name) + if entry.IsDir() { + dstChild := filepath.Join(dst, name) + srcChild := filepath.Join(src, name) + if err := scanDir(dstChild, srcChild, ops); err != nil { + return err + } + } else { + ops <- copyOperation{ + src: filepath.Join(src, name), + dst: filepath.Join(dst, name), + } + } + } + return nil +} + func copyFile(dst, src string) error { srcFile, err := os.Open(src) if err != nil {