diff --git a/pkg/iac/rego/provider.go b/pkg/iac/rego/provider.go new file mode 100644 index 000000000000..a49e700717c0 --- /dev/null +++ b/pkg/iac/rego/provider.go @@ -0,0 +1,43 @@ +package rego + +import ( + "fmt" + "io/fs" + "sync" + + "github.com/aquasecurity/trivy/pkg/iac/scanners/options" +) + +func WithRegoScanner(rs *Scanner) options.ScannerOption { + return func(s options.ConfigurableScanner) { + if ss, ok := s.(*RegoScannerProvider); ok { + ss.regoScanner = rs + } + } +} + +type RegoScannerProvider struct { + mu sync.Mutex + regoScanner *Scanner +} + +func NewRegoScannerProvider(opts ...options.ScannerOption) *RegoScannerProvider { + s := &RegoScannerProvider{} + for _, o := range opts { + o(s) + } + return s +} + +func (s *RegoScannerProvider) InitRegoScanner(fsys fs.FS, opts []options.ScannerOption) (*Scanner, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.regoScanner != nil { + return s.regoScanner, nil + } + s.regoScanner = NewScanner(opts...) + if err := s.regoScanner.LoadPolicies(fsys); err != nil { + return nil, fmt.Errorf("load checks: %w", err) + } + return s.regoScanner, nil +} diff --git a/pkg/iac/rego/scanner.go b/pkg/iac/rego/scanner.go index de80bedc6883..9bc5d8206ad1 100644 --- a/pkg/iac/rego/scanner.go +++ b/pkg/iac/rego/scanner.go @@ -105,6 +105,7 @@ func NewScanner(opts ...options.ScannerOption) *Scanner { func (s *Scanner) runQuery(ctx context.Context, query string, input ast.Value, disableTracing bool) (rego.ResultSet, []string, error) { + // TODO: (s.traceWriter != nil && s.tracePerResult) && !disableTracing trace := (s.traceWriter != nil || s.tracePerResult) && !disableTracing regoOptions := []func(*rego.Rego){ diff --git a/pkg/iac/scanners/azure/arm/scanner.go b/pkg/iac/scanners/azure/arm/scanner.go index fe8b3e8a33d9..7869569b1b46 100644 --- a/pkg/iac/scanners/azure/arm/scanner.go +++ b/pkg/iac/scanners/azure/arm/scanner.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "io/fs" - "sync" "github.com/aquasecurity/trivy/pkg/iac/adapters/arm" "github.com/aquasecurity/trivy/pkg/iac/rego" @@ -13,7 +12,6 @@ import ( "github.com/aquasecurity/trivy/pkg/iac/scanners/azure" "github.com/aquasecurity/trivy/pkg/iac/scanners/azure/arm/parser" "github.com/aquasecurity/trivy/pkg/iac/scanners/options" - "github.com/aquasecurity/trivy/pkg/iac/state" "github.com/aquasecurity/trivy/pkg/iac/types" "github.com/aquasecurity/trivy/pkg/log" ) @@ -22,16 +20,16 @@ var _ scanners.FSScanner = (*Scanner)(nil) var _ options.ConfigurableScanner = (*Scanner)(nil) type Scanner struct { - mu sync.Mutex - scannerOptions []options.ScannerOption - logger *log.Logger - regoScanner *rego.Scanner + *rego.RegoScannerProvider + opts []options.ScannerOption + logger *log.Logger } func New(opts ...options.ScannerOption) *Scanner { scanner := &Scanner{ - scannerOptions: opts, - logger: log.WithPrefix("azure-arm"), + RegoScannerProvider: rego.NewRegoScannerProvider(opts...), + opts: opts, + logger: log.WithPrefix("azure-arm"), } for _, opt := range opts { opt(scanner) @@ -43,29 +41,12 @@ func (s *Scanner) Name() string { return "Azure ARM" } -func (s *Scanner) initRegoScanner(srcFS fs.FS) error { - s.mu.Lock() - defer s.mu.Unlock() - if s.regoScanner != nil { - return nil - } - regoScanner := rego.NewScanner(s.scannerOptions...) - if err := regoScanner.LoadPolicies(srcFS); err != nil { - return err - } - s.regoScanner = regoScanner - return nil -} - func (s *Scanner) ScanFS(ctx context.Context, fsys fs.FS, dir string) (scan.Results, error) { p := parser.New(fsys) deployments, err := p.ParseFS(ctx, dir) if err != nil { return nil, err } - if err := s.initRegoScanner(fsys); err != nil { - return nil, err - } return s.scanDeployments(ctx, deployments, fsys) } @@ -87,12 +68,17 @@ func (s *Scanner) scanDeployments(ctx context.Context, deployments []azure.Deplo } func (s *Scanner) scanDeployment(ctx context.Context, deployment azure.Deployment, fsys fs.FS) (scan.Results, error) { - deploymentState := s.adaptDeployment(ctx, deployment) + state := arm.Adapt(ctx, deployment) - results, err := s.regoScanner.ScanInput(ctx, types.SourceCloud, rego.Input{ + rs, err := s.InitRegoScanner(fsys, s.opts) + if err != nil { + return nil, fmt.Errorf("init rego scanner: %w", err) + } + + results, err := rs.ScanInput(ctx, types.SourceCloud, rego.Input{ Path: deployment.Metadata.Range().GetFilename(), FS: fsys, - Contents: deploymentState.ToRego(), + Contents: state.ToRego(), }) if err != nil { return nil, fmt.Errorf("rego scan error: %w", err) @@ -100,7 +86,3 @@ func (s *Scanner) scanDeployment(ctx context.Context, deployment azure.Deploymen return results, nil } - -func (s *Scanner) adaptDeployment(ctx context.Context, deployment azure.Deployment) *state.State { - return arm.Adapt(ctx, deployment) -} diff --git a/pkg/iac/scanners/cloudformation/scanner.go b/pkg/iac/scanners/cloudformation/scanner.go index f29e98213291..63b00e412fb0 100644 --- a/pkg/iac/scanners/cloudformation/scanner.go +++ b/pkg/iac/scanners/cloudformation/scanner.go @@ -5,7 +5,6 @@ import ( "fmt" "io/fs" "sort" - "sync" adapter "github.com/aquasecurity/trivy/pkg/iac/adapters/cloudformation" "github.com/aquasecurity/trivy/pkg/iac/rego" @@ -45,10 +44,9 @@ var _ scanners.FSScanner = (*Scanner)(nil) var _ options.ConfigurableScanner = (*Scanner)(nil) type Scanner struct { - mu sync.Mutex + *rego.RegoScannerProvider logger *log.Logger parser *parser.Parser - regoScanner *rego.Scanner options []options.ScannerOption parserOptions []parser.Option } @@ -64,8 +62,9 @@ func (s *Scanner) Name() string { // New creates a new Scanner func New(opts ...options.ScannerOption) *Scanner { s := &Scanner{ - options: opts, - logger: log.WithPrefix("cloudformation scanner"), + RegoScannerProvider: rego.NewRegoScannerProvider(opts...), + options: opts, + logger: log.WithPrefix("cloudformation scanner"), } for _, opt := range opts { opt(s) @@ -74,20 +73,6 @@ func New(opts ...options.ScannerOption) *Scanner { return s } -func (s *Scanner) initRegoScanner(srcFS fs.FS) (*rego.Scanner, error) { - s.mu.Lock() - defer s.mu.Unlock() - if s.regoScanner != nil { - return s.regoScanner, nil - } - regoScanner := rego.NewScanner(s.options...) - if err := regoScanner.LoadPolicies(srcFS); err != nil { - return nil, err - } - s.regoScanner = regoScanner - return regoScanner, nil -} - func (s *Scanner) ScanFS(ctx context.Context, fsys fs.FS, dir string) (results scan.Results, err error) { contexts, err := s.parser.ParseFS(ctx, fsys, dir) @@ -99,16 +84,16 @@ func (s *Scanner) ScanFS(ctx context.Context, fsys fs.FS, dir string) (results s return nil, nil } - regoScanner, err := s.initRegoScanner(fsys) + rs, err := s.InitRegoScanner(fsys, s.options) if err != nil { - return nil, err + return nil, fmt.Errorf("init rego scanner: %w", err) } for _, cfCtx := range contexts { if cfCtx == nil { continue } - fileResults, err := s.scanFileContext(ctx, regoScanner, cfCtx, fsys) + fileResults, err := s.scanFileContext(ctx, rs, cfCtx, fsys) if err != nil { return nil, err } @@ -127,12 +112,12 @@ func (s *Scanner) ScanFile(ctx context.Context, fsys fs.FS, path string) (scan.R return nil, err } - regoScanner, err := s.initRegoScanner(fsys) + rs, err := s.InitRegoScanner(fsys, s.options) if err != nil { - return nil, err + return nil, fmt.Errorf("init rego scanner: %w", err) } - results, err := s.scanFileContext(ctx, regoScanner, cfCtx, fsys) + results, err := s.scanFileContext(ctx, rs, cfCtx, fsys) if err != nil { return nil, err } diff --git a/pkg/iac/scanners/dockerfile/scanner_test.go b/pkg/iac/scanners/dockerfile/scanner_test.go index 44ab743e1778..37647ded9e25 100644 --- a/pkg/iac/scanners/dockerfile/scanner_test.go +++ b/pkg/iac/scanners/dockerfile/scanner_test.go @@ -573,7 +573,7 @@ COPY --from=dep /binary /` results, err := scanner.ScanFS(context.TODO(), fsys, "code") if tc.expectedError != "" && err != nil { - require.Equal(t, tc.expectedError, err.Error(), tc.name) + require.ErrorContainsf(t, err, tc.expectedError, tc.name) } else { require.NoError(t, err) require.Len(t, results.GetFailed(), 1) diff --git a/pkg/iac/scanners/generic/scanner.go b/pkg/iac/scanners/generic/scanner.go index f556f1168103..0ac53b3525e9 100644 --- a/pkg/iac/scanners/generic/scanner.go +++ b/pkg/iac/scanners/generic/scanner.go @@ -9,7 +9,6 @@ import ( "io/fs" "path/filepath" "strings" - "sync" "github.com/BurntSushi/toml" "github.com/samber/lo" @@ -41,12 +40,11 @@ type configParser interface { // GenericScanner is a scanner that scans a file as is without processing it type GenericScanner struct { - mu sync.Mutex - name string - source types.Source - logger *log.Logger - options []options.ScannerOption - regoScanner *rego.Scanner + *rego.RegoScannerProvider + name string + source types.Source + logger *log.Logger + options []options.ScannerOption parser configParser } @@ -59,11 +57,12 @@ func (f ParseFunc) Parse(ctx context.Context, r io.Reader, path string) (any, er func NewScanner(name string, source types.Source, parser configParser, opts ...options.ScannerOption) *GenericScanner { s := &GenericScanner{ - name: name, - options: opts, - source: source, - logger: log.WithPrefix(fmt.Sprintf("%s scanner", source)), - parser: parser, + RegoScannerProvider: rego.NewRegoScannerProvider(opts...), + name: name, + options: opts, + source: source, + logger: log.WithPrefix(fmt.Sprintf("%s scanner", source)), + parser: parser, } for _, opt := range opts { @@ -113,13 +112,13 @@ func (s *GenericScanner) ScanFS(ctx context.Context, fsys fs.FS, dir string) (sc } } - regoScanner, err := s.initRegoScanner(fsys) + rs, err := s.InitRegoScanner(fsys, s.options) if err != nil { - return nil, err + return nil, fmt.Errorf("init rego scanner: %w", err) } s.logger.Debug("Scanning files...", log.Int("count", len(inputs))) - results, err := regoScanner.ScanInput(ctx, s.source, inputs...) + results, err := rs.ScanInput(ctx, s.source, inputs...) if err != nil { return nil, err } @@ -170,20 +169,6 @@ func (s *GenericScanner) parseFS(ctx context.Context, fsys fs.FS, path string) ( return files, nil } -func (s *GenericScanner) initRegoScanner(srcFS fs.FS) (*rego.Scanner, error) { - s.mu.Lock() - defer s.mu.Unlock() - if s.regoScanner != nil { - return s.regoScanner, nil - } - regoScanner := rego.NewScanner(s.options...) - if err := regoScanner.LoadPolicies(srcFS); err != nil { - return nil, err - } - s.regoScanner = regoScanner - return regoScanner, nil -} - func (s *GenericScanner) applyIgnoreRules(fsys fs.FS, results scan.Results) error { if !s.supportsIgnoreRules() { return nil diff --git a/pkg/iac/scanners/helm/scanner.go b/pkg/iac/scanners/helm/scanner.go index 6ab324f66f28..634e40d200d0 100644 --- a/pkg/iac/scanners/helm/scanner.go +++ b/pkg/iac/scanners/helm/scanner.go @@ -7,7 +7,6 @@ import ( "io/fs" "path/filepath" "strings" - "sync" "github.com/liamg/memoryfs" @@ -27,18 +26,18 @@ var _ scanners.FSScanner = (*Scanner)(nil) var _ options.ConfigurableScanner = (*Scanner)(nil) type Scanner struct { - mu sync.Mutex + *rego.RegoScannerProvider logger *log.Logger options []options.ScannerOption parserOptions []parser.Option - regoScanner *rego.Scanner } // New creates a new Scanner func New(opts ...options.ScannerOption) *Scanner { s := &Scanner{ - options: opts, - logger: log.WithPrefix("helm scanner"), + RegoScannerProvider: rego.NewRegoScannerProvider(opts...), + options: opts, + logger: log.WithPrefix("helm scanner"), } for _, option := range opts { @@ -56,11 +55,6 @@ func (s *Scanner) Name() string { } func (s *Scanner) ScanFS(ctx context.Context, target fs.FS, path string) (scan.Results, error) { - - if err := s.initRegoScanner(target); err != nil { - return nil, fmt.Errorf("failed to init rego scanner: %w", err) - } - var results []scan.Result if err := fs.WalkDir(target, path, func(path string, d fs.DirEntry, err error) error { select { @@ -122,6 +116,11 @@ func (s *Scanner) getScanResults(path string, ctx context.Context, target fs.FS) return nil, nil } + rs, err := s.InitRegoScanner(target, s.options) + if err != nil { + return nil, fmt.Errorf("init rego scanner: %w", err) + } + for _, file := range chartFiles { file := file s.logger.Debug("Processing rendered chart file", log.FilePath(file.TemplateFilePath)) @@ -132,7 +131,7 @@ func (s *Scanner) getScanResults(path string, ctx context.Context, target fs.FS) return nil, fmt.Errorf("unmarshal yaml: %w", err) } for _, manifest := range manifests { - fileResults, err := s.regoScanner.ScanInput(ctx, types.SourceKubernetes, rego.Input{ + fileResults, err := rs.ScanInput(ctx, types.SourceKubernetes, rego.Input{ Path: file.TemplateFilePath, Contents: manifest, FS: target, @@ -161,17 +160,3 @@ func (s *Scanner) getScanResults(path string, ctx context.Context, target fs.FS) } return results, nil } - -func (s *Scanner) initRegoScanner(srcFS fs.FS) error { - s.mu.Lock() - defer s.mu.Unlock() - if s.regoScanner != nil { - return nil - } - regoScanner := rego.NewScanner(s.options...) - if err := regoScanner.LoadPolicies(srcFS); err != nil { - return err - } - s.regoScanner = regoScanner - return nil -} diff --git a/pkg/iac/scanners/scanner.go b/pkg/iac/scanners/scanner.go index e792545b9e97..0702726b119a 100644 --- a/pkg/iac/scanners/scanner.go +++ b/pkg/iac/scanners/scanner.go @@ -3,15 +3,10 @@ package scanners import ( "context" "io/fs" - "os" "github.com/aquasecurity/trivy/pkg/iac/scan" ) -type WriteFileFS interface { - WriteFile(name string, data []byte, perm os.FileMode) error -} - type FSScanner interface { // Name provides the human-readable name of the scanner e.g. "CloudFormation" Name() string diff --git a/pkg/iac/scanners/terraform/scanner.go b/pkg/iac/scanners/terraform/scanner.go index 21735bff1a52..5975facab448 100644 --- a/pkg/iac/scanners/terraform/scanner.go +++ b/pkg/iac/scanners/terraform/scanner.go @@ -26,14 +26,13 @@ var _ options.ConfigurableScanner = (*Scanner)(nil) var _ ConfigurableTerraformScanner = (*Scanner)(nil) type Scanner struct { - mu sync.Mutex + *rego.RegoScannerProvider logger *log.Logger options []options.ScannerOption parserOpt []parser.Option executorOpt []executor.Option dirs set.Set[string] forceAllDirs bool - regoScanner *rego.Scanner execLock sync.RWMutex } @@ -55,9 +54,10 @@ func (s *Scanner) AddExecutorOptions(opts ...executor.Option) { func New(opts ...options.ScannerOption) *Scanner { s := &Scanner{ - dirs: set.New[string](), - options: opts, - logger: log.WithPrefix("terraform scanner"), + RegoScannerProvider: rego.NewRegoScannerProvider(opts...), + dirs: set.New[string](), + options: opts, + logger: log.WithPrefix("terraform scanner"), } for _, opt := range opts { opt(s) @@ -65,20 +65,6 @@ func New(opts ...options.ScannerOption) *Scanner { return s } -func (s *Scanner) initRegoScanner(srcFS fs.FS) (*rego.Scanner, error) { - s.mu.Lock() - defer s.mu.Unlock() - if s.regoScanner != nil { - return s.regoScanner, nil - } - regoScanner := rego.NewScanner(s.options...) - if err := regoScanner.LoadPolicies(srcFS); err != nil { - return nil, err - } - s.regoScanner = regoScanner - return regoScanner, nil -} - // terraformRootModule represents the module to be used as the root module for Terraform deployment. type terraformRootModule struct { rootPath string @@ -99,13 +85,13 @@ func (s *Scanner) ScanFS(ctx context.Context, target fs.FS, dir string) (scan.Re return nil, nil } - regoScanner, err := s.initRegoScanner(target) + rs, err := s.InitRegoScanner(target, s.options) if err != nil { - return nil, err + return nil, fmt.Errorf("init rego scanner: %w", err) } s.execLock.Lock() - s.executorOpt = append(s.executorOpt, executor.OptionWithRegoScanner(regoScanner)) + s.executorOpt = append(s.executorOpt, executor.OptionWithRegoScanner(rs)) s.execLock.Unlock() var allResults scan.Results diff --git a/pkg/iac/scanners/terraformplan/snapshot/scanner.go b/pkg/iac/scanners/terraformplan/snapshot/scanner.go index 3c8dcc8fce0b..654c9dcaa078 100644 --- a/pkg/iac/scanners/terraformplan/snapshot/scanner.go +++ b/pkg/iac/scanners/terraformplan/snapshot/scanner.go @@ -9,12 +9,12 @@ import ( "github.com/aquasecurity/trivy/pkg/iac/scan" "github.com/aquasecurity/trivy/pkg/iac/scanners/options" - terraformScanner "github.com/aquasecurity/trivy/pkg/iac/scanners/terraform" + tfscanner "github.com/aquasecurity/trivy/pkg/iac/scanners/terraform" tfparser "github.com/aquasecurity/trivy/pkg/iac/scanners/terraform/parser" ) type Scanner struct { - inner *terraformScanner.Scanner + inner *tfscanner.Scanner } func (s *Scanner) Name() string { @@ -23,7 +23,7 @@ func (s *Scanner) Name() string { func New(opts ...options.ScannerOption) *Scanner { scanner := &Scanner{ - inner: terraformScanner.New(opts...), + inner: tfscanner.New(opts...), } return scanner } diff --git a/pkg/iac/scanners/terraformplan/tfjson/scanner.go b/pkg/iac/scanners/terraformplan/tfjson/scanner.go index b3e8725c42ae..7fe6969beef0 100644 --- a/pkg/iac/scanners/terraformplan/tfjson/scanner.go +++ b/pkg/iac/scanners/terraformplan/tfjson/scanner.go @@ -14,10 +14,10 @@ import ( ) type Scanner struct { - parser *parser.Parser - logger *log.Logger - options []options.ScannerOption - tfScanner *terraform.Scanner + inner *terraform.Scanner + parser *parser.Parser + logger *log.Logger + options []options.ScannerOption } func (s *Scanner) Name() string { @@ -55,10 +55,10 @@ func (s *Scanner) ScanFS(ctx context.Context, fsys fs.FS, dir string) (scan.Resu func New(opts ...options.ScannerOption) *Scanner { scanner := &Scanner{ - options: opts, - logger: log.WithPrefix("tfjson scanner"), - parser: parser.New(), - tfScanner: terraform.New(opts...), + inner: terraform.New(opts...), + parser: parser.New(), + logger: log.WithPrefix("tfjson scanner"), + options: opts, } return scanner @@ -87,5 +87,5 @@ func (s *Scanner) Scan(reader io.Reader) (scan.Results, error) { return nil, fmt.Errorf("failed to convert plan to FS: %w", err) } - return s.tfScanner.ScanFS(context.TODO(), planFS, ".") + return s.inner.ScanFS(context.TODO(), planFS, ".") } diff --git a/pkg/misconf/scanner.go b/pkg/misconf/scanner.go index a32ef5c105e0..aca664fc2c54 100644 --- a/pkg/misconf/scanner.go +++ b/pkg/misconf/scanner.go @@ -103,6 +103,13 @@ func NewScanner(t detection.FileType, opt ScannerOption) (*Scanner, error) { return nil, err } + rs := rego.NewScanner(opts...) + if err := rs.LoadPolicies(nil); err != nil { + return nil, xerrors.Errorf("load checks: %w", err) + } + + opts = append(opts, rego.WithRegoScanner(rs)) + var scanner scanners.FSScanner switch t { case detection.FileTypeAzureARM: