Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
charlieegan3 committed Jan 15, 2025
1 parent ada2575 commit 8ce3897
Show file tree
Hide file tree
Showing 12 changed files with 196 additions and 96 deletions.
6 changes: 6 additions & 0 deletions cmd/fix.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,15 @@ func fix(args []string, params *fixCommandParams) error {
return fmt.Errorf("could not find potential roots: %w", err)
}

versionsMap, err := config.AllRegoVersions(regalDir.Name(), &userConfig)
if err != nil {
return fmt.Errorf("failed to get all Rego versions: %w", err)
}

f := fixer.NewFixer()
f.RegisterRoots(roots...)
f.RegisterFixes(fixes.NewDefaultFixes()...)
f.SetRegoVersionsMap(versionsMap)

if !slices.Contains([]string{"error", "rename"}, params.conflictMode) {
return fmt.Errorf("invalid conflict mode: %s, expected 'error' or 'rename'", params.conflictMode)
Expand Down
10 changes: 6 additions & 4 deletions cmd/lint.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,10 @@ func init() {
warningsFound := 0

for _, violation := range rep.Violations {
if violation.Level == "error" {
switch violation.Level {
case "error":
errorsFound++
} else if violation.Level == "warning" {
case "warning":
warningsFound++
}
}
Expand Down Expand Up @@ -407,7 +408,8 @@ func getWriterForOutputFile(filename string) (io.Writer, error) {

func formatError(format string, err error) error {
// currently, JSON and SARIF will get the same generic JSON error format
if format == formatJSON || format == formatSarif {
switch format {
case formatJSON, formatSarif:
bs, err := json.MarshalIndent(map[string]interface{}{
"errors": []string{err.Error()},
}, "", " ")
Expand All @@ -416,7 +418,7 @@ func formatError(format string, err error) error {
}

return fmt.Errorf("%s", string(bs))
} else if format == formatJunit {
case formatJunit:
testSuites := junit.Testsuites{
Name: "regal",
}
Expand Down
2 changes: 1 addition & 1 deletion e2e/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ project:
- path: v0
rego-version: 0
- path: v1
rego-version: 0
rego-version: 1
`,
"foo/main.rego": `package wow
Expand Down
4 changes: 1 addition & 3 deletions internal/lsp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2223,9 +2223,7 @@ func (l *LanguageServer) handleTextDocumentFormatting(
params.TextDocument.URI: oldContent,
})

input, err := memfp.ToInput(func(fileName string) ast.RegoVersion {
return l.determineVersionForFile(uri.FromPath(l.clientIdentifier, fileName))
})
input, err := memfp.ToInput(l.loadedConfigAllRegoVersions.Clone())
if err != nil {
return nil, fmt.Errorf("failed to create fixer input: %w", err)
}
Expand Down
23 changes: 5 additions & 18 deletions pkg/fixer/fileprovider/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"github.com/styrainc/regal/internal/lsp/cache"
"github.com/styrainc/regal/internal/lsp/clients"
"github.com/styrainc/regal/internal/lsp/uri"
"github.com/styrainc/regal/internal/parse"
"github.com/styrainc/regal/internal/util"
"github.com/styrainc/regal/pkg/rules"
)
Expand Down Expand Up @@ -90,23 +89,11 @@ func (c *CacheFileProvider) Rename(from, to string) error {
return nil
}

func (c *CacheFileProvider) ToInput(versionLookup func(string) ast.RegoVersion) (rules.Input, error) {
strContents := make(map[string]string)
modules := make(map[string]*ast.Module)

for filename, content := range c.Cache.GetAllFiles() {
var err error

strContents[filename] = content

po := parse.ParserOptions()
po.RegoVersion = versionLookup(filename)

modules[filename], err = parse.ModuleWithOpts(filename, strContents[filename], po)
if err != nil {
return rules.Input{}, fmt.Errorf("failed to parse module %s: %w", filename, err)
}
func (c *CacheFileProvider) ToInput(versionsMap map[string]ast.RegoVersion) (rules.Input, error) {
input, err := rules.InputFromMap(c.Cache.GetAllFiles(), versionsMap)
if err != nil {
return rules.Input{}, fmt.Errorf("failed to create input: %w", err)
}

return rules.NewInput(strContents, modules), nil
return input, nil
}
2 changes: 1 addition & 1 deletion pkg/fixer/fileprovider/fp.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type FileProvider interface {
Delete(string) error
Rename(string, string) error

ToInput(versionLookup func(string) ast.RegoVersion) (rules.Input, error)
ToInput(versionsMap map[string]ast.RegoVersion) (rules.Input, error)
}

type RenameConflictError struct {
Expand Down
20 changes: 5 additions & 15 deletions pkg/fixer/fileprovider/inmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (

"github.com/open-policy-agent/opa/v1/ast"

"github.com/styrainc/regal/internal/parse"
"github.com/styrainc/regal/internal/util"
"github.com/styrainc/regal/pkg/rules"
)
Expand Down Expand Up @@ -112,20 +111,11 @@ func (p *InMemoryFileProvider) DeletedFiles() []string {
return util.Keys(p.deletedFiles)
}

func (p *InMemoryFileProvider) ToInput(versionLookup func(string) ast.RegoVersion) (rules.Input, error) {
modules := make(map[string]*ast.Module)

for filename, content := range p.files {
var err error

po := parse.ParserOptions()
po.RegoVersion = versionLookup(filename)

modules[filename], err = parse.ModuleWithOpts(filename, content, po)
if err != nil {
return rules.Input{}, fmt.Errorf("failed to parse module %s: %w", filename, err)
}
func (p *InMemoryFileProvider) ToInput(versionsMap map[string]ast.RegoVersion) (rules.Input, error) {
input, err := rules.InputFromMap(p.files, versionsMap)
if err != nil {
return rules.Input{}, fmt.Errorf("failed to create input: %w", err)
}

return rules.NewInput(p.files, modules), nil
return input, nil
}
20 changes: 10 additions & 10 deletions pkg/fixer/fixer.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,18 @@ type Fixer struct {
registeredMandatoryFixes map[string]any
onConflictOperation OnConflictOperation
registeredRoots []string
getRegoVersion func(string) ast.RegoVersion
versionsMap map[string]ast.RegoVersion
}

// SetOnConflictOperation sets the fixer's behavior when a conflict occurs.
func (f *Fixer) SetOnConflictOperation(operation OnConflictOperation) {
f.onConflictOperation = operation
}

func (f *Fixer) SetRegoVersionLookup(fn func(string) ast.RegoVersion) {
f.getRegoVersion = fn
// SetRegoVersionsMap sets the mapping of path prefixes to versions for the
// fixer to use when creating input for fixer runs.
func (f *Fixer) SetRegoVersionsMap(versionsMap map[string]ast.RegoVersion) {
f.versionsMap = versionsMap
}

// RegisterFixes sets the fixes that will be fixed if there are related linter
Expand Down Expand Up @@ -293,16 +295,14 @@ func (f *Fixer) applyLinterFixes(
return fmt.Errorf("failed to list files: %w", err)
}

if f.versionsMap == nil {
return errors.New("rego versions map not set")
}

for {
fixMadeInIteration := false

in, err := fp.ToInput(func(fileName string) ast.RegoVersion {
if f.getRegoVersion == nil {
return ast.RegoV1
}

return f.getRegoVersion(fileName)
})
in, err := fp.ToInput(f.versionsMap)
if err != nil {
return fmt.Errorf("failed to generate linter input: %w", err)
}
Expand Down
29 changes: 10 additions & 19 deletions pkg/fixer/fixer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,8 @@ deny = true

memfp := fileprovider.NewInMemoryFileProvider(policies)

input, err := memfp.ToInput(func(fileName string) ast.RegoVersion {
if fileName == "/root/main/main.rego" {
return ast.RegoV1
}

t.Fatalf("unexpected file when looking up version %s", fileName)

return ast.RegoUndefined
input, err := memfp.ToInput(map[string]ast.RegoVersion{
"/root/main": ast.RegoV1,
})
if err != nil {
t.Fatalf("failed to create input: %v", err)
Expand All @@ -50,6 +44,9 @@ deny = true
f := NewFixer()
f.RegisterFixes(fixes.NewDefaultFixes()...)
f.RegisterRoots("/root")
f.SetRegoVersionsMap(map[string]ast.RegoVersion{
"/root/main": ast.RegoV1,
})

fixReport, err := f.Fix(context.Background(), &l, memfp)
if err != nil {
Expand Down Expand Up @@ -129,7 +126,7 @@ func TestFixerWithRegisterMandatoryFixes(t *testing.T) {
t.Parallel()

policies := map[string]string{
"main.rego": `package test
"/root/main/main.rego": `package test
allow {
true #no space
Expand All @@ -141,14 +138,8 @@ deny = true

memfp := fileprovider.NewInMemoryFileProvider(policies)

input, err := memfp.ToInput(func(fileName string) ast.RegoVersion {
if fileName == "main.rego" {
return ast.RegoV0
}

t.Fatalf("unexpected file when looking up version %s", fileName)

return ast.RegoUndefined
input, err := memfp.ToInput(map[string]ast.RegoVersion{
"/root/main": ast.RegoV0,
})
if err != nil {
t.Fatalf("failed to create input: %v", err)
Expand Down Expand Up @@ -181,12 +172,12 @@ deny = true
}

expectedFileFixedViolations := map[string][]string{
"main.rego": {"use-rego-v1"},
"/root/main/main.rego": {"use-rego-v1"},
}
expectedFileContents := map[string]string{
// note that since only the rego-v1-format fix is run, the
// no-whitespace-comment fix is not applied
"main.rego": `package test
"/root/main/main.rego": `package test
import rego.v1
Expand Down
5 changes: 1 addition & 4 deletions pkg/linter/linter.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,17 +292,14 @@ func (l Linter) Lint(ctx context.Context) (report.Report, error) {

var versionsMap map[string]ast.RegoVersion

// TODO: How should we deal with this in the language server?
// AllRegoVersions will call WalkDir on the root to find manifests, but that's obviously not
// going to work for a file:// path..
if l.pathPrefix != "" && !strings.HasPrefix(l.pathPrefix, "file://") {
versionsMap, err = config.AllRegoVersions(l.pathPrefix, conf)
if err != nil && l.debugMode {
log.Printf("failed to get configured Rego versions: %v", err)
}
}

inputFromPaths, err := rules.InputFromPaths(filtered, versionsMap)
inputFromPaths, err := rules.InputFromPaths(filtered, l.pathPrefix, versionsMap)
if err != nil {
return report.Report{}, fmt.Errorf("errors encountered when reading files to lint: %w", err)
}
Expand Down
77 changes: 56 additions & 21 deletions pkg/rules/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func NewInput(fileContent map[string]string, modules map[string]*ast.Module) Inp
// function. When the versionsMap is not nil/empty, files in a directory matching a key in the map will be parsed with
// the corresponding Rego version. If not provided, the file may be parsed multiple times in order to determine the
// version (best-effort and may include false positives).
func InputFromPaths(paths []string, versionsMap map[string]ast.RegoVersion) (Input, error) {
func InputFromPaths(paths []string, prefix string, versionsMap map[string]ast.RegoVersion) (Input, error) {
if len(paths) == 1 && paths[0] == "-" {
return inputFromStdin()
}
Expand All @@ -96,30 +96,13 @@ func InputFromPaths(paths []string, versionsMap map[string]ast.RegoVersion) (Inp

errors := make([]error, 0, len(paths))

parserOptions := parse.ParserOptions()

for _, path := range paths {
go func(path string) {
defer wg.Done()

parserOptions.RegoVersion = ast.RegoUndefined

// Check if the path matches any directory where a specific Rego version is set,
// and if so use that instead of having to parse the file (potentially multiple times)
// in order to determine the Rego version.
// If a project-wide version has been set, it'll be found under the path "", which will
// always be the last entry in versionedDirs, and only match if no specific directory
// matches.
if len(versionsMap) > 0 {
dir := filepath.Dir(path)
for _, versionedDir := range versionedDirs {
if strings.HasPrefix(dir, versionedDir) {
parserOptions.RegoVersion = versionsMap[versionedDir]

break
}
}
}
parserOptions := parse.ParserOptions()

parserOptions.RegoVersion = RegoVersionFromVersionsMap(versionsMap, strings.TrimPrefix(path, prefix), ast.RegoUndefined)

result, err := regoWithOpts(path, parserOptions)

Expand All @@ -146,6 +129,58 @@ func InputFromPaths(paths []string, versionsMap map[string]ast.RegoVersion) (Inp
return NewInput(fileContent, modules), nil
}

// InputFromMap creates a new Input from a map of file paths to their contents.
// This function uses a vesrionsMap to determine the parser version for each
// file before parsing the module.
func InputFromMap(files map[string]string, versionsMap map[string]ast.RegoVersion) (Input, error) {
fileContent := make(map[string]string, len(files))
modules := make(map[string]*ast.Module, len(files))
parserOptions := parse.ParserOptions()

for path, content := range files {
fileContent[path] = content

parserOptions.RegoVersion = RegoVersionFromVersionsMap(versionsMap, path, ast.RegoUndefined)

mod, err := parse.ModuleWithOpts(path, content, parserOptions)
if err != nil {
return Input{}, fmt.Errorf("failed to parse module %s: %w", path, err)
}

modules[path] = mod
}

return NewInput(fileContent, modules), nil
}

// RegoVersionFromVersionsMap takes a mapping of file path prefixes, typically
// representing the roots of the project, and the expected Rego version for
// each. Using this, it finds the longest matching prefix for the given filename
// and returns the defaultVersion if to matching prefix is found.
func RegoVersionFromVersionsMap(versionsMap map[string]ast.RegoVersion, filename string, defaultVersion ast.RegoVersion) ast.RegoVersion {
if len(versionsMap) == 0 {
return defaultVersion
}

selectedVersion := defaultVersion

var longestMatch int

dir := filepath.Dir(filename)
for versionedDir := range versionsMap {
matchingVersionedDir := versionedDir + "/"

if strings.HasPrefix(dir+"/", matchingVersionedDir) {
if len(versionedDir) > longestMatch {
longestMatch = len(versionedDir)
selectedVersion = versionsMap[versionedDir]
}
}
}

return selectedVersion
}

func regoWithOpts(path string, opts ast.ParserOptions) (*regoFile, error) {
path = filepath.Clean(path)

Expand Down
Loading

0 comments on commit 8ce3897

Please sign in to comment.