From c38eb114cee916b09848ac44144c732932e6741d Mon Sep 17 00:00:00 2001 From: Paul Cody Johnston Date: Fri, 26 Jul 2024 18:45:44 -0600 Subject: [PATCH] Feature: bazel-assisted removal of wildcard imports (#119) * Initial wildcard fixer * Add gazelle:scala_fix_wildcard_imports directive * Initial implementation of incremental wildcard fixing * Add TestScalaConfigParseFixWildcardImports cases --- cmd/wildcardimportfixer/BUILD.bazel | 15 +++ cmd/wildcardimportfixer/main.go | 77 +++++++++++++++ language/scala/BUILD.bazel | 1 + language/scala/language_test.go | 1 + language/scala/scala_rule.go | 64 +++++++++---- pkg/scalaconfig/config.go | 84 +++++++++++++--- pkg/scalaconfig/config_test.go | 91 ++++++++++++++++-- pkg/wildcardimport/BUILD.bazel | 45 +++++++++ pkg/wildcardimport/bazel.go | 38 ++++++++ pkg/wildcardimport/fixer.go | 144 ++++++++++++++++++++++++++++ pkg/wildcardimport/scanner.go | 48 ++++++++++ pkg/wildcardimport/scanner_test.go | 121 +++++++++++++++++++++++ pkg/wildcardimport/text_file.go | 86 +++++++++++++++++ 13 files changed, 771 insertions(+), 44 deletions(-) create mode 100644 cmd/wildcardimportfixer/BUILD.bazel create mode 100644 cmd/wildcardimportfixer/main.go create mode 100644 pkg/wildcardimport/BUILD.bazel create mode 100644 pkg/wildcardimport/bazel.go create mode 100644 pkg/wildcardimport/fixer.go create mode 100644 pkg/wildcardimport/scanner.go create mode 100644 pkg/wildcardimport/scanner_test.go create mode 100644 pkg/wildcardimport/text_file.go diff --git a/cmd/wildcardimportfixer/BUILD.bazel b/cmd/wildcardimportfixer/BUILD.bazel new file mode 100644 index 0000000..db53416 --- /dev/null +++ b/cmd/wildcardimportfixer/BUILD.bazel @@ -0,0 +1,15 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library") + +go_library( + name = "wildcardimportfixer_lib", + srcs = ["main.go"], + importpath = "github.com/stackb/scala-gazelle/cmd/wildcardimportfixer", + visibility = ["//visibility:private"], + deps = ["//pkg/wildcardimport"], +) + +go_binary( + name = "wildcardimportfixer", + embed = [":wildcardimportfixer_lib"], + visibility = ["//visibility:public"], +) diff --git a/cmd/wildcardimportfixer/main.go b/cmd/wildcardimportfixer/main.go new file mode 100644 index 0000000..6bed3a2 --- /dev/null +++ b/cmd/wildcardimportfixer/main.go @@ -0,0 +1,77 @@ +package main + +import ( + "flag" + "fmt" + "log" + "os" + + "github.com/stackb/scala-gazelle/pkg/wildcardimport" +) + +const ( + executableName = "wildcardimportfixer" +) + +type config struct { + ruleLabel string + targetFilename string + importPrefix string + bazelExe string +} + +func main() { + log.SetPrefix(executableName + ": ") + log.SetFlags(0) // don't print timestamps + + cfg, err := parseFlags(os.Args[1:]) + if err != nil { + log.Fatal(err) + } + + if err := run(cfg); err != nil { + log.Fatalln("ERROR:", err) + } + +} + +func parseFlags(args []string) (*config, error) { + cfg := new(config) + + fs := flag.NewFlagSet(executableName, flag.ExitOnError) + fs.StringVar(&cfg.ruleLabel, "rule_label", "", "the rule label to iteratively build") + fs.StringVar(&cfg.targetFilename, "target_filename", "", "the scala file to fix") + fs.StringVar(&cfg.importPrefix, "import_prefix", "", "the scala import prefix to set") + fs.StringVar(&cfg.bazelExe, "bazel_executable", "bazel", "the path to the bazel executable") + fs.Usage = func() { + fmt.Fprintf(flag.CommandLine.Output(), "usage: %s OPTIONS", executableName) + fs.PrintDefaults() + } + if err := fs.Parse(args); err != nil { + return nil, err + } + + if cfg.ruleLabel == "" { + log.Fatal("-rule_label is required") + } + + return cfg, nil +} + +func run(cfg *config) error { + + var err error + + fixer := wildcardimport.NewFixer(&wildcardimport.FixerOptions{ + BazelExecutable: cfg.bazelExe, + }) + + symbols, err := fixer.Fix(cfg.ruleLabel, cfg.targetFilename, cfg.importPrefix) + if err != nil { + return err + } + + log.Println("FIXED", cfg.targetFilename, symbols) + + return nil +} diff --git a/language/scala/BUILD.bazel b/language/scala/BUILD.bazel index 019d85b..8b6a4c5 100644 --- a/language/scala/BUILD.bazel +++ b/language/scala/BUILD.bazel @@ -43,6 +43,7 @@ go_library( "//pkg/resolver", "//pkg/scalaconfig", "//pkg/scalarule", + "//pkg/wildcardimport", "@bazel_gazelle//config:go_default_library", "@bazel_gazelle//label:go_default_library", "@bazel_gazelle//language:go_default_library", diff --git a/language/scala/language_test.go b/language/scala/language_test.go index 06e2beb..d59531b 100644 --- a/language/scala/language_test.go +++ b/language/scala/language_test.go @@ -16,6 +16,7 @@ func ExampleLanguage_KnownDirectives() { } // output: // scala_debug + // scala_fix_wildcard_imports // scala_rule // resolve_glob // resolve_conflicts diff --git a/language/scala/scala_rule.go b/language/scala/scala_rule.go index 89f7bce..e26657a 100644 --- a/language/scala/scala_rule.go +++ b/language/scala/scala_rule.go @@ -3,6 +3,8 @@ package scala import ( "fmt" "log" + "os" + "path/filepath" "sort" "strings" @@ -16,6 +18,7 @@ import ( "github.com/stackb/scala-gazelle/pkg/resolver" "github.com/stackb/scala-gazelle/pkg/scalaconfig" "github.com/stackb/scala-gazelle/pkg/scalarule" + "github.com/stackb/scala-gazelle/pkg/wildcardimport" ) const ( @@ -48,6 +51,14 @@ type scalaRule struct { exports map[string]resolve.ImportSpec } +var bazel = "bazel" + +func init() { + if bazelExe, ok := os.LookupEnv("SCALA_GAZELLE_BAZEL_EXECUTABLE"); ok { + bazel = bazelExe + } +} + func newScalaRule( ctx *scalaRuleContext, rule *sppb.Rule, @@ -135,14 +146,6 @@ func (r *scalaRule) ResolveImports(rctx *scalarule.ResolveContext) resolver.Impo if resolved, ok := sc.ResolveConflict(rctx.Rule, imports, item.imp, item.sym); ok { item.imp.Symbol = resolved } else { - if r.ctx.scalaConfig.ShouldAnnotateWildcardImports() && item.sym.Type == sppb.ImportType_PROTO_PACKAGE { - if scope, ok := r.ctx.scope.GetScope(item.imp.Imp); ok { - wildcardImport := item.imp.Src // original symbol name having underscore suffix - r.handleWildcardImport(item.imp.Source, wildcardImport, scope) - } else { - - } - } fmt.Println(resolver.SymbolConfictMessage(item.sym, item.imp, rctx.From)) } } @@ -285,6 +288,23 @@ func (r *scalaRule) fileImports(imports resolver.ImportMap, file *sppb.File) { // gather direct imports and import scopes for _, name := range file.Imports { if wimp, ok := resolver.IsWildcardImport(name); ok { + filename := filepath.Join(r.ctx.scalaConfig.Rel(), file.Filename) + if r.ctx.scalaConfig.ShouldFixWildcardImport(filename, name) { + symbolNames, err := r.fixWildcardImport(filename, wimp) + if err != nil { + log.Fatalf("fixing wildcard imports for %s (%s): %v", file.Filename, wimp, err) + } + for _, symName := range symbolNames { + fqn := wimp + "." + symName + if sym, ok := r.ctx.scope.GetSymbol(fqn); ok { + putImport(resolver.NewResolvedNameImport(sym.Name, file, fqn, sym)) + } else { + if debugUnresolved { + log.Printf("warning: unresolved fix wildcard import: symbol %q: was not found' (%s)", name, file.Filename) + } + } + } + } // collect the (package) symbol for import if sym, ok := r.ctx.scope.GetSymbol(name); ok { putImport(resolver.NewResolvedNameImport(sym.Name, file, name, sym)) @@ -377,19 +397,6 @@ func (r *scalaRule) fileImports(imports resolver.ImportMap, file *sppb.File) { } } -func (r *scalaRule) handleWildcardImport(file *sppb.File, imp string, scope resolver.Scope) { - names := make([]string, 0) - for _, name := range file.Names { - if _, ok := scope.GetSymbol(name); ok { - names = append(names, name) - } - } - if len(names) > 0 { - sort.Strings(names) - log.Printf("[%s]: import %s.{%s}", file.Filename, strings.TrimSuffix(imp, "._"), strings.Join(names, ", ")) - } -} - // Provides implements part of the scalarule.Rule interface. func (r *scalaRule) Provides() []resolve.ImportSpec { exports := make([]resolve.ImportSpec, 0, len(r.exports)) @@ -428,6 +435,21 @@ func (r *scalaRule) putExport(imp string) { r.exports[imp] = resolve.ImportSpec{Imp: imp, Lang: scalaLangName} } +func (r *scalaRule) fixWildcardImport(filename, wimp string) ([]string, error) { + fixer := wildcardimport.NewFixer(&wildcardimport.FixerOptions{ + BazelExecutable: bazel, + }) + + absFilename := filepath.Join(wildcardimport.GetBuildWorkspaceDirectory(), filename) + ruleLabel := label.New("", r.ctx.scalaConfig.Rel(), r.ctx.rule.Name()).String() + symbols, err := fixer.Fix(ruleLabel, absFilename, wimp) + if err != nil { + return nil, err + } + + return symbols, nil +} + func isBinaryRule(kind string) bool { return strings.Contains(kind, "binary") || strings.Contains(kind, "test") } diff --git a/pkg/scalaconfig/config.go b/pkg/scalaconfig/config.go index 9b4a80e..0e66bda 100644 --- a/pkg/scalaconfig/config.go +++ b/pkg/scalaconfig/config.go @@ -21,15 +21,15 @@ import ( type debugAnnotation int const ( - DebugUnknown debugAnnotation = 0 - DebugImports debugAnnotation = 1 - DebugExports debugAnnotation = 2 - DebugWildcardImports debugAnnotation = 3 - scalaLangName = "scala" + DebugUnknown debugAnnotation = 0 + DebugImports debugAnnotation = 1 + DebugExports debugAnnotation = 2 + scalaLangName = "scala" ) const ( scalaDebugDirective = "scala_debug" + scalaFixWildcardImportDirective = "scala_fix_wildcard_imports" scalaRuleDirective = "scala_rule" resolveGlobDirective = "resolve_glob" resolveConflictsDirective = "resolve_conflicts" @@ -41,6 +41,7 @@ const ( func DirectiveNames() []string { return []string{ scalaDebugDirective, + scalaFixWildcardImportDirective, scalaRuleDirective, resolveGlobDirective, resolveConflictsDirective, @@ -58,6 +59,7 @@ type Config struct { overrides []*overrideSpec implicitImports []*implicitImportSpec resolveFileSymbolNames []*resolveFileSymbolNameSpec + fixWildcardImportSpecs []*fixWildcardImportSpec rules map[string]*scalarule.Config labelNameRewrites map[string]resolver.LabelNameRewriteSpec annotations map[debugAnnotation]interface{} @@ -122,6 +124,9 @@ func (c *Config) clone(config *config.Config, rel string) *Config { if c.resolveFileSymbolNames != nil { clone.resolveFileSymbolNames = c.resolveFileSymbolNames[:] } + if c.fixWildcardImportSpecs != nil { + clone.fixWildcardImportSpecs = c.fixWildcardImportSpecs[:] + } return clone } @@ -175,6 +180,8 @@ func (c *Config) ParseDirectives(directives []rule.Directive) (err error) { if err != nil { return fmt.Errorf(`invalid directive: "gazelle:%s %s": %w`, d.Key, d.Value, err) } + case scalaFixWildcardImportDirective: + c.parseFixWildcardImport(d) case resolveGlobDirective: c.parseResolveGlobDirective(d) case resolveWithDirective: @@ -202,7 +209,7 @@ func (c *Config) parseScalaRuleDirective(d rule.Directive) error { return fmt.Errorf("expected three or more fields, got %d", len(fields)) } name, param, value := fields[0], fields[1], strings.Join(fields[2:], " ") - r, err := c.getOrCreateScalaRuleConfig(c.config, name) + r, err := c.getOrCreateScalaRuleConfig(name) if err != nil { return err } @@ -247,6 +254,24 @@ func (c *Config) parseResolveWithDirective(d rule.Directive) { }) } +func (c *Config) parseFixWildcardImport(d rule.Directive) { + parts := strings.Fields(d.Value) + if len(parts) < 2 { + log.Fatalf("invalid gazelle:%s directive: expected [FILENAME_PATTERN [+|-]IMPORT_PATTERN...], got %v", scalaFixWildcardImportDirective, parts) + return + } + filenamePattern := parts[0] + + for _, part := range parts[1:] { + intent := collections.ParseIntent(part) + c.fixWildcardImportSpecs = append(c.fixWildcardImportSpecs, &fixWildcardImportSpec{ + filenamePattern: filenamePattern, + importPattern: *intent, + }) + } + +} + func (c *Config) parseResolveFileSymbolNames(d rule.Directive) { parts := strings.Fields(d.Value) if len(parts) < 2 { @@ -319,10 +344,10 @@ func (c *Config) parseScalaAnnotation(d rule.Directive) error { return nil } -func (c *Config) getOrCreateScalaRuleConfig(config *config.Config, name string) (*scalarule.Config, error) { +func (c *Config) getOrCreateScalaRuleConfig(name string) (*scalarule.Config, error) { r, ok := c.rules[name] if !ok { - r = scalarule.NewConfig(config, name) + r = scalarule.NewConfig(c.config, name) r.Implementation = name c.rules[name] = r } @@ -357,11 +382,6 @@ func (c *Config) ConfiguredRules() []*scalarule.Config { return rules } -func (c *Config) ShouldAnnotateWildcardImports() bool { - _, ok := c.annotations[DebugWildcardImports] - return ok -} - func (c *Config) ShouldAnnotateImports() bool { _, ok := c.annotations[DebugImports] return ok @@ -372,6 +392,36 @@ func (c *Config) ShouldAnnotateExports() bool { return ok } +// ShouldFixWildcardImport tests whether the given symbol name pattern +// should be resolved within the scope of the given filename pattern. +// resolveFileSymbolNameSpecs represent a whitelist; if no patterns match, false +// is returned. +func (c *Config) ShouldFixWildcardImport(filename, wimp string) bool { + if len(c.fixWildcardImportSpecs) > 0 { + log.Println("should fix wildcard import?", filename, wimp) + } + for _, spec := range c.fixWildcardImportSpecs { + hasStarChar := strings.Contains(spec.filenamePattern, "*") + if hasStarChar { + if ok, _ := doublestar.Match(spec.filenamePattern, filename); !ok { + log.Println("should fix wildcard import? FILENAME GLOB MATCH FAILED", filename, spec.filenamePattern) + continue + } + } else { + if !strings.HasSuffix(filename, spec.filenamePattern) { + log.Println("should fix wildcard import? FILENAME SUFFIX MATCH FAILED", filename, spec.filenamePattern) + continue + } + } + if ok, _ := doublestar.Match(spec.importPattern.Value, wimp); !ok { + log.Println("should fix wildcard import? IMPORT PATTERN MATCH FAILED", filename, spec.importPattern.Value, wimp) + continue + } + return spec.importPattern.Want + } + return false +} + // ShouldResolveFileSymbolName tests whether the given symbol name pattern // should be resolved within the scope of the given filename pattern. // resolveFileSymbolNameSpecs represent a whitelist; if no patterns match, false @@ -565,10 +615,14 @@ type resolveFileSymbolNameSpec struct { symbolName collections.Intent } +type fixWildcardImportSpec struct { + // filenamePattern is the filename glob wildcard filenamePattern to test + filenamePattern string + importPattern collections.Intent +} + func parseAnnotation(val string) debugAnnotation { switch val { - case "wildcardimports": - return DebugWildcardImports case "imports": return DebugImports case "exports": diff --git a/pkg/scalaconfig/config_test.go b/pkg/scalaconfig/config_test.go index 41848a7..c4dc832 100644 --- a/pkg/scalaconfig/config_test.go +++ b/pkg/scalaconfig/config_test.go @@ -290,6 +290,89 @@ func TestScalaConfigParseResolveFileSymbolName(t *testing.T) { }) } } +func TestScalaConfigParseFixWildcardImports(t *testing.T) { + for name, tc := range map[string]struct { + directives []rule.Directive + filename string + imports []string + want []bool + wantErr error + }{ + "degenerate": { + want: []bool{}, + }, + "exact matches": { + directives: []rule.Directive{ + {Key: scalaFixWildcardImportDirective, Value: "filename.scala omnistac.core.entity._"}, + }, + filename: "filename.scala", + imports: []string{"omnistac.core.entity._"}, + want: []bool{true}, + }, + "glob matches": { + directives: []rule.Directive{ + {Key: scalaFixWildcardImportDirective, Value: "*.scala omnistac.core.entity._"}, + }, + filename: "filename.scala", + imports: []string{"omnistac.core.entity._"}, + want: []bool{true}, + }, + "recursive glob matches non-recursive path": { + directives: []rule.Directive{ + {Key: scalaFixWildcardImportDirective, Value: "**/*.scala omnistac.core.entity._"}, + }, + filename: "filename.scala", + imports: []string{"omnistac.core.entity._"}, + want: []bool{true}, + }, + "recursive glob matches": { + directives: []rule.Directive{ + {Key: scalaFixWildcardImportDirective, Value: "**/*.scala omnistac.core.entity._"}, + }, + filename: "path/to/filename.scala", + imports: []string{"omnistac.core.entity._"}, + want: []bool{true}, + }, + "recursive glob matches only absolute path (absolute version)": { + directives: []rule.Directive{ + {Key: scalaFixWildcardImportDirective, Value: "/**/*.scala omnistac.core.entity._"}, + }, + filename: "path/to/filename.scala", + imports: []string{"omnistac.core.entity._"}, + want: []bool{false}, + }, + "recursive glob matches absolute path (absolute version)": { + directives: []rule.Directive{ + {Key: scalaFixWildcardImportDirective, Value: "/**/*.scala omnistac.core.entity._"}, + }, + filename: "/path/to/filename.scala", + imports: []string{"omnistac.core.entity._"}, + want: []bool{true}, + }, + "no match": { + directives: []rule.Directive{ + {Key: scalaFixWildcardImportDirective, Value: "*.scala -omnistac.core.entity._"}, + }, + filename: "filename.scala", + imports: []string{"omnistac.core.entity._"}, + want: []bool{false}, + }, + } { + t.Run(name, func(t *testing.T) { + sc, err := NewTestScalaConfig(t, mocks.NewUniverse(t), "", tc.directives...) + if testutil.ExpectError(t, tc.wantErr, err) { + return + } + got := []bool{} + for _, imp := range tc.imports { + got = append(got, sc.ShouldFixWildcardImport(tc.filename, imp)) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("(-want +got):\n%s", diff) + } + }) + } +} func TestScalaConfigParseScalaAnnotate(t *testing.T) { for name, tc := range map[string]struct { @@ -316,14 +399,6 @@ func TestScalaConfigParseScalaAnnotate(t *testing.T) { DebugExports: nil, }, }, - "wildcards": { - directives: []rule.Directive{ - {Key: scalaDebugDirective, Value: "wildcardimports"}, - }, - want: map[debugAnnotation]interface{}{ - DebugWildcardImports: nil, - }, - }, } { t.Run(name, func(t *testing.T) { sc, err := NewTestScalaConfig(t, mocks.NewUniverse(t), "", tc.directives...) diff --git a/pkg/wildcardimport/BUILD.bazel b/pkg/wildcardimport/BUILD.bazel new file mode 100644 index 0000000..be916de --- /dev/null +++ b/pkg/wildcardimport/BUILD.bazel @@ -0,0 +1,45 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "wildcardimport", + srcs = [ + "bazel.go", + "fixer.go", + "scanner.go", + "text_file.go", + ], + importpath = "github.com/stackb/scala-gazelle/pkg/wildcardimport", + visibility = ["//visibility:public"], +) + +go_test( + name = "wildcardimport_test", + srcs = ["scanner_test.go"], + embed = [":wildcardimport"], + deps = ["@com_github_google_go_cmp//cmp"], +) + +# Maybe put this back in after finding conflicts? +# +# +# // if r.ctx.scalaConfig.ShouldAnnotateWildcardImports() && item.sym.Type == sppb.ImportType_PROTO_PACKAGE { +# // if scope, ok := r.ctx.scope.GetScope(item.imp.Imp); ok { +# // wildcardImport := item.imp.Src // original symbol name having underscore suffix +# // r.handleWildcardImport(item.imp.Source, wildcardImport, scope) +# // } else { + +# // } +# // } + +# func (r *scalaRule) handleWildcardImport(file *sppb.File, imp string, scope resolver.Scope) { +# names := make([]string, 0) +# for _, name := range file.Names { +# if _, ok := scope.GetSymbol(name); ok { +# names = append(names, name) +# } +# } +# if len(names) > 0 { +# sort.Strings(names) +# log.Printf("[%s]: import %s.{%s}", file.Filename, strings.TrimSuffix(imp, "._"), strings.Join(names, ", ")) +# } +# } diff --git a/pkg/wildcardimport/bazel.go b/pkg/wildcardimport/bazel.go new file mode 100644 index 0000000..9ff81d6 --- /dev/null +++ b/pkg/wildcardimport/bazel.go @@ -0,0 +1,38 @@ +package wildcardimport + +import ( + "log" + "os" + "os/exec" + "syscall" +) + +func execBazelBuild(bazelExe string, label string) ([]byte, int, error) { + args := []string{"build", label} + + command := exec.Command(bazelExe, args...) + command.Dir = GetBuildWorkspaceDirectory() + + log.Println("!!!", command.String()) + output, err := command.CombinedOutput() + if err != nil { + log.Println("cmdErr:", err) + // Check for exit errors specifically + if exitError, ok := err.(*exec.ExitError); ok { + waitStatus := exitError.Sys().(syscall.WaitStatus) + exitCode := waitStatus.ExitStatus() + return output, exitCode, err + } else { + return output, -1, err + } + } + return output, 0, nil +} + +func GetBuildWorkspaceDirectory() string { + if bwd, ok := os.LookupEnv("BUILD_WORKSPACE_DIRECTORY"); ok { + return bwd + } else { + return "." + } +} diff --git a/pkg/wildcardimport/fixer.go b/pkg/wildcardimport/fixer.go new file mode 100644 index 0000000..0ff5530 --- /dev/null +++ b/pkg/wildcardimport/fixer.go @@ -0,0 +1,144 @@ +package wildcardimport + +import ( + "fmt" + "log" + "sort" + "strings" +) + +const debug = true + +type FixerOptions struct { + BazelExecutable string +} + +type Fixer struct { + bazelExe string +} + +func NewFixer(options *FixerOptions) *Fixer { + bazelExe := options.BazelExecutable + if bazelExe == "" { + bazelExe = "bazel" + } + + return &Fixer{ + bazelExe: bazelExe, + } +} + +// Fix uses iterative bazel builds to remove wildcard imports and returns a list +// of unqualified symbols that were used to complete the import. +func (w *Fixer) Fix(ruleLabel, filename, importPrefix string) ([]string, error) { + targetLine := fmt.Sprintf("import %s._", importPrefix) + + tf, err := NewTextFileFromFilename(filename, targetLine) + if err != nil { + return nil, err + } + + symbols, err := w.fixFile(ruleLabel, tf, importPrefix) + if err != nil { + return nil, err + } + + return symbols, nil +} + +func (w *Fixer) fixFile(ruleLabel string, tf *TextFile, importPrefix string) ([]string, error) { + + // the complete list of not found symbols + completion := map[string]bool{} + + // initialize the scanner + scanner := &outputScanner{} + + var iteration int + for { + if iteration == 0 { + // rewrite the file clean on the first iteration, in case the + // previous run edited it. + if err := tf.WriteOriginal(); err != nil { + return nil, err + } + } else if iteration == 1 { + // comment out the target line on the 2nd iteration + if err := tf.WriteCommented(); err != nil { + return nil, err + } + } + + // execute the build and gather output + output, exitCode, cmdErr := execBazelBuild(w.bazelExe, ruleLabel) + + // must build clean first time + if iteration == 0 { + if exitCode != 0 { + return nil, fmt.Errorf("%v: target must build first time: %v (%v)", ruleLabel, string(output), cmdErr) + } else { + iteration++ + continue + } + } + + // on subsequent iterations if the exitCode is 0, the process is + // successful. + if exitCode == 0 { + keys := mapKeys(completion) + return keys, nil + } + + if debug { + log.Printf(">>> fixing %s [%s] (iteration %d)\n", tf.filename, importPrefix, iteration) + log.Println(">>>", string(output), cmdErr) + } + + // scan the output for symbols that were not found + symbols, err := scanner.scan(output) + if err != nil { + return nil, fmt.Errorf("scanning output: %w", err) + } + + if debug { + log.Printf("iteration %d symbols: %v", iteration, symbols) + } + + var hasNewResult bool + for _, sym := range symbols { + if _, ok := completion[sym]; !ok { + completion[sym] = true + hasNewResult = true + } + } + + // if no notFound symbols were found, the process failed, but we have + // nothing actionable. + if !hasNewResult { + return nil, fmt.Errorf("expand wildcard failed: final set of notFound symbols: %v", mapKeys(completion)) + } + + // rewrite the file with the updated import (and continue) + if err := tf.Write(makeImportLine(importPrefix, mapKeys(completion))); err != nil { + return nil, fmt.Errorf("failed to write split file: %v", err) + } + + iteration++ + } +} + +func makeImportLine(importPrefix string, symbols []string) string { + return fmt.Sprintf("import %s.{%s}", importPrefix, strings.Join(symbols, ", ")) +} + +// mapKeys sorts the list of map keys +func mapKeys(in map[string]bool) (out []string) { + if len(in) == 0 { + return nil + } + for k := range in { + out = append(out, k) + } + sort.Strings(out) + return +} diff --git a/pkg/wildcardimport/scanner.go b/pkg/wildcardimport/scanner.go new file mode 100644 index 0000000..a7297f1 --- /dev/null +++ b/pkg/wildcardimport/scanner.go @@ -0,0 +1,48 @@ +package wildcardimport + +import ( + "bufio" + "bytes" + "log" + "regexp" + "sort" + "strings" +) + +// omnistac/gum/testutils/DbDataInitUtils.scala:98: error: [rewritten by -quickfix] not found: value FixSessionDao +var notFoundLine = regexp.MustCompile(`^(.*):\d+: error: .*not found: (value|type) ([A-Z].*)$`) + +type outputScanner struct { + debug bool +} + +func (s *outputScanner) scan(output []byte) ([]string, error) { + notFound := make(map[string]bool) + + scanner := bufio.NewScanner(bytes.NewReader(output)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + if s.debug { + log.Println("line:", line) + } + if match := notFoundLine.FindStringSubmatch(line); match != nil { + typeOrValue := match[3] + notFound[typeOrValue] = true + continue + } + } + if err := scanner.Err(); err != nil { + return nil, err + } + + list := make([]string, 0, len(notFound)) + for k := range notFound { + list = append(list, k) + } + sort.Strings(list) + + return list, nil +} diff --git a/pkg/wildcardimport/scanner_test.go b/pkg/wildcardimport/scanner_test.go new file mode 100644 index 0000000..7b20920 --- /dev/null +++ b/pkg/wildcardimport/scanner_test.go @@ -0,0 +1,121 @@ +package wildcardimport + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestOutputScanner(t *testing.T) { + for name, tc := range map[string]struct { + output string + wantErr string + want []string + }{ + "degenerate": { + want: []string{}, + }, + "example 1": { + output: ` +ERROR: /Users/pcj/go/src/github.com/Omnistac/unity/omnistac/gum/dao/BUILD.bazel:48:21: scala @//omnistac/gum/dao:auth_dao_scala failed: (Exit 1): scalac failed: error executing command (from target //omnistac/gum/dao:auth_dao_scala) bazel-out/darwin_arm64-opt-exec-2B5CBBC6/bin/external/io_bazel_rules_scala/src/java/io/bazel/rulesscala/scalac/scalac '--jvm_flag=-Xss32M' ... (remaining 1 argument skipped) +omnistac/gum/dao/AuthDao.scala:15: error: [rewritten by -quickfix] not found: type ZonedDateTime + passwordLastUpdatedTimestamp: ZonedDateTime = ZonedDateTime.now(DateUtils.SYSTEM_TZ), + ^ +omnistac/gum/dao/AuthDao.scala:15: error: [rewritten by -quickfix] not found: value ZonedDateTime + passwordLastUpdatedTimestamp: ZonedDateTime = ZonedDateTime.now(DateUtils.SYSTEM_TZ), + ^ +omnistac/gum/dao/AuthDao.scala:32: error: [rewritten by -quickfix] not found: type ZonedDateTime + def putNewPasswordToken(userId: ActorId, token: String, expirationTs: ZonedDateTime): Future[ResponseStatus] + ^ +3 errors +Build failed +java.lang.RuntimeException: Build failed + at io.bazel.rulesscala.scalac.ScalacWorker.compileScalaSources(ScalacWorker.java:324) + at io.bazel.rulesscala.scalac.ScalacWorker.work(ScalacWorker.java:72) + at io.bazel.rulesscala.worker.Worker.persistentWorkerMain(Worker.java:86) + at io.bazel.rulesscala.worker.Worker.workerMain(Worker.java:39) + at io.bazel.rulesscala.scalac.ScalacWorker.main(ScalacWorker.java:36) +Target //omnistac/gum/dao:auth_dao_scala failed to build +Use --verbose_failures to see the command lines of failed build steps. +INFO: Elapsed time: 2.354s, Critical Path: 1.40s +INFO: 2 processes: 2 internal. +FAILED: Build did NOT complete successfully +`, + want: []string{ + "ZonedDateTime", + }, + }, + "example 2": { + output: ` +ERROR: /Users/pcj/go/src/github.com/Omnistac/unity/omnistac/euds/common/masking/BUILD.bazel:3:21: scala @//omnistac/euds/common/masking:scala failed: (Exit 1): scalac failed: error executing command (from target //omnistac/euds/common/masking:scala) bazel-out/darwin_arm64-opt-exec-2B5CBBC6/bin/external/io_bazel_rules_scala/src/java/io/bazel/rulesscala/scalac/scalac '--jvm_flag=-Xss32M' ... (remaining 1 argument skipped) +omnistac/euds/common/masking/MaskingFacade.scala:11: error: [rewritten by -quickfix] not found: type UserContext + def maskBlotterOrderEvent(event: BlotterOrderEvent, userContext: Option[UserContext]): Option[BlotterOrderEvent] = { + ^ +omnistac/euds/common/masking/MaskingFacade.scala:70: error: [rewritten by -quickfix] not found: type UserContext + userContext: Option[UserContext] + ^ +omnistac/euds/common/masking/MaskingFacade.scala:23: error: [rewritten by -quickfix] not found: type UserContext + userContext: Option[UserContext] + ^ +omnistac/euds/common/masking/MaskingFacade.scala:26: error: [rewritten by -quickfix] not found: type TrumidUserContext + def ignoreTrumidUser(trumidUser: TrumidUserContext): Boolean = + ^ +omnistac/euds/common/masking/MaskingFacade.scala:33: error: [rewritten by -quickfix] not found: type TrumidUserContext + case trumidUser: TrumidUserContext if !ignoreTrumidUser(trumidUser) => + ^ +omnistac/euds/common/masking/MaskingFacade.scala:61: error: [rewritten by -quickfix] not found: type CounterpartyUserContext + case cptyUser: CounterpartyUserContext if orderEvent.getCptyId == cptyUser.counterpartyId => orderEvent + ^ +omnistac/euds/common/masking/MaskingFacade.scala:62: error: [rewritten by -quickfix] not found: type TradingFirmUserContext + case tradingFirmUser: TradingFirmUserContext if orderEvent.getFirmId == tradingFirmUser.firmId => orderEvent + ^ +omnistac/euds/common/masking/MaskingFacade.scala:63: error: [rewritten by -quickfix] not found: type TradingAccountUserContext + case tradingAccountUser: TradingAccountUserContext if orderEvent.getAccountId == tradingAccountUser.accountId => + ^ +omnistac/euds/common/masking/MaskingFacade.scala:73: error: [rewritten by -quickfix] not found: type TrumidUserContext + case trumidUser: TrumidUserContext => + ^ +omnistac/euds/common/masking/MaskingFacade.scala:107: error: [rewritten by -quickfix] not found: type CounterpartyUserContext + case cptyUser: CounterpartyUserContext if stagedIoiEvent.getCptyId == cptyUser.counterpartyId => stagedIoiEvent + ^ +omnistac/euds/common/masking/MaskingFacade.scala:109: error: [rewritten by -quickfix] not found: type TradingFirmUserContext + case tradingFirmUser: TradingFirmUserContext if stagedIoiEvent.getFirmId == tradingFirmUser.firmId => + ^ +omnistac/euds/common/masking/MaskingFacade.scala:112: error: [rewritten by -quickfix] not found: type TradingAccountUserContext + case tradingAccountUser: TradingAccountUserContext + ^ +12 errors +Build failed +java.lang.RuntimeException: Build failed + at io.bazel.rulesscala.scalac.ScalacInvoker.invokeCompiler(ScalacInvoker.java:55) + at io.bazel.rulesscala.scalac.ScalacWorker.compileScalaSources(ScalacWorker.java:253) + at io.bazel.rulesscala.scalac.ScalacWorker.work(ScalacWorker.java:69) + at io.bazel.rulesscala.worker.Worker.persistentWorkerMain(Worker.java:86) + at io.bazel.rulesscala.worker.Worker.workerMain(Worker.java:39) + at io.bazel.rulesscala.scalac.ScalacWorker.main(ScalacWorker.java:33) +`, + want: []string{ + "CounterpartyUserContext", + "TradingAccountUserContext", + "TradingFirmUserContext", + "TrumidUserContext", + "UserContext", + }, + }, + } { + t.Run(name, func(t *testing.T) { + scanner := &outputScanner{} + got, err := scanner.scan([]byte(tc.output)) + var gotErr string + if err != nil { + gotErr = err.Error() + } + if diff := cmp.Diff(tc.wantErr, gotErr); diff != "" { + t.Errorf("error (-want +got):\n%s", diff) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("result (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/wildcardimport/text_file.go b/pkg/wildcardimport/text_file.go new file mode 100644 index 0000000..6ff191c --- /dev/null +++ b/pkg/wildcardimport/text_file.go @@ -0,0 +1,86 @@ +package wildcardimport + +import ( + "bufio" + "fmt" + "io" + "io/fs" + "os" + "strings" +) + +type TextFile struct { + filename string + info fs.FileInfo + + beforeLines []string + targetLine string + afterLines []string +} + +// NewTextFileFromFilename constructs a new text file split on the target line. +func NewTextFileFromFilename(filename string, targetLine string) (*TextFile, error) { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + info, err := f.Stat() + if err != nil { + return nil, err + } + return NewTextFileFromReader(filename, info, f, targetLine) +} + +// NewTextFileFromFilename constructs a new text file split on the target line. +func NewTextFileFromReader(filename string, info fs.FileInfo, in io.Reader, targetLine string) (*TextFile, error) { + + file := new(TextFile) + file.filename = filename + file.info = info + + scanner := bufio.NewScanner(in) + for scanner.Scan() { + line := scanner.Text() + if line == targetLine { + file.targetLine = targetLine + continue + } + if line == "// "+targetLine { // already commented out (subsequent run) + file.targetLine = targetLine + continue + } + if file.targetLine == "" { + file.beforeLines = append(file.beforeLines, line) + } else { + file.afterLines = append(file.afterLines, line) + } + } + if file.targetLine == "" { + return nil, fmt.Errorf("%s: import target line not found: %q", filename, targetLine) + } + + // add a final entry to afterLines so that the file ends with a single newline + file.afterLines = append(file.afterLines, "") + + return file, nil +} + +func (f *TextFile) WriteOriginal() error { + return f.Write(f.targetLine) +} + +func (f *TextFile) WriteCommented() error { + return f.Write("// " + f.targetLine) +} + +func (f *TextFile) Write(targetLine string) error { + lines := append(f.beforeLines, targetLine) + lines = append(lines, f.afterLines...) + content := strings.Join(lines, "\n") + data := []byte(content) + if err := os.WriteFile(f.filename, data, f.info.Mode()); err != nil { + return err + } + return nil +}