Skip to content

Commit

Permalink
feat(misconf): add option to pass Rego scanner to IaC scanner
Browse files Browse the repository at this point in the history
Signed-off-by: nikpivkin <[email protected]>
  • Loading branch information
nikpivkin committed Feb 7, 2025
1 parent e8c1d45 commit db15a17
Show file tree
Hide file tree
Showing 12 changed files with 120 additions and 151 deletions.
43 changes: 43 additions & 0 deletions pkg/iac/rego/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package rego

import (
"fmt"
"io/fs"
"sync"

"github.com/aquasecurity/trivy/pkg/iac/scanners/options"
)

func WithRegoScanner(rs *Scanner) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*RegoScannerProvider); ok {
ss.regoScanner = rs
}
}
}

type RegoScannerProvider struct {
mu sync.Mutex
regoScanner *Scanner
}

func NewRegoScannerProvider(opts ...options.ScannerOption) *RegoScannerProvider {
s := &RegoScannerProvider{}
for _, o := range opts {
o(s)
}
return s
}

func (s *RegoScannerProvider) InitRegoScanner(fsys fs.FS, opts []options.ScannerOption) (*Scanner, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.regoScanner != nil {
return s.regoScanner, nil
}
s.regoScanner = NewScanner(opts...)
if err := s.regoScanner.LoadPolicies(fsys); err != nil {
return nil, fmt.Errorf("load checks: %w", err)
}
return s.regoScanner, nil
}
1 change: 1 addition & 0 deletions pkg/iac/rego/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ func NewScanner(opts ...options.ScannerOption) *Scanner {

func (s *Scanner) runQuery(ctx context.Context, query string, input ast.Value, disableTracing bool) (rego.ResultSet, []string, error) {

// TODO: (s.traceWriter != nil && s.tracePerResult) && !disableTracing
trace := (s.traceWriter != nil || s.tracePerResult) && !disableTracing

regoOptions := []func(*rego.Rego){
Expand Down
46 changes: 14 additions & 32 deletions pkg/iac/scanners/azure/arm/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"fmt"
"io/fs"
"sync"

"github.com/aquasecurity/trivy/pkg/iac/adapters/arm"
"github.com/aquasecurity/trivy/pkg/iac/rego"
Expand All @@ -13,7 +12,6 @@ import (
"github.com/aquasecurity/trivy/pkg/iac/scanners/azure"
"github.com/aquasecurity/trivy/pkg/iac/scanners/azure/arm/parser"
"github.com/aquasecurity/trivy/pkg/iac/scanners/options"
"github.com/aquasecurity/trivy/pkg/iac/state"
"github.com/aquasecurity/trivy/pkg/iac/types"
"github.com/aquasecurity/trivy/pkg/log"
)
Expand All @@ -22,16 +20,16 @@ var _ scanners.FSScanner = (*Scanner)(nil)
var _ options.ConfigurableScanner = (*Scanner)(nil)

type Scanner struct {
mu sync.Mutex
scannerOptions []options.ScannerOption
logger *log.Logger
regoScanner *rego.Scanner
*rego.RegoScannerProvider
opts []options.ScannerOption
logger *log.Logger
}

func New(opts ...options.ScannerOption) *Scanner {
scanner := &Scanner{
scannerOptions: opts,
logger: log.WithPrefix("azure-arm"),
RegoScannerProvider: rego.NewRegoScannerProvider(opts...),
opts: opts,
logger: log.WithPrefix("azure-arm"),
}
for _, opt := range opts {
opt(scanner)
Expand All @@ -43,29 +41,12 @@ func (s *Scanner) Name() string {
return "Azure ARM"
}

func (s *Scanner) initRegoScanner(srcFS fs.FS) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.regoScanner != nil {
return nil
}
regoScanner := rego.NewScanner(s.scannerOptions...)
if err := regoScanner.LoadPolicies(srcFS); err != nil {
return err
}
s.regoScanner = regoScanner
return nil
}

func (s *Scanner) ScanFS(ctx context.Context, fsys fs.FS, dir string) (scan.Results, error) {
p := parser.New(fsys)
deployments, err := p.ParseFS(ctx, dir)
if err != nil {
return nil, err
}
if err := s.initRegoScanner(fsys); err != nil {
return nil, err
}

return s.scanDeployments(ctx, deployments, fsys)
}
Expand All @@ -87,20 +68,21 @@ func (s *Scanner) scanDeployments(ctx context.Context, deployments []azure.Deplo
}

func (s *Scanner) scanDeployment(ctx context.Context, deployment azure.Deployment, fsys fs.FS) (scan.Results, error) {
deploymentState := s.adaptDeployment(ctx, deployment)
state := arm.Adapt(ctx, deployment)

results, err := s.regoScanner.ScanInput(ctx, types.SourceCloud, rego.Input{
rs, err := s.InitRegoScanner(fsys, s.opts)
if err != nil {
return nil, fmt.Errorf("init rego scanner: %w", err)
}

results, err := rs.ScanInput(ctx, types.SourceCloud, rego.Input{
Path: deployment.Metadata.Range().GetFilename(),
FS: fsys,
Contents: deploymentState.ToRego(),
Contents: state.ToRego(),
})
if err != nil {
return nil, fmt.Errorf("rego scan error: %w", err)
}

return results, nil
}

func (s *Scanner) adaptDeployment(ctx context.Context, deployment azure.Deployment) *state.State {
return arm.Adapt(ctx, deployment)
}
35 changes: 10 additions & 25 deletions pkg/iac/scanners/cloudformation/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"io/fs"
"sort"
"sync"

adapter "github.com/aquasecurity/trivy/pkg/iac/adapters/cloudformation"
"github.com/aquasecurity/trivy/pkg/iac/rego"
Expand Down Expand Up @@ -45,10 +44,9 @@ var _ scanners.FSScanner = (*Scanner)(nil)
var _ options.ConfigurableScanner = (*Scanner)(nil)

type Scanner struct {
mu sync.Mutex
*rego.RegoScannerProvider
logger *log.Logger
parser *parser.Parser
regoScanner *rego.Scanner
options []options.ScannerOption
parserOptions []parser.Option
}
Expand All @@ -64,8 +62,9 @@ func (s *Scanner) Name() string {
// New creates a new Scanner
func New(opts ...options.ScannerOption) *Scanner {
s := &Scanner{
options: opts,
logger: log.WithPrefix("cloudformation scanner"),
RegoScannerProvider: rego.NewRegoScannerProvider(opts...),
options: opts,
logger: log.WithPrefix("cloudformation scanner"),
}
for _, opt := range opts {
opt(s)
Expand All @@ -74,20 +73,6 @@ func New(opts ...options.ScannerOption) *Scanner {
return s
}

func (s *Scanner) initRegoScanner(srcFS fs.FS) (*rego.Scanner, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.regoScanner != nil {
return s.regoScanner, nil
}
regoScanner := rego.NewScanner(s.options...)
if err := regoScanner.LoadPolicies(srcFS); err != nil {
return nil, err
}
s.regoScanner = regoScanner
return regoScanner, nil
}

func (s *Scanner) ScanFS(ctx context.Context, fsys fs.FS, dir string) (results scan.Results, err error) {

contexts, err := s.parser.ParseFS(ctx, fsys, dir)
Expand All @@ -99,16 +84,16 @@ func (s *Scanner) ScanFS(ctx context.Context, fsys fs.FS, dir string) (results s
return nil, nil
}

regoScanner, err := s.initRegoScanner(fsys)
rs, err := s.InitRegoScanner(fsys, s.options)
if err != nil {
return nil, err
return nil, fmt.Errorf("init rego scanner: %w", err)
}

for _, cfCtx := range contexts {
if cfCtx == nil {
continue
}
fileResults, err := s.scanFileContext(ctx, regoScanner, cfCtx, fsys)
fileResults, err := s.scanFileContext(ctx, rs, cfCtx, fsys)
if err != nil {
return nil, err
}
Expand All @@ -127,12 +112,12 @@ func (s *Scanner) ScanFile(ctx context.Context, fsys fs.FS, path string) (scan.R
return nil, err
}

regoScanner, err := s.initRegoScanner(fsys)
rs, err := s.InitRegoScanner(fsys, s.options)
if err != nil {
return nil, err
return nil, fmt.Errorf("init rego scanner: %w", err)
}

results, err := s.scanFileContext(ctx, regoScanner, cfCtx, fsys)
results, err := s.scanFileContext(ctx, rs, cfCtx, fsys)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/iac/scanners/dockerfile/scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ COPY --from=dep /binary /`

results, err := scanner.ScanFS(context.TODO(), fsys, "code")
if tc.expectedError != "" && err != nil {
require.Equal(t, tc.expectedError, err.Error(), tc.name)
require.ErrorContainsf(t, err, tc.expectedError, tc.name)
} else {
require.NoError(t, err)
require.Len(t, results.GetFailed(), 1)
Expand Down
43 changes: 14 additions & 29 deletions pkg/iac/scanners/generic/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"io/fs"
"path/filepath"
"strings"
"sync"

"github.com/BurntSushi/toml"
"github.com/samber/lo"
Expand Down Expand Up @@ -41,12 +40,11 @@ type configParser interface {

// GenericScanner is a scanner that scans a file as is without processing it
type GenericScanner struct {
mu sync.Mutex
name string
source types.Source
logger *log.Logger
options []options.ScannerOption
regoScanner *rego.Scanner
*rego.RegoScannerProvider
name string
source types.Source
logger *log.Logger
options []options.ScannerOption

parser configParser
}
Expand All @@ -59,11 +57,12 @@ func (f ParseFunc) Parse(ctx context.Context, r io.Reader, path string) (any, er

func NewScanner(name string, source types.Source, parser configParser, opts ...options.ScannerOption) *GenericScanner {
s := &GenericScanner{
name: name,
options: opts,
source: source,
logger: log.WithPrefix(fmt.Sprintf("%s scanner", source)),
parser: parser,
RegoScannerProvider: rego.NewRegoScannerProvider(opts...),
name: name,
options: opts,
source: source,
logger: log.WithPrefix(fmt.Sprintf("%s scanner", source)),
parser: parser,
}

for _, opt := range opts {
Expand Down Expand Up @@ -113,13 +112,13 @@ func (s *GenericScanner) ScanFS(ctx context.Context, fsys fs.FS, dir string) (sc
}
}

regoScanner, err := s.initRegoScanner(fsys)
rs, err := s.InitRegoScanner(fsys, s.options)
if err != nil {
return nil, err
return nil, fmt.Errorf("init rego scanner: %w", err)
}

s.logger.Debug("Scanning files...", log.Int("count", len(inputs)))
results, err := regoScanner.ScanInput(ctx, s.source, inputs...)
results, err := rs.ScanInput(ctx, s.source, inputs...)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -170,20 +169,6 @@ func (s *GenericScanner) parseFS(ctx context.Context, fsys fs.FS, path string) (
return files, nil
}

func (s *GenericScanner) initRegoScanner(srcFS fs.FS) (*rego.Scanner, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.regoScanner != nil {
return s.regoScanner, nil
}
regoScanner := rego.NewScanner(s.options...)
if err := regoScanner.LoadPolicies(srcFS); err != nil {
return nil, err
}
s.regoScanner = regoScanner
return regoScanner, nil
}

func (s *GenericScanner) applyIgnoreRules(fsys fs.FS, results scan.Results) error {
if !s.supportsIgnoreRules() {
return nil
Expand Down
Loading

0 comments on commit db15a17

Please sign in to comment.