diff --git a/cmd/fix.go b/cmd/fix.go index 86e4c76e..457fcf47 100644 --- a/cmd/fix.go +++ b/cmd/fix.go @@ -270,9 +270,15 @@ func fix(args []string, params *fixCommandParams) error { return fmt.Errorf("could not find potential roots: %w", err) } + versionsMap, err := config.AllRegoVersions(regalDir.Name(), &userConfig) + if err != nil { + return fmt.Errorf("failed to get all Rego versions: %w", err) + } + f := fixer.NewFixer() f.RegisterRoots(roots...) f.RegisterFixes(fixes.NewDefaultFixes()...) + f.SetRegoVersionsMap(versionsMap) if !slices.Contains([]string{"error", "rename"}, params.conflictMode) { return fmt.Errorf("invalid conflict mode: %s, expected 'error' or 'rename'", params.conflictMode) diff --git a/cmd/lint.go b/cmd/lint.go index 31ce4b62..dac93dc3 100644 --- a/cmd/lint.go +++ b/cmd/lint.go @@ -118,9 +118,10 @@ func init() { warningsFound := 0 for _, violation := range rep.Violations { - if violation.Level == "error" { + switch violation.Level { + case "error": errorsFound++ - } else if violation.Level == "warning" { + case "warning": warningsFound++ } } @@ -407,7 +408,8 @@ func getWriterForOutputFile(filename string) (io.Writer, error) { func formatError(format string, err error) error { // currently, JSON and SARIF will get the same generic JSON error format - if format == formatJSON || format == formatSarif { + switch format { + case formatJSON, formatSarif: bs, err := json.MarshalIndent(map[string]interface{}{ "errors": []string{err.Error()}, }, "", " ") @@ -416,7 +418,7 @@ func formatError(format string, err error) error { } return fmt.Errorf("%s", string(bs)) - } else if format == formatJunit { + case formatJunit: testSuites := junit.Testsuites{ Name: "regal", } diff --git a/e2e/cli_test.go b/e2e/cli_test.go index 52e8d7dc..7fffbfe3 100644 --- a/e2e/cli_test.go +++ b/e2e/cli_test.go @@ -865,7 +865,7 @@ project: - path: v0 rego-version: 0 - path: v1 - rego-version: 0 + rego-version: 1 `, "foo/main.rego": `package wow diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 0721a8dd..4f75ffa9 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -2223,9 +2223,7 @@ func (l *LanguageServer) handleTextDocumentFormatting( params.TextDocument.URI: oldContent, }) - input, err := memfp.ToInput(func(fileName string) ast.RegoVersion { - return l.determineVersionForFile(uri.FromPath(l.clientIdentifier, fileName)) - }) + input, err := memfp.ToInput(l.loadedConfigAllRegoVersions.Clone()) if err != nil { return nil, fmt.Errorf("failed to create fixer input: %w", err) } diff --git a/pkg/fixer/fileprovider/cache.go b/pkg/fixer/fileprovider/cache.go index c3d7bc0b..3ed18aed 100644 --- a/pkg/fixer/fileprovider/cache.go +++ b/pkg/fixer/fileprovider/cache.go @@ -8,7 +8,6 @@ import ( "github.com/styrainc/regal/internal/lsp/cache" "github.com/styrainc/regal/internal/lsp/clients" "github.com/styrainc/regal/internal/lsp/uri" - "github.com/styrainc/regal/internal/parse" "github.com/styrainc/regal/internal/util" "github.com/styrainc/regal/pkg/rules" ) @@ -90,23 +89,11 @@ func (c *CacheFileProvider) Rename(from, to string) error { return nil } -func (c *CacheFileProvider) ToInput(versionLookup func(string) ast.RegoVersion) (rules.Input, error) { - strContents := make(map[string]string) - modules := make(map[string]*ast.Module) - - for filename, content := range c.Cache.GetAllFiles() { - var err error - - strContents[filename] = content - - po := parse.ParserOptions() - po.RegoVersion = versionLookup(filename) - - modules[filename], err = parse.ModuleWithOpts(filename, strContents[filename], po) - if err != nil { - return rules.Input{}, fmt.Errorf("failed to parse module %s: %w", filename, err) - } +func (c *CacheFileProvider) ToInput(versionsMap map[string]ast.RegoVersion) (rules.Input, error) { + input, err := rules.InputFromMap(c.Cache.GetAllFiles(), versionsMap) + if err != nil { + return rules.Input{}, fmt.Errorf("failed to create input: %w", err) } - return rules.NewInput(strContents, modules), nil + return input, nil } diff --git a/pkg/fixer/fileprovider/fp.go b/pkg/fixer/fileprovider/fp.go index bec60456..f5402e7e 100644 --- a/pkg/fixer/fileprovider/fp.go +++ b/pkg/fixer/fileprovider/fp.go @@ -16,7 +16,7 @@ type FileProvider interface { Delete(string) error Rename(string, string) error - ToInput(versionLookup func(string) ast.RegoVersion) (rules.Input, error) + ToInput(versionsMap map[string]ast.RegoVersion) (rules.Input, error) } type RenameConflictError struct { diff --git a/pkg/fixer/fileprovider/inmem.go b/pkg/fixer/fileprovider/inmem.go index 623b9ffb..e9bb8496 100644 --- a/pkg/fixer/fileprovider/inmem.go +++ b/pkg/fixer/fileprovider/inmem.go @@ -6,7 +6,6 @@ import ( "github.com/open-policy-agent/opa/v1/ast" - "github.com/styrainc/regal/internal/parse" "github.com/styrainc/regal/internal/util" "github.com/styrainc/regal/pkg/rules" ) @@ -112,20 +111,11 @@ func (p *InMemoryFileProvider) DeletedFiles() []string { return util.Keys(p.deletedFiles) } -func (p *InMemoryFileProvider) ToInput(versionLookup func(string) ast.RegoVersion) (rules.Input, error) { - modules := make(map[string]*ast.Module) - - for filename, content := range p.files { - var err error - - po := parse.ParserOptions() - po.RegoVersion = versionLookup(filename) - - modules[filename], err = parse.ModuleWithOpts(filename, content, po) - if err != nil { - return rules.Input{}, fmt.Errorf("failed to parse module %s: %w", filename, err) - } +func (p *InMemoryFileProvider) ToInput(versionsMap map[string]ast.RegoVersion) (rules.Input, error) { + input, err := rules.InputFromMap(p.files, versionsMap) + if err != nil { + return rules.Input{}, fmt.Errorf("failed to create input: %w", err) } - return rules.NewInput(p.files, modules), nil + return input, nil } diff --git a/pkg/fixer/fixer.go b/pkg/fixer/fixer.go index f7b0fbfb..b47e9a63 100644 --- a/pkg/fixer/fixer.go +++ b/pkg/fixer/fixer.go @@ -40,7 +40,7 @@ type Fixer struct { registeredMandatoryFixes map[string]any onConflictOperation OnConflictOperation registeredRoots []string - getRegoVersion func(string) ast.RegoVersion + versionsMap map[string]ast.RegoVersion } // SetOnConflictOperation sets the fixer's behavior when a conflict occurs. @@ -48,8 +48,10 @@ func (f *Fixer) SetOnConflictOperation(operation OnConflictOperation) { f.onConflictOperation = operation } -func (f *Fixer) SetRegoVersionLookup(fn func(string) ast.RegoVersion) { - f.getRegoVersion = fn +// SetRegoVersionsMap sets the mapping of path prefixes to versions for the +// fixer to use when creating input for fixer runs. +func (f *Fixer) SetRegoVersionsMap(versionsMap map[string]ast.RegoVersion) { + f.versionsMap = versionsMap } // RegisterFixes sets the fixes that will be fixed if there are related linter @@ -293,16 +295,14 @@ func (f *Fixer) applyLinterFixes( return fmt.Errorf("failed to list files: %w", err) } + if f.versionsMap == nil { + return errors.New("rego versions map not set") + } + for { fixMadeInIteration := false - in, err := fp.ToInput(func(fileName string) ast.RegoVersion { - if f.getRegoVersion == nil { - return ast.RegoV1 - } - - return f.getRegoVersion(fileName) - }) + in, err := fp.ToInput(f.versionsMap) if err != nil { return fmt.Errorf("failed to generate linter input: %w", err) } diff --git a/pkg/fixer/fixer_test.go b/pkg/fixer/fixer_test.go index 8f93ff6a..9de762a4 100644 --- a/pkg/fixer/fixer_test.go +++ b/pkg/fixer/fixer_test.go @@ -30,14 +30,8 @@ deny = true memfp := fileprovider.NewInMemoryFileProvider(policies) - input, err := memfp.ToInput(func(fileName string) ast.RegoVersion { - if fileName == "/root/main/main.rego" { - return ast.RegoV1 - } - - t.Fatalf("unexpected file when looking up version %s", fileName) - - return ast.RegoUndefined + input, err := memfp.ToInput(map[string]ast.RegoVersion{ + "/root/main": ast.RegoV1, }) if err != nil { t.Fatalf("failed to create input: %v", err) @@ -50,6 +44,9 @@ deny = true f := NewFixer() f.RegisterFixes(fixes.NewDefaultFixes()...) f.RegisterRoots("/root") + f.SetRegoVersionsMap(map[string]ast.RegoVersion{ + "/root/main": ast.RegoV1, + }) fixReport, err := f.Fix(context.Background(), &l, memfp) if err != nil { @@ -129,7 +126,7 @@ func TestFixerWithRegisterMandatoryFixes(t *testing.T) { t.Parallel() policies := map[string]string{ - "main.rego": `package test + "/root/main/main.rego": `package test allow { true #no space @@ -141,14 +138,8 @@ deny = true memfp := fileprovider.NewInMemoryFileProvider(policies) - input, err := memfp.ToInput(func(fileName string) ast.RegoVersion { - if fileName == "main.rego" { - return ast.RegoV0 - } - - t.Fatalf("unexpected file when looking up version %s", fileName) - - return ast.RegoUndefined + input, err := memfp.ToInput(map[string]ast.RegoVersion{ + "/root/main": ast.RegoV0, }) if err != nil { t.Fatalf("failed to create input: %v", err) @@ -181,12 +172,12 @@ deny = true } expectedFileFixedViolations := map[string][]string{ - "main.rego": {"use-rego-v1"}, + "/root/main/main.rego": {"use-rego-v1"}, } expectedFileContents := map[string]string{ // note that since only the rego-v1-format fix is run, the // no-whitespace-comment fix is not applied - "main.rego": `package test + "/root/main/main.rego": `package test import rego.v1 diff --git a/pkg/linter/linter.go b/pkg/linter/linter.go index 3be7cd3f..bbeff70a 100644 --- a/pkg/linter/linter.go +++ b/pkg/linter/linter.go @@ -292,9 +292,6 @@ func (l Linter) Lint(ctx context.Context) (report.Report, error) { var versionsMap map[string]ast.RegoVersion - // TODO: How should we deal with this in the language server? - // AllRegoVersions will call WalkDir on the root to find manifests, but that's obviously not - // going to work for a file:// path.. if l.pathPrefix != "" && !strings.HasPrefix(l.pathPrefix, "file://") { versionsMap, err = config.AllRegoVersions(l.pathPrefix, conf) if err != nil && l.debugMode { @@ -302,7 +299,7 @@ func (l Linter) Lint(ctx context.Context) (report.Report, error) { } } - inputFromPaths, err := rules.InputFromPaths(filtered, versionsMap) + inputFromPaths, err := rules.InputFromPaths(filtered, l.pathPrefix, versionsMap) if err != nil { return report.Report{}, fmt.Errorf("errors encountered when reading files to lint: %w", err) } diff --git a/pkg/rules/rules.go b/pkg/rules/rules.go index cb19a39b..43e04026 100644 --- a/pkg/rules/rules.go +++ b/pkg/rules/rules.go @@ -71,7 +71,7 @@ func NewInput(fileContent map[string]string, modules map[string]*ast.Module) Inp // function. When the versionsMap is not nil/empty, files in a directory matching a key in the map will be parsed with // the corresponding Rego version. If not provided, the file may be parsed multiple times in order to determine the // version (best-effort and may include false positives). -func InputFromPaths(paths []string, versionsMap map[string]ast.RegoVersion) (Input, error) { +func InputFromPaths(paths []string, prefix string, versionsMap map[string]ast.RegoVersion) (Input, error) { if len(paths) == 1 && paths[0] == "-" { return inputFromStdin() } @@ -96,30 +96,13 @@ func InputFromPaths(paths []string, versionsMap map[string]ast.RegoVersion) (Inp errors := make([]error, 0, len(paths)) - parserOptions := parse.ParserOptions() - for _, path := range paths { go func(path string) { defer wg.Done() - parserOptions.RegoVersion = ast.RegoUndefined - - // Check if the path matches any directory where a specific Rego version is set, - // and if so use that instead of having to parse the file (potentially multiple times) - // in order to determine the Rego version. - // If a project-wide version has been set, it'll be found under the path "", which will - // always be the last entry in versionedDirs, and only match if no specific directory - // matches. - if len(versionsMap) > 0 { - dir := filepath.Dir(path) - for _, versionedDir := range versionedDirs { - if strings.HasPrefix(dir, versionedDir) { - parserOptions.RegoVersion = versionsMap[versionedDir] - - break - } - } - } + parserOptions := parse.ParserOptions() + + parserOptions.RegoVersion = RegoVersionFromVersionsMap(versionsMap, strings.TrimPrefix(path, prefix), ast.RegoUndefined) result, err := regoWithOpts(path, parserOptions) @@ -146,6 +129,58 @@ func InputFromPaths(paths []string, versionsMap map[string]ast.RegoVersion) (Inp return NewInput(fileContent, modules), nil } +// InputFromMap creates a new Input from a map of file paths to their contents. +// This function uses a vesrionsMap to determine the parser version for each +// file before parsing the module. +func InputFromMap(files map[string]string, versionsMap map[string]ast.RegoVersion) (Input, error) { + fileContent := make(map[string]string, len(files)) + modules := make(map[string]*ast.Module, len(files)) + parserOptions := parse.ParserOptions() + + for path, content := range files { + fileContent[path] = content + + parserOptions.RegoVersion = RegoVersionFromVersionsMap(versionsMap, path, ast.RegoUndefined) + + mod, err := parse.ModuleWithOpts(path, content, parserOptions) + if err != nil { + return Input{}, fmt.Errorf("failed to parse module %s: %w", path, err) + } + + modules[path] = mod + } + + return NewInput(fileContent, modules), nil +} + +// RegoVersionFromVersionsMap takes a mapping of file path prefixes, typically +// representing the roots of the project, and the expected Rego version for +// each. Using this, it finds the longest matching prefix for the given filename +// and returns the defaultVersion if to matching prefix is found. +func RegoVersionFromVersionsMap(versionsMap map[string]ast.RegoVersion, filename string, defaultVersion ast.RegoVersion) ast.RegoVersion { + if len(versionsMap) == 0 { + return defaultVersion + } + + selectedVersion := defaultVersion + + var longestMatch int + + dir := filepath.Dir(filename) + for versionedDir := range versionsMap { + matchingVersionedDir := versionedDir + "/" + + if strings.HasPrefix(dir+"/", matchingVersionedDir) { + if len(versionedDir) > longestMatch { + longestMatch = len(versionedDir) + selectedVersion = versionsMap[versionedDir] + } + } + } + + return selectedVersion +} + func regoWithOpts(path string, opts ast.ParserOptions) (*regoFile, error) { path = filepath.Clean(path) diff --git a/pkg/rules/rules_test.go b/pkg/rules/rules_test.go index 20604cf4..3b8d7c20 100644 --- a/pkg/rules/rules_test.go +++ b/pkg/rules/rules_test.go @@ -38,3 +38,97 @@ p { true }`, }) } } + +func TestRegoVersionFromVersionMap(t *testing.T) { + testCases := map[string]struct { + VersionsMap map[string]ast.RegoVersion + Filename string + ExpectedVersion ast.RegoVersion + }{ + "file has no root in version map": { + VersionsMap: map[string]ast.RegoVersion{ + "/foo/bar": ast.RegoV1, + "/bar": ast.RegoV0, + "/unknown": ast.RegoUndefined, + }, + Filename: "/baz/qux.rego", + ExpectedVersion: ast.RegoUndefined, + }, + "file has version from current dir": { + VersionsMap: map[string]ast.RegoVersion{ + "/foo": ast.RegoV1, + "/bar": ast.RegoV0, + "/unknown": ast.RegoUndefined, + }, + Filename: "/foo/bar.rego", + ExpectedVersion: ast.RegoV1, + }, + "file has version from parent dir": { + VersionsMap: map[string]ast.RegoVersion{ + "/foo": ast.RegoV1, + "/bar": ast.RegoV0, + "/unknown": ast.RegoUndefined, + }, + Filename: "/foo/bar/baz.rego", + ExpectedVersion: ast.RegoV1, + }, + "file has version from grandparent dir": { + VersionsMap: map[string]ast.RegoVersion{ + "/foo": ast.RegoV1, + "/bar": ast.RegoV0, + "/unknown": ast.RegoUndefined, + }, + Filename: "/foo/bar/baz/qux.rego", + ExpectedVersion: ast.RegoV1, + }, + "project roots are subdirs and overlap": { + VersionsMap: map[string]ast.RegoVersion{ + "/foo/bar": ast.RegoV1, + "/foo": ast.RegoV0, + "/unknown": ast.RegoUndefined, + }, + Filename: "/foo/bar/baz/qux.rego", + ExpectedVersion: ast.RegoV1, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + actualVersion := RegoVersionFromVersionsMap(tc.VersionsMap, tc.Filename, ast.RegoUndefined) + if actualVersion != tc.ExpectedVersion { + t.Errorf("Expected %v, got %v", tc.ExpectedVersion, actualVersion) + } + }) + } +} + +func TestInputFromMap(t *testing.T) { + t.Parallel() + + versionsMap := map[string]ast.RegoVersion{ + "/foo/bar": ast.RegoV1, + "/foo": ast.RegoV0, + } + + files := map[string]string{ + "/foo/bar/main.rego": `package main +# v1 syntax is allowed + +allow if input.admin +`, + "/foo/main.rego": `package main +# v0 syntax is allowed + +allow[msg] { msg := "hello" } +`, + } + + input, err := InputFromMap(files, versionsMap) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if len(input.Modules) != 2 { + t.Fatalf("Expected 2 modules, got %d", len(input.Modules)) + } +}