diff --git a/major/major.go b/major/major.go index 9e12754..29db4e8 100644 --- a/major/major.go +++ b/major/major.go @@ -31,12 +31,13 @@ func Run(dir, op, modName string, tag int, buildFlags []string) error { client = false modName = modFile.Module.Mod.Path } + separator := getSeparator(modName) var newModPath string switch op { case "upgrade": - newModPath = getNext(tag, modName) + newModPath = getNext(tag, modName, separator) case "downgrade": - newModPath = getPrevious(modName) + newModPath = getPrevious(modName, separator) } c := &packages.Config{Mode: packages.LoadSyntax, Tests: true, Dir: dir, BuildFlags: buildFlags} pkgs, err := packages.Load(c, "./...") @@ -50,7 +51,7 @@ func Run(dir, op, modName string, tag int, buildFlags []string) error { continue } ids[p.ID] = struct{}{} - err = updateImportPath(p, modName, newModPath, files) + err = updateImportPath(p, modName, newModPath, separator, files) if err != nil { return err } @@ -72,6 +73,17 @@ func Run(dir, op, modName string, tag int, buildFlags []string) error { return nil } +func isGopkgin(s string) bool { + return strings.HasPrefix(s, "gopkg.in/") +} + +func getSeparator(s string) string { + if isGopkgin(s) { + return "." + } + return "/" +} + func getOperation() string { args := flag.Args() if len(args) != 1 { @@ -86,21 +98,21 @@ func getOperation() string { return op } -func getNext(tagNum int, s string) string { - ss := strings.Split(s, "/") +func getNext(tagNum int, s, sep string) string { + ss := strings.Split(s, sep) num, isMajor := versionSuffix(ss) if !isMajor { if tagNum != 0 { - return s + "/v" + strconv.Itoa(tagNum) + return s + sep + "v" + strconv.Itoa(tagNum) } - return s + "/v2" + return s + sep + "v2" } newV := num + 1 if tagNum != 0 { newV = tagNum } - return strings.Join(ss[:len(ss)-1], "/") + "/v" + strconv.Itoa(newV) + return strings.Join(ss[:len(ss)-1], sep) + sep + "v" + strconv.Itoa(newV) } func versionSuffix(ss []string) (int, bool) { @@ -118,22 +130,22 @@ func versionSuffix(ss []string) (int, bool) { return num, true } -func getPrevious(s string) string { - ss := strings.Split(s, "/") +func getPrevious(s, sep string) string { + ss := strings.Split(s, sep) num, isMajor := versionSuffix(ss) - if !isMajor { + if !isMajor || isGopkgin(s) && num == 0 { return s } - if num == 2 { - return strings.Join(ss[:len(ss)-1], "/") + if num == 2 && !isGopkgin(s) { + return strings.Join(ss[:len(ss)-1], sep) } newV := num - 1 - return strings.Join(ss[:len(ss)-1], "/") + "/v" + strconv.Itoa(newV) + return strings.Join(ss[:len(ss)-1], sep) + sep + "v" + strconv.Itoa(newV) } -func updateImportPath(p *packages.Package, old, new string, files map[string]struct{}) error { +func updateImportPath(p *packages.Package, old, new, sep string, files map[string]struct{}) error { for _, syn := range p.Syntax { goFileName := p.Fset.File(syn.Pos()).Name() if _, ok := files[goFileName]; ok { @@ -143,7 +155,7 @@ func updateImportPath(p *packages.Package, old, new string, files map[string]str var rewritten bool for _, i := range syn.Imports { imp := strings.Replace(i.Path.Value, `"`, ``, 2) - if strings.HasPrefix(imp, fmt.Sprintf("%s/", old)) || imp == old { + if strings.HasPrefix(imp, fmt.Sprintf("%s%s", old, sep)) || imp == old { newImp := strings.Replace(imp, old, new, 1) rewrote := astutil.RewriteImport(p.Fset, syn, imp, newImp) if rewrote { diff --git a/major/major_test.go b/major/major_test.go index 6096d22..b4448a8 100644 --- a/major/major_test.go +++ b/major/major_test.go @@ -1,6 +1,8 @@ package major -import "testing" +import ( + "testing" +) var upDownCases = []struct { name string @@ -32,16 +34,35 @@ var upDownCases = []struct { "mod/sub/v100", "mod/sub/v98", }, + { + "gopkg.in v0", + "gopkg.in/yaml.v0", + "gopkg.in/yaml.v1", + "gopkg.in/yaml.v0", + }, + { + "gopkg.in v1", + "gopkg.in/yaml.v1", + "gopkg.in/yaml.v2", + "gopkg.in/yaml.v0", + }, + { + "gopkg.in v2", + "gopkg.in/yaml.v2", + "gopkg.in/yaml.v3", + "gopkg.in/yaml.v1", + }, } func TestGetNext(t *testing.T) { for _, tc := range upDownCases { t.Run(tc.name, func(t *testing.T) { - next := getNext(0, tc.input) + sep := getSeparator(tc.input) + next := getNext(0, tc.input, sep) if next != tc.next { t.Fatalf("expected getNext to return %v but got %v", tc.next, next) } - prev := getPrevious(tc.input) + prev := getPrevious(tc.input, sep) if prev != tc.prev { t.Fatalf("expected getPrevious to return %v but got %v", tc.prev, prev) }