Skip to content

Commit

Permalink
Feature: bazel-assisted removal of wildcard imports (#119)
Browse files Browse the repository at this point in the history
* Initial wildcard fixer
* Add gazelle:scala_fix_wildcard_imports directive
* Initial implementation of incremental wildcard fixing
* Add TestScalaConfigParseFixWildcardImports cases
  • Loading branch information
pcj authored Jul 27, 2024
1 parent fb1909a commit c38eb11
Show file tree
Hide file tree
Showing 13 changed files with 771 additions and 44 deletions.
15 changes: 15 additions & 0 deletions cmd/wildcardimportfixer/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library")

go_library(
name = "wildcardimportfixer_lib",
srcs = ["main.go"],
importpath = "github.com/stackb/scala-gazelle/cmd/wildcardimportfixer",
visibility = ["//visibility:private"],
deps = ["//pkg/wildcardimport"],
)

go_binary(
name = "wildcardimportfixer",
embed = [":wildcardimportfixer_lib"],
visibility = ["//visibility:public"],
)
77 changes: 77 additions & 0 deletions cmd/wildcardimportfixer/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package main

import (
"flag"
"fmt"
"log"
"os"

"github.com/stackb/scala-gazelle/pkg/wildcardimport"
)

const (
executableName = "wildcardimportfixer"
)

type config struct {
ruleLabel string
targetFilename string
importPrefix string
bazelExe string
}

func main() {
log.SetPrefix(executableName + ": ")
log.SetFlags(0) // don't print timestamps

cfg, err := parseFlags(os.Args[1:])
if err != nil {
log.Fatal(err)
}

if err := run(cfg); err != nil {
log.Fatalln("ERROR:", err)
}

}

func parseFlags(args []string) (*config, error) {
cfg := new(config)

fs := flag.NewFlagSet(executableName, flag.ExitOnError)
fs.StringVar(&cfg.ruleLabel, "rule_label", "", "the rule label to iteratively build")
fs.StringVar(&cfg.targetFilename, "target_filename", "", "the scala file to fix")
fs.StringVar(&cfg.importPrefix, "import_prefix", "", "the scala import prefix to set")
fs.StringVar(&cfg.bazelExe, "bazel_executable", "bazel", "the path to the bazel executable")
fs.Usage = func() {
fmt.Fprintf(flag.CommandLine.Output(), "usage: %s OPTIONS", executableName)
fs.PrintDefaults()
}
if err := fs.Parse(args); err != nil {
return nil, err
}

if cfg.ruleLabel == "" {
log.Fatal("-rule_label is required")
}

return cfg, nil
}

func run(cfg *config) error {

var err error

fixer := wildcardimport.NewFixer(&wildcardimport.FixerOptions{
BazelExecutable: cfg.bazelExe,
})

symbols, err := fixer.Fix(cfg.ruleLabel, cfg.targetFilename, cfg.importPrefix)
if err != nil {
return err
}

log.Println("FIXED", cfg.targetFilename, symbols)

return nil
}
1 change: 1 addition & 0 deletions language/scala/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ go_library(
"//pkg/resolver",
"//pkg/scalaconfig",
"//pkg/scalarule",
"//pkg/wildcardimport",
"@bazel_gazelle//config:go_default_library",
"@bazel_gazelle//label:go_default_library",
"@bazel_gazelle//language:go_default_library",
Expand Down
1 change: 1 addition & 0 deletions language/scala/language_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ func ExampleLanguage_KnownDirectives() {
}
// output:
// scala_debug
// scala_fix_wildcard_imports
// scala_rule
// resolve_glob
// resolve_conflicts
Expand Down
64 changes: 43 additions & 21 deletions language/scala/scala_rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package scala
import (
"fmt"
"log"
"os"
"path/filepath"
"sort"
"strings"

Expand All @@ -16,6 +18,7 @@ import (
"github.com/stackb/scala-gazelle/pkg/resolver"
"github.com/stackb/scala-gazelle/pkg/scalaconfig"
"github.com/stackb/scala-gazelle/pkg/scalarule"
"github.com/stackb/scala-gazelle/pkg/wildcardimport"
)

const (
Expand Down Expand Up @@ -48,6 +51,14 @@ type scalaRule struct {
exports map[string]resolve.ImportSpec
}

var bazel = "bazel"

func init() {
if bazelExe, ok := os.LookupEnv("SCALA_GAZELLE_BAZEL_EXECUTABLE"); ok {
bazel = bazelExe
}
}

func newScalaRule(
ctx *scalaRuleContext,
rule *sppb.Rule,
Expand Down Expand Up @@ -135,14 +146,6 @@ func (r *scalaRule) ResolveImports(rctx *scalarule.ResolveContext) resolver.Impo
if resolved, ok := sc.ResolveConflict(rctx.Rule, imports, item.imp, item.sym); ok {
item.imp.Symbol = resolved
} else {
if r.ctx.scalaConfig.ShouldAnnotateWildcardImports() && item.sym.Type == sppb.ImportType_PROTO_PACKAGE {
if scope, ok := r.ctx.scope.GetScope(item.imp.Imp); ok {
wildcardImport := item.imp.Src // original symbol name having underscore suffix
r.handleWildcardImport(item.imp.Source, wildcardImport, scope)
} else {

}
}
fmt.Println(resolver.SymbolConfictMessage(item.sym, item.imp, rctx.From))
}
}
Expand Down Expand Up @@ -285,6 +288,23 @@ func (r *scalaRule) fileImports(imports resolver.ImportMap, file *sppb.File) {
// gather direct imports and import scopes
for _, name := range file.Imports {
if wimp, ok := resolver.IsWildcardImport(name); ok {
filename := filepath.Join(r.ctx.scalaConfig.Rel(), file.Filename)
if r.ctx.scalaConfig.ShouldFixWildcardImport(filename, name) {
symbolNames, err := r.fixWildcardImport(filename, wimp)
if err != nil {
log.Fatalf("fixing wildcard imports for %s (%s): %v", file.Filename, wimp, err)
}
for _, symName := range symbolNames {
fqn := wimp + "." + symName
if sym, ok := r.ctx.scope.GetSymbol(fqn); ok {
putImport(resolver.NewResolvedNameImport(sym.Name, file, fqn, sym))
} else {
if debugUnresolved {
log.Printf("warning: unresolved fix wildcard import: symbol %q: was not found' (%s)", name, file.Filename)
}
}
}
}
// collect the (package) symbol for import
if sym, ok := r.ctx.scope.GetSymbol(name); ok {
putImport(resolver.NewResolvedNameImport(sym.Name, file, name, sym))
Expand Down Expand Up @@ -377,19 +397,6 @@ func (r *scalaRule) fileImports(imports resolver.ImportMap, file *sppb.File) {
}
}

func (r *scalaRule) handleWildcardImport(file *sppb.File, imp string, scope resolver.Scope) {
names := make([]string, 0)
for _, name := range file.Names {
if _, ok := scope.GetSymbol(name); ok {
names = append(names, name)
}
}
if len(names) > 0 {
sort.Strings(names)
log.Printf("[%s]: import %s.{%s}", file.Filename, strings.TrimSuffix(imp, "._"), strings.Join(names, ", "))
}
}

// Provides implements part of the scalarule.Rule interface.
func (r *scalaRule) Provides() []resolve.ImportSpec {
exports := make([]resolve.ImportSpec, 0, len(r.exports))
Expand Down Expand Up @@ -428,6 +435,21 @@ func (r *scalaRule) putExport(imp string) {
r.exports[imp] = resolve.ImportSpec{Imp: imp, Lang: scalaLangName}
}

func (r *scalaRule) fixWildcardImport(filename, wimp string) ([]string, error) {
fixer := wildcardimport.NewFixer(&wildcardimport.FixerOptions{
BazelExecutable: bazel,
})

absFilename := filepath.Join(wildcardimport.GetBuildWorkspaceDirectory(), filename)
ruleLabel := label.New("", r.ctx.scalaConfig.Rel(), r.ctx.rule.Name()).String()
symbols, err := fixer.Fix(ruleLabel, absFilename, wimp)
if err != nil {
return nil, err
}

return symbols, nil
}

func isBinaryRule(kind string) bool {
return strings.Contains(kind, "binary") || strings.Contains(kind, "test")
}
Expand Down
84 changes: 69 additions & 15 deletions pkg/scalaconfig/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ import (
type debugAnnotation int

const (
DebugUnknown debugAnnotation = 0
DebugImports debugAnnotation = 1
DebugExports debugAnnotation = 2
DebugWildcardImports debugAnnotation = 3
scalaLangName = "scala"
DebugUnknown debugAnnotation = 0
DebugImports debugAnnotation = 1
DebugExports debugAnnotation = 2
scalaLangName = "scala"
)

const (
scalaDebugDirective = "scala_debug"
scalaFixWildcardImportDirective = "scala_fix_wildcard_imports"
scalaRuleDirective = "scala_rule"
resolveGlobDirective = "resolve_glob"
resolveConflictsDirective = "resolve_conflicts"
Expand All @@ -41,6 +41,7 @@ const (
func DirectiveNames() []string {
return []string{
scalaDebugDirective,
scalaFixWildcardImportDirective,
scalaRuleDirective,
resolveGlobDirective,
resolveConflictsDirective,
Expand All @@ -58,6 +59,7 @@ type Config struct {
overrides []*overrideSpec
implicitImports []*implicitImportSpec
resolveFileSymbolNames []*resolveFileSymbolNameSpec
fixWildcardImportSpecs []*fixWildcardImportSpec
rules map[string]*scalarule.Config
labelNameRewrites map[string]resolver.LabelNameRewriteSpec
annotations map[debugAnnotation]interface{}
Expand Down Expand Up @@ -122,6 +124,9 @@ func (c *Config) clone(config *config.Config, rel string) *Config {
if c.resolveFileSymbolNames != nil {
clone.resolveFileSymbolNames = c.resolveFileSymbolNames[:]
}
if c.fixWildcardImportSpecs != nil {
clone.fixWildcardImportSpecs = c.fixWildcardImportSpecs[:]
}
return clone
}

Expand Down Expand Up @@ -175,6 +180,8 @@ func (c *Config) ParseDirectives(directives []rule.Directive) (err error) {
if err != nil {
return fmt.Errorf(`invalid directive: "gazelle:%s %s": %w`, d.Key, d.Value, err)
}
case scalaFixWildcardImportDirective:
c.parseFixWildcardImport(d)
case resolveGlobDirective:
c.parseResolveGlobDirective(d)
case resolveWithDirective:
Expand Down Expand Up @@ -202,7 +209,7 @@ func (c *Config) parseScalaRuleDirective(d rule.Directive) error {
return fmt.Errorf("expected three or more fields, got %d", len(fields))
}
name, param, value := fields[0], fields[1], strings.Join(fields[2:], " ")
r, err := c.getOrCreateScalaRuleConfig(c.config, name)
r, err := c.getOrCreateScalaRuleConfig(name)
if err != nil {
return err
}
Expand Down Expand Up @@ -247,6 +254,24 @@ func (c *Config) parseResolveWithDirective(d rule.Directive) {
})
}

func (c *Config) parseFixWildcardImport(d rule.Directive) {
parts := strings.Fields(d.Value)
if len(parts) < 2 {
log.Fatalf("invalid gazelle:%s directive: expected [FILENAME_PATTERN [+|-]IMPORT_PATTERN...], got %v", scalaFixWildcardImportDirective, parts)
return
}
filenamePattern := parts[0]

for _, part := range parts[1:] {
intent := collections.ParseIntent(part)
c.fixWildcardImportSpecs = append(c.fixWildcardImportSpecs, &fixWildcardImportSpec{
filenamePattern: filenamePattern,
importPattern: *intent,
})
}

}

func (c *Config) parseResolveFileSymbolNames(d rule.Directive) {
parts := strings.Fields(d.Value)
if len(parts) < 2 {
Expand Down Expand Up @@ -319,10 +344,10 @@ func (c *Config) parseScalaAnnotation(d rule.Directive) error {
return nil
}

func (c *Config) getOrCreateScalaRuleConfig(config *config.Config, name string) (*scalarule.Config, error) {
func (c *Config) getOrCreateScalaRuleConfig(name string) (*scalarule.Config, error) {
r, ok := c.rules[name]
if !ok {
r = scalarule.NewConfig(config, name)
r = scalarule.NewConfig(c.config, name)
r.Implementation = name
c.rules[name] = r
}
Expand Down Expand Up @@ -357,11 +382,6 @@ func (c *Config) ConfiguredRules() []*scalarule.Config {
return rules
}

func (c *Config) ShouldAnnotateWildcardImports() bool {
_, ok := c.annotations[DebugWildcardImports]
return ok
}

func (c *Config) ShouldAnnotateImports() bool {
_, ok := c.annotations[DebugImports]
return ok
Expand All @@ -372,6 +392,36 @@ func (c *Config) ShouldAnnotateExports() bool {
return ok
}

// ShouldFixWildcardImport tests whether the given symbol name pattern
// should be resolved within the scope of the given filename pattern.
// resolveFileSymbolNameSpecs represent a whitelist; if no patterns match, false
// is returned.
func (c *Config) ShouldFixWildcardImport(filename, wimp string) bool {
if len(c.fixWildcardImportSpecs) > 0 {
log.Println("should fix wildcard import?", filename, wimp)
}
for _, spec := range c.fixWildcardImportSpecs {
hasStarChar := strings.Contains(spec.filenamePattern, "*")
if hasStarChar {
if ok, _ := doublestar.Match(spec.filenamePattern, filename); !ok {
log.Println("should fix wildcard import? FILENAME GLOB MATCH FAILED", filename, spec.filenamePattern)
continue
}
} else {
if !strings.HasSuffix(filename, spec.filenamePattern) {
log.Println("should fix wildcard import? FILENAME SUFFIX MATCH FAILED", filename, spec.filenamePattern)
continue
}
}
if ok, _ := doublestar.Match(spec.importPattern.Value, wimp); !ok {
log.Println("should fix wildcard import? IMPORT PATTERN MATCH FAILED", filename, spec.importPattern.Value, wimp)
continue
}
return spec.importPattern.Want
}
return false
}

// ShouldResolveFileSymbolName tests whether the given symbol name pattern
// should be resolved within the scope of the given filename pattern.
// resolveFileSymbolNameSpecs represent a whitelist; if no patterns match, false
Expand Down Expand Up @@ -565,10 +615,14 @@ type resolveFileSymbolNameSpec struct {
symbolName collections.Intent
}

type fixWildcardImportSpec struct {
// filenamePattern is the filename glob wildcard filenamePattern to test
filenamePattern string
importPattern collections.Intent
}

func parseAnnotation(val string) debugAnnotation {
switch val {
case "wildcardimports":
return DebugWildcardImports
case "imports":
return DebugImports
case "exports":
Expand Down
Loading

0 comments on commit c38eb11

Please sign in to comment.