From ad13f95df9ea76486c22f6586e498f4e04a9861d Mon Sep 17 00:00:00 2001 From: Jeroen Vervaeke <9132134+jeroenvervaeke@users.noreply.github.com> Date: Wed, 25 Sep 2024 12:30:03 +0100 Subject: [PATCH] Add --rewrite flag to get command (#36) --- internal/importpaths/rewrite.go | 6 ++++-- main.go | 24 ++++++++++++++++++++---- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/internal/importpaths/rewrite.go b/internal/importpaths/rewrite.go index b7dd7b3..ea51043 100644 --- a/internal/importpaths/rewrite.go +++ b/internal/importpaths/rewrite.go @@ -165,7 +165,7 @@ type RewriteModuleOptions struct { NewVersion string NewPrefix string PkgDir string - OnRewrite func(pos token.Position, oldpath, newpath string) + OnRewrite func(pos token.Position, oldpath, newpath string) error } // RewriteModule rewrites imports of a specific module to a new version or prefix. @@ -188,7 +188,9 @@ func RewriteModule(dir string, opt RewriteModuleOptions) error { return "", ErrSkip } if opt.OnRewrite != nil { - opt.OnRewrite(pos, path, newpath) + if err := opt.OnRewrite(pos, path, newpath); err != nil { + return "", err + } } return newpath, nil }) diff --git a/main.go b/main.go index f6ed65f..8d64bac 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "go/token" "os" "os/exec" + "regexp" "runtime/debug" "golang.org/x/mod/modfile" @@ -117,13 +118,14 @@ func listcmd(args []string) error { } func getcmd(args []string) error { - var dir string + var dir, rewrite string var pre, cached, major bool fset := flag.NewFlagSet("get", flag.ExitOnError) fset.BoolVar(&pre, "pre", false, "allow non-v0 prerelease versions") fset.BoolVar(&major, "major", false, "only get newer major versions") fset.StringVar(&dir, "dir", ".", "working directory") fset.BoolVar(&cached, "cached", true, "only fetch cached content from the module proxy") + fset.StringVar(&rewrite, "rewrite", "", "exact package version to upgrade") fset.Usage = func() { fmt.Fprintln(os.Stderr, "Usage: gomajor get ") fset.PrintDefaults() @@ -162,8 +164,9 @@ func getcmd(args []string) error { err := importpaths.RewriteModule(dir, importpaths.RewriteModuleOptions{ Prefix: packages.ModPrefix(u.Module.Path), NewVersion: u.Latest.Version, - OnRewrite: func(pos token.Position, _, newpath string) { + OnRewrite: func(pos token.Position, _, newpath string) error { fmt.Printf("%s %s\n", pos, newpath) + return nil }, }) if err != nil { @@ -217,13 +220,25 @@ func getcmd(args []string) error { if err := cmd.Run(); err != nil { return err } + var rewriteRegex *regexp.Regexp + if rewrite != "" { + if rewriteRegex, err = regexp.Compile(rewrite); err != nil { + return err + } + } // rewrite imports err = importpaths.RewriteModule(dir, importpaths.RewriteModuleOptions{ PkgDir: pkgdir, Prefix: modprefix, NewVersion: version, - OnRewrite: func(pos token.Position, _, newpath string) { + OnRewrite: func(pos token.Position, oldpath, newpath string) error { + if rewriteRegex != nil && !rewriteRegex.MatchString(oldpath) { + return importpaths.ErrSkip + } + fmt.Printf("%s %s\n", pos, newpath) + + return nil }, }) if err != nil { @@ -297,8 +312,9 @@ func pathcmd(args []string) error { Prefix: oldmodprefix, NewVersion: version, NewPrefix: modprefix, - OnRewrite: func(pos token.Position, _, newpath string) { + OnRewrite: func(pos token.Position, _, newpath string) error { fmt.Printf("%s %s\n", pos, newpath) + return nil }, }) if err != nil {