From 4be02bab8c45f4f387e32b7a3f718ea7a56d2192 Mon Sep 17 00:00:00 2001 From: Teppei Fukuda Date: Thu, 27 Jun 2024 11:04:01 +0400 Subject: [PATCH] refactor: use google/wire for cache (#7024) Signed-off-by: knqyf263 --- pkg/cache/cache.go | 2 +- pkg/cache/client.go | 167 ++++++------------------------ pkg/cache/client_test.go | 150 +++++++++++++-------------- pkg/cache/fs.go | 7 +- pkg/cache/fs_test.go | 3 +- pkg/cache/redis.go | 117 +++++++++++++++++++-- pkg/cache/redis_test.go | 91 ++++++++-------- pkg/cache/remote.go | 16 ++- pkg/cache/remote_test.go | 24 ++++- pkg/commands/artifact/inject.go | 37 +++---- pkg/commands/artifact/run.go | 67 +++--------- pkg/commands/artifact/scanner.go | 29 +++--- pkg/commands/artifact/wire_gen.go | 122 +++++++++++++++------- pkg/commands/clean/run.go | 4 +- pkg/commands/clean/run_test.go | 4 + pkg/commands/server/run.go | 5 +- pkg/flag/cache_flags.go | 29 +++--- pkg/flag/global_flags.go | 3 + pkg/flag/options.go | 23 ++++ pkg/k8s/wire_gen.go | 2 +- pkg/rpc/server/wire_gen.go | 2 +- pkg/scanner/scan.go | 10 ++ 22 files changed, 496 insertions(+), 418 deletions(-) diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index b2f5fa704ae7..1280c84fe156 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -5,7 +5,7 @@ import ( ) const ( - cacheDirName = "fanal" + scanCacheDirName = "fanal" // artifactBucket stores artifact information with artifact ID such as image ID artifactBucket = "artifact" diff --git a/pkg/cache/client.go b/pkg/cache/client.go index ab9dd4799428..46bced1771aa 100644 --- a/pkg/cache/client.go +++ b/pkg/cache/client.go @@ -1,166 +1,65 @@ package cache import ( - "crypto/tls" - "crypto/x509" - "fmt" - "os" "strings" "time" - "github.com/go-redis/redis/v8" - "github.com/samber/lo" "golang.org/x/xerrors" - - "github.com/aquasecurity/trivy/pkg/log" ) const ( - TypeFS Type = "fs" - TypeRedis Type = "redis" + TypeUnknown Type = "unknown" + TypeFS Type = "fs" + TypeRedis Type = "redis" ) type Type string type Options struct { - Type Type - TTL time.Duration - Redis RedisOptions -} - -func NewOptions(backend, redisCACert, redisCert, redisKey string, redisTLS bool, ttl time.Duration) (Options, error) { - t, err := NewType(backend) - if err != nil { - return Options{}, xerrors.Errorf("cache type error: %w", err) - } - - var redisOpts RedisOptions - if t == TypeRedis { - redisTLSOpts, err := NewRedisTLSOptions(redisCACert, redisCert, redisKey) - if err != nil { - return Options{}, xerrors.Errorf("redis TLS option error: %w", err) - } - redisOpts = RedisOptions{ - Backend: backend, - TLS: redisTLS, - TLSOptions: redisTLSOpts, - } - } else if ttl != 0 { - log.Warn("'--cache-ttl' is only available with Redis cache backend") - } - - return Options{ - Type: t, - TTL: ttl, - Redis: redisOpts, - }, nil + Backend string + CacheDir string + RedisCACert string + RedisCert string + RedisKey string + RedisTLS bool + TTL time.Duration } -type RedisOptions struct { - Backend string - TLS bool - TLSOptions RedisTLSOptions -} - -// BackendMasked returns the redis connection string masking credentials -func (o *RedisOptions) BackendMasked() string { - endIndex := strings.Index(o.Backend, "@") - if endIndex == -1 { - return o.Backend - } - - startIndex := strings.Index(o.Backend, "//") - - return fmt.Sprintf("%s****%s", o.Backend[:startIndex+2], o.Backend[endIndex:]) -} - -// RedisTLSOptions holds the options for redis cache -type RedisTLSOptions struct { - CACert string - Cert string - Key string -} - -func NewRedisTLSOptions(caCert, cert, key string) (RedisTLSOptions, error) { - opts := RedisTLSOptions{ - CACert: caCert, - Cert: cert, - Key: key, - } - - // If one of redis option not nil, make sure CA, cert, and key provided - if !lo.IsEmpty(opts) { - if opts.CACert == "" || opts.Cert == "" || opts.Key == "" { - return RedisTLSOptions{}, xerrors.Errorf("you must provide Redis CA, cert and key file path when using TLS") - } - } - return opts, nil -} - -func NewType(backend string) (Type, error) { +func NewType(backend string) Type { // "redis://" or "fs" are allowed for now // An empty value is also allowed for testability switch { case strings.HasPrefix(backend, "redis://"): - return TypeRedis, nil + return TypeRedis case backend == "fs", backend == "": - return TypeFS, nil + return TypeFS default: - return "", xerrors.Errorf("unknown cache backend: %s", backend) + return TypeUnknown } } // New returns a new cache client -func New(dir string, opts Options) (Cache, error) { - if opts.Type == TypeRedis { - log.Info("Redis cache", log.String("url", opts.Redis.BackendMasked())) - options, err := redis.ParseURL(opts.Redis.Backend) +func New(opts Options) (Cache, func(), error) { + cleanup := func() {} // To avoid panic + + var cache Cache + t := NewType(opts.Backend) + switch t { + case TypeRedis: + redisCache, err := NewRedisCache(opts.Backend, opts.RedisCACert, opts.RedisCert, opts.RedisKey, opts.RedisTLS, opts.TTL) if err != nil { - return nil, err + return nil, cleanup, xerrors.Errorf("unable to initialize redis cache: %w", err) } - - if tlsOpts := opts.Redis.TLSOptions; !lo.IsEmpty(tlsOpts) { - caCert, cert, err := GetTLSConfig(tlsOpts.CACert, tlsOpts.Cert, tlsOpts.Key) - if err != nil { - return nil, err - } - - options.TLSConfig = &tls.Config{ - RootCAs: caCert, - Certificates: []tls.Certificate{cert}, - MinVersion: tls.VersionTLS12, - } - } else if opts.Redis.TLS { - options.TLSConfig = &tls.Config{ - MinVersion: tls.VersionTLS12, - } + cache = redisCache + case TypeFS: + // standalone mode + fsCache, err := NewFSCache(opts.CacheDir) + if err != nil { + return nil, cleanup, xerrors.Errorf("unable to initialize fs cache: %w", err) } - - return NewRedisCache(options, opts.TTL), nil - } - - // standalone mode - fsCache, err := NewFSCache(dir) - if err != nil { - return nil, xerrors.Errorf("unable to initialize fs cache: %w", err) - } - return fsCache, nil -} - -// GetTLSConfig gets tls config from CA, Cert and Key file -func GetTLSConfig(caCertPath, certPath, keyPath string) (*x509.CertPool, tls.Certificate, error) { - cert, err := tls.LoadX509KeyPair(certPath, keyPath) - if err != nil { - return nil, tls.Certificate{}, err - } - - caCert, err := os.ReadFile(caCertPath) - if err != nil { - return nil, tls.Certificate{}, err + cache = fsCache + default: + return nil, cleanup, xerrors.Errorf("unknown cache type: %s", t) } - - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) - - return caCertPool, cert, nil + return cache, func() { _ = cache.Close() }, nil } diff --git a/pkg/cache/client_test.go b/pkg/cache/client_test.go index f22ce4f93e2f..c72eb3de4d13 100644 --- a/pkg/cache/client_test.go +++ b/pkg/cache/client_test.go @@ -2,7 +2,6 @@ package cache_test import ( "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -10,120 +9,113 @@ import ( "github.com/aquasecurity/trivy/pkg/cache" ) -func TestNewOptions(t *testing.T) { - type args struct { - backend string - redisCACert string - redisCert string - redisKey string - redisTLS bool - ttl time.Duration - } +func TestNew(t *testing.T) { tests := []struct { - name string - args args - want cache.Options - assertion require.ErrorAssertionFunc + name string + opts cache.Options + wantType any + wantErr string }{ { - name: "fs", - args: args{backend: "fs"}, - want: cache.Options{Type: cache.TypeFS}, - assertion: require.NoError, + name: "fs backend", + opts: cache.Options{ + Backend: "fs", + CacheDir: "/tmp/cache", + }, + wantType: cache.FSCache{}, }, { - name: "redis", - args: args{backend: "redis://localhost:6379"}, - want: cache.Options{ - Type: cache.TypeRedis, - Redis: cache.RedisOptions{Backend: "redis://localhost:6379"}, + name: "redis backend", + opts: cache.Options{ + Backend: "redis://localhost:6379", }, - assertion: require.NoError, + wantType: cache.RedisCache{}, }, { - name: "redis tls", - args: args{ - backend: "redis://localhost:6379", - redisCACert: "ca-cert.pem", - redisCert: "cert.pem", - redisKey: "key.pem", - }, - want: cache.Options{ - Type: cache.TypeRedis, - Redis: cache.RedisOptions{ - Backend: "redis://localhost:6379", - TLSOptions: cache.RedisTLSOptions{ - CACert: "ca-cert.pem", - Cert: "cert.pem", - Key: "key.pem", - }, - }, + name: "unknown backend", + opts: cache.Options{ + Backend: "unknown", }, - assertion: require.NoError, + wantErr: "unknown cache type", }, { - name: "redis tls with public certificates", - args: args{ - backend: "redis://localhost:6379", - redisTLS: true, + name: "invalid redis URL", + opts: cache.Options{ + Backend: "redis://invalid-url:foo/bar", }, - want: cache.Options{ - Type: cache.TypeRedis, - Redis: cache.RedisOptions{ - Backend: "redis://localhost:6379", - TLS: true, - }, - }, - assertion: require.NoError, + wantErr: "failed to parse Redis URL", }, { - name: "unknown backend", - args: args{backend: "unknown"}, - assertion: func(t require.TestingT, err error, msgs ...any) { - require.ErrorContains(t, err, "unknown cache backend") + name: "incomplete TLS options", + opts: cache.Options{ + Backend: "redis://localhost:6379", + RedisCACert: "testdata/ca-cert.pem", + RedisTLS: true, }, + wantErr: "you must provide Redis CA, cert and key file path when using TLS", }, { - name: "sad redis tls", - args: args{ - backend: "redis://localhost:6379", - redisCACert: "ca-cert.pem", - }, - assertion: func(t require.TestingT, err error, msgs ...any) { - require.ErrorContains(t, err, "you must provide Redis CA, cert and key file path when using TLS") + name: "invalid TLS file paths", + opts: cache.Options{ + Backend: "redis://localhost:6379", + RedisCACert: "testdata/non-existent-ca-cert.pem", + RedisCert: "testdata/non-existent-cert.pem", + RedisKey: "testdata/non-existent-key.pem", + RedisTLS: true, }, + wantErr: "failed to get TLS config", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := cache.NewOptions(tt.args.backend, tt.args.redisCACert, tt.args.redisCert, tt.args.redisKey, tt.args.redisTLS, tt.args.ttl) - tt.assertion(t, err) - assert.Equal(t, tt.want, got) + c, cleanup, err := cache.New(tt.opts) + defer cleanup() + + if tt.wantErr != "" { + assert.ErrorContains(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + assert.NotNil(t, c) + assert.IsType(t, tt.wantType, c) }) } } -func TestRedisOptions_BackendMasked(t *testing.T) { +func TestNewType(t *testing.T) { tests := []struct { - name string - fields cache.RedisOptions - want string + name string + backend string + wantType cache.Type }{ { - name: "redis cache backend masked", - fields: cache.RedisOptions{Backend: "redis://root:password@localhost:6379"}, - want: "redis://****@localhost:6379", + name: "redis backend", + backend: "redis://localhost:6379", + wantType: cache.TypeRedis, + }, + { + name: "fs backend", + backend: "fs", + wantType: cache.TypeFS, }, { - name: "redis cache backend masked does nothing", - fields: cache.RedisOptions{Backend: "redis://localhost:6379"}, - want: "redis://localhost:6379", + name: "empty backend", + backend: "", + wantType: cache.TypeFS, + }, + { + name: "unknown backend", + backend: "unknown", + wantType: cache.TypeUnknown, }, } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.want, tt.fields.BackendMasked()) + got := cache.NewType(tt.backend) + assert.Equal(t, tt.wantType, got) }) } } diff --git a/pkg/cache/fs.go b/pkg/cache/fs.go index 08fa696e6555..edfac70b04e5 100644 --- a/pkg/cache/fs.go +++ b/pkg/cache/fs.go @@ -20,7 +20,7 @@ type FSCache struct { } func NewFSCache(cacheDir string) (FSCache, error) { - dir := filepath.Join(cacheDir, cacheDirName) + dir := filepath.Join(cacheDir, scanCacheDirName) if err := os.MkdirAll(dir, 0700); err != nil { return FSCache{}, xerrors.Errorf("failed to create cache dir: %w", err) } @@ -31,7 +31,10 @@ func NewFSCache(cacheDir string) (FSCache, error) { } err = db.Update(func(tx *bolt.Tx) error { - for _, bucket := range []string{artifactBucket, blobBucket} { + for _, bucket := range []string{ + artifactBucket, + blobBucket, + } { if _, err := tx.CreateBucketIfNotExists([]byte(bucket)); err != nil { return xerrors.Errorf("unable to create %s bucket: %w", bucket, err) } diff --git a/pkg/cache/fs_test.go b/pkg/cache/fs_test.go index 4eb059f5c508..9323391a3af4 100644 --- a/pkg/cache/fs_test.go +++ b/pkg/cache/fs_test.go @@ -373,7 +373,7 @@ func TestFSCache_PutArtifact(t *testing.T) { require.NoError(t, err, tt.name) } - fs.db.View(func(tx *bolt.Tx) error { + err = fs.db.View(func(tx *bolt.Tx) error { // check decompressedDigestBucket imageBucket := tx.Bucket([]byte(artifactBucket)) b := imageBucket.Get([]byte(tt.args.imageID)) @@ -381,6 +381,7 @@ func TestFSCache_PutArtifact(t *testing.T) { return nil }) + require.NoError(t, err) }) } } diff --git a/pkg/cache/redis.go b/pkg/cache/redis.go index af9d2622b531..2a4a12bda3f7 100644 --- a/pkg/cache/redis.go +++ b/pkg/cache/redis.go @@ -2,33 +2,118 @@ package cache import ( "context" + "crypto/tls" + "crypto/x509" "encoding/json" "fmt" + "os" + "strings" "time" "github.com/go-redis/redis/v8" "github.com/hashicorp/go-multierror" + "github.com/samber/lo" "golang.org/x/xerrors" "github.com/aquasecurity/trivy/pkg/fanal/types" + "github.com/aquasecurity/trivy/pkg/log" ) -var _ Cache = &RedisCache{} +var _ Cache = (*RedisCache)(nil) -const ( - redisPrefix = "fanal" -) +const redisPrefix = "fanal" + +type RedisOptions struct { + Backend string + TLS bool + TLSOptions RedisTLSOptions +} + +func NewRedisOptions(backend, caCert, cert, key string, enableTLS bool) (RedisOptions, error) { + tlsOpts, err := NewRedisTLSOptions(caCert, cert, key) + if err != nil { + return RedisOptions{}, xerrors.Errorf("redis TLS option error: %w", err) + } + + return RedisOptions{ + Backend: backend, + TLS: enableTLS, + TLSOptions: tlsOpts, + }, nil +} + +// BackendMasked returns the redis connection string masking credentials +func (o *RedisOptions) BackendMasked() string { + endIndex := strings.Index(o.Backend, "@") + if endIndex == -1 { + return o.Backend + } + + startIndex := strings.Index(o.Backend, "//") + + return fmt.Sprintf("%s****%s", o.Backend[:startIndex+2], o.Backend[endIndex:]) +} + +// RedisTLSOptions holds the options for redis cache +type RedisTLSOptions struct { + CACert string + Cert string + Key string +} + +func NewRedisTLSOptions(caCert, cert, key string) (RedisTLSOptions, error) { + opts := RedisTLSOptions{ + CACert: caCert, + Cert: cert, + Key: key, + } + + // If one of redis option not nil, make sure CA, cert, and key provided + if !lo.IsEmpty(opts) { + if opts.CACert == "" || opts.Cert == "" || opts.Key == "" { + return RedisTLSOptions{}, xerrors.Errorf("you must provide Redis CA, cert and key file path when using TLS") + } + } + return opts, nil +} type RedisCache struct { client *redis.Client expiration time.Duration } -func NewRedisCache(options *redis.Options, expiration time.Duration) RedisCache { +func NewRedisCache(backend, caCertPath, certPath, keyPath string, enableTLS bool, ttl time.Duration) (RedisCache, error) { + opts, err := NewRedisOptions(backend, caCertPath, certPath, keyPath, enableTLS) + if err != nil { + return RedisCache{}, xerrors.Errorf("failed to create Redis options: %w", err) + } + + log.Info("Redis scan cache", log.String("url", opts.BackendMasked())) + options, err := redis.ParseURL(opts.Backend) + if err != nil { + return RedisCache{}, xerrors.Errorf("failed to parse Redis URL: %w", err) + } + + if tlsOpts := opts.TLSOptions; !lo.IsEmpty(tlsOpts) { + caCert, cert, err := GetTLSConfig(tlsOpts.CACert, tlsOpts.Cert, tlsOpts.Key) + if err != nil { + return RedisCache{}, xerrors.Errorf("failed to get TLS config: %w", err) + } + + options.TLSConfig = &tls.Config{ + RootCAs: caCert, + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + } + } else if opts.TLS { + options.TLSConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + } + } return RedisCache{ client: redis.NewClient(options), - expiration: expiration, - } + expiration: ttl, + }, nil } func (c RedisCache) PutArtifact(artifactID string, artifactConfig types.ArtifactInfo) error { @@ -145,3 +230,21 @@ func (c RedisCache) Clear() error { } return nil } + +// GetTLSConfig gets tls config from CA, Cert and Key file +func GetTLSConfig(caCertPath, certPath, keyPath string) (*x509.CertPool, tls.Certificate, error) { + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, tls.Certificate{}, err + } + + caCert, err := os.ReadFile(caCertPath) + if err != nil { + return nil, tls.Certificate{}, err + } + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + + return caCertPool, cert, nil +} diff --git a/pkg/cache/redis_test.go b/pkg/cache/redis_test.go index 46716a7d7bfe..3cc8bbd702ad 100644 --- a/pkg/cache/redis_test.go +++ b/pkg/cache/redis_test.go @@ -7,7 +7,6 @@ import ( "time" "github.com/alicebob/miniredis/v2" - "github.com/go-redis/redis/v8" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -67,18 +66,15 @@ func TestRedisCache_PutArtifact(t *testing.T) { addr = "dummy:16379" } - c := cache.NewRedisCache(&redis.Options{ - Addr: addr, - }, 0) + c, err := cache.NewRedisCache(fmt.Sprintf("redis://%s", addr), "", "", "", false, 0) + require.NoError(t, err) err = c.PutArtifact(tt.args.artifactID, tt.args.artifactConfig) if tt.wantErr != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.wantErr) + require.ErrorContains(t, err, tt.wantErr) return - } else { - require.NoError(t, err) } + require.NoError(t, err) got, err := s.Get(tt.wantKey) require.NoError(t, err) @@ -156,18 +152,15 @@ func TestRedisCache_PutBlob(t *testing.T) { addr = "dummy:16379" } - c := cache.NewRedisCache(&redis.Options{ - Addr: addr, - }, 0) + c, err := cache.NewRedisCache(fmt.Sprintf("redis://%s", addr), "", "", "", false, 0) + require.NoError(t, err) err = c.PutBlob(tt.args.blobID, tt.args.blobConfig) if tt.wantErr != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.wantErr) + require.ErrorContains(t, err, tt.wantErr) return - } else { - require.NoError(t, err) } + require.NoError(t, err) got, err := s.Get(tt.wantKey) require.NoError(t, err) @@ -241,18 +234,15 @@ func TestRedisCache_GetArtifact(t *testing.T) { addr = "dummy:16379" } - c := cache.NewRedisCache(&redis.Options{ - Addr: addr, - }, 0) + c, err := cache.NewRedisCache(fmt.Sprintf("redis://%s", addr), "", "", "", false, 0) + require.NoError(t, err) got, err := c.GetArtifact(tt.artifactID) if tt.wantErr != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.wantErr) + require.ErrorContains(t, err, tt.wantErr) return - } else { - require.NoError(t, err) } + require.NoError(t, err) assert.Equal(t, tt.want, got) }) @@ -334,14 +324,12 @@ func TestRedisCache_GetBlob(t *testing.T) { addr = "dummy:16379" } - c := cache.NewRedisCache(&redis.Options{ - Addr: addr, - }, 0) + c, err := cache.NewRedisCache(fmt.Sprintf("redis://%s", addr), "", "", "", false, 0) + require.NoError(t, err) got, err := c.GetBlob(tt.blobID) if tt.wantErr != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.wantErr) + require.ErrorContains(t, err, tt.wantErr) return } @@ -445,14 +433,12 @@ func TestRedisCache_MissingBlobs(t *testing.T) { addr = "dummy:6379" } - c := cache.NewRedisCache(&redis.Options{ - Addr: addr, - }, 0) + c, err := cache.NewRedisCache(fmt.Sprintf("redis://%s", addr), "", "", "", false, 0) + require.NoError(t, err) missingArtifact, missingBlobIDs, err := c.MissingBlobs(tt.args.artifactID, tt.args.blobIDs) if tt.wantErr != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.wantErr) + require.ErrorContains(t, err, tt.wantErr) return } @@ -470,9 +456,9 @@ func TestRedisCache_Close(t *testing.T) { defer s.Close() t.Run("close", func(t *testing.T) { - c := cache.NewRedisCache(&redis.Options{ - Addr: s.Addr(), - }, 0) + c, err := cache.NewRedisCache(fmt.Sprintf("redis://%s", s.Addr()), "", "", "", false, 0) + require.NoError(t, err) + closeErr := c.Close() require.NoError(t, closeErr) time.Sleep(3 * time.Second) // give it some time @@ -492,9 +478,9 @@ func TestRedisCache_Clear(t *testing.T) { s.Set("foo", "bar") t.Run("clear", func(t *testing.T) { - c := cache.NewRedisCache(&redis.Options{ - Addr: s.Addr(), - }, 0) + c, err := cache.NewRedisCache(fmt.Sprintf("redis://%s", s.Addr()), "", "", "", false, 0) + require.NoError(t, err) + require.NoError(t, c.Clear()) for i := 0; i < 200; i++ { assert.False(t, s.Exists(fmt.Sprintf("fanal::key%d", i))) @@ -546,9 +532,8 @@ func TestRedisCache_DeleteBlobs(t *testing.T) { addr = "dummy:16379" } - c := cache.NewRedisCache(&redis.Options{ - Addr: addr, - }, 0) + c, err := cache.NewRedisCache(fmt.Sprintf("redis://%s", addr), "", "", "", false, 0) + require.NoError(t, err) err = c.DeleteBlobs(tt.args.blobIDs) if tt.wantErr != "" { @@ -560,3 +545,27 @@ func TestRedisCache_DeleteBlobs(t *testing.T) { }) } } + +func TestRedisOptions_BackendMasked(t *testing.T) { + tests := []struct { + name string + fields cache.RedisOptions + want string + }{ + { + name: "redis cache backend masked", + fields: cache.RedisOptions{Backend: "redis://root:password@localhost:6379"}, + want: "redis://****@localhost:6379", + }, + { + name: "redis cache backend masked does nothing", + fields: cache.RedisOptions{Backend: "redis://localhost:6379"}, + want: "redis://localhost:6379", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, tt.fields.BackendMasked()) + }) + } +} diff --git a/pkg/cache/remote.go b/pkg/cache/remote.go index f55d877d5f3d..44c9f63c92d8 100644 --- a/pkg/cache/remote.go +++ b/pkg/cache/remote.go @@ -13,6 +13,14 @@ import ( rpcCache "github.com/aquasecurity/trivy/rpc/cache" ) +var _ ArtifactCache = (*RemoteCache)(nil) + +type RemoteOptions struct { + ServerAddr string + CustomHeaders http.Header + Insecure bool +} + // RemoteCache implements remote cache type RemoteCache struct { ctx context.Context // for custom header @@ -20,18 +28,18 @@ type RemoteCache struct { } // NewRemoteCache is the factory method for RemoteCache -func NewRemoteCache(url string, customHeaders http.Header, insecure bool) ArtifactCache { - ctx := client.WithCustomHeaders(context.Background(), customHeaders) +func NewRemoteCache(opts RemoteOptions) *RemoteCache { + ctx := client.WithCustomHeaders(context.Background(), opts.CustomHeaders) httpClient := &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, TLSClientConfig: &tls.Config{ - InsecureSkipVerify: insecure, + InsecureSkipVerify: opts.Insecure, }, }, } - c := rpcCache.NewCacheProtobufClient(url, httpClient) + c := rpcCache.NewCacheProtobufClient(opts.ServerAddr, httpClient) return &RemoteCache{ ctx: ctx, client: c, diff --git a/pkg/cache/remote_test.go b/pkg/cache/remote_test.go index bbfd72e3b20d..3e1363d5dd4d 100644 --- a/pkg/cache/remote_test.go +++ b/pkg/cache/remote_test.go @@ -145,7 +145,11 @@ func TestRemoteCache_PutArtifact(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := cache.NewRemoteCache(ts.URL, tt.args.customHeaders, false) + c := cache.NewRemoteCache(cache.RemoteOptions{ + ServerAddr: ts.URL, + CustomHeaders: tt.args.customHeaders, + Insecure: false, + }) err := c.PutArtifact(tt.args.imageID, tt.args.imageInfo) if tt.wantErr != "" { require.Error(t, err, tt.name) @@ -206,7 +210,11 @@ func TestRemoteCache_PutBlob(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := cache.NewRemoteCache(ts.URL, tt.args.customHeaders, false) + c := cache.NewRemoteCache(cache.RemoteOptions{ + ServerAddr: ts.URL, + CustomHeaders: tt.args.customHeaders, + Insecure: false, + }) err := c.PutBlob(tt.args.diffID, tt.args.layerInfo) if tt.wantErr != "" { require.Error(t, err, tt.name) @@ -284,7 +292,11 @@ func TestRemoteCache_MissingBlobs(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := cache.NewRemoteCache(ts.URL, tt.args.customHeaders, false) + c := cache.NewRemoteCache(cache.RemoteOptions{ + ServerAddr: ts.URL, + CustomHeaders: tt.args.customHeaders, + Insecure: false, + }) gotMissingImage, gotMissingLayerIDs, err := c.MissingBlobs(tt.args.imageID, tt.args.layerIDs) if tt.wantErr != "" { require.Error(t, err, tt.name) @@ -334,7 +346,11 @@ func TestRemoteCache_PutArtifactInsecure(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := cache.NewRemoteCache(ts.URL, nil, tt.args.insecure) + c := cache.NewRemoteCache(cache.RemoteOptions{ + ServerAddr: ts.URL, + CustomHeaders: nil, + Insecure: tt.args.insecure, + }) err := c.PutArtifact(tt.args.imageID, tt.args.imageInfo) if tt.wantErr != "" { require.Error(t, err) diff --git a/pkg/commands/artifact/inject.go b/pkg/commands/artifact/inject.go index 174af5045f74..da2c05ac91e4 100644 --- a/pkg/commands/artifact/inject.go +++ b/pkg/commands/artifact/inject.go @@ -21,8 +21,7 @@ import ( // initializeImageScanner is for container image scanning in standalone mode // e.g. dockerd, container registry, podman, etc. -func initializeImageScanner(ctx context.Context, imageName string, artifactCache cache.ArtifactCache, - localArtifactCache cache.LocalArtifactCache, imageOpt types.ImageOptions, artifactOption artifact.Option) ( +func initializeImageScanner(ctx context.Context, imageName string, imageOpt types.ImageOptions, cacheOptions cache.Options, artifactOption artifact.Option) ( scanner.Scanner, func(), error) { wire.Build(scanner.StandaloneDockerSet) return scanner.Scanner{}, nil, nil @@ -30,33 +29,29 @@ func initializeImageScanner(ctx context.Context, imageName string, artifactCache // initializeArchiveScanner is for container image archive scanning in standalone mode // e.g. docker save -o alpine.tar alpine:3.15 -func initializeArchiveScanner(ctx context.Context, filePath string, artifactCache cache.ArtifactCache, - localArtifactCache cache.LocalArtifactCache, artifactOption artifact.Option) (scanner.Scanner, error) { +func initializeArchiveScanner(ctx context.Context, filePath string, cacheOptions cache.Options, artifactOption artifact.Option) ( + scanner.Scanner, func(), error) { wire.Build(scanner.StandaloneArchiveSet) - return scanner.Scanner{}, nil + return scanner.Scanner{}, nil, nil } // initializeFilesystemScanner is for filesystem scanning in standalone mode -func initializeFilesystemScanner(ctx context.Context, path string, artifactCache cache.ArtifactCache, - localArtifactCache cache.LocalArtifactCache, artifactOption artifact.Option) (scanner.Scanner, func(), error) { +func initializeFilesystemScanner(ctx context.Context, path string, cacheOptions cache.Options, artifactOption artifact.Option) (scanner.Scanner, func(), error) { wire.Build(scanner.StandaloneFilesystemSet) return scanner.Scanner{}, nil, nil } -func initializeRepositoryScanner(ctx context.Context, url string, artifactCache cache.ArtifactCache, - localArtifactCache cache.LocalArtifactCache, artifactOption artifact.Option) (scanner.Scanner, func(), error) { +func initializeRepositoryScanner(ctx context.Context, url string, cacheOptions cache.Options, artifactOption artifact.Option) (scanner.Scanner, func(), error) { wire.Build(scanner.StandaloneRepositorySet) return scanner.Scanner{}, nil, nil } -func initializeSBOMScanner(ctx context.Context, filePath string, artifactCache cache.ArtifactCache, - localArtifactCache cache.LocalArtifactCache, artifactOption artifact.Option) (scanner.Scanner, func(), error) { +func initializeSBOMScanner(ctx context.Context, filePath string, cacheOptions cache.Options, artifactOption artifact.Option) (scanner.Scanner, func(), error) { wire.Build(scanner.StandaloneSBOMSet) return scanner.Scanner{}, nil, nil } -func initializeVMScanner(ctx context.Context, filePath string, artifactCache cache.ArtifactCache, - localArtifactCache cache.LocalArtifactCache, artifactOption artifact.Option) ( +func initializeVMScanner(ctx context.Context, filePath string, cacheOptions cache.Options, artifactOption artifact.Option) ( scanner.Scanner, func(), error) { wire.Build(scanner.StandaloneVMSet) return scanner.Scanner{}, nil, nil @@ -68,7 +63,7 @@ func initializeVMScanner(ctx context.Context, filePath string, artifactCache cac // initializeRemoteImageScanner is for container image scanning in client/server mode // e.g. dockerd, container registry, podman, etc. -func initializeRemoteImageScanner(ctx context.Context, imageName string, artifactCache cache.ArtifactCache, +func initializeRemoteImageScanner(ctx context.Context, imageName string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ScannerOption, imageOpt types.ImageOptions, artifactOption artifact.Option) ( scanner.Scanner, func(), error) { wire.Build(scanner.RemoteDockerSet) @@ -77,21 +72,21 @@ func initializeRemoteImageScanner(ctx context.Context, imageName string, artifac // initializeRemoteArchiveScanner is for container image archive scanning in client/server mode // e.g. docker save -o alpine.tar alpine:3.15 -func initializeRemoteArchiveScanner(ctx context.Context, filePath string, artifactCache cache.ArtifactCache, - remoteScanOptions client.ScannerOption, artifactOption artifact.Option) (scanner.Scanner, error) { +func initializeRemoteArchiveScanner(ctx context.Context, filePath string, remoteCacheOptions cache.RemoteOptions, + remoteScanOptions client.ScannerOption, artifactOption artifact.Option) (scanner.Scanner, func(), error) { wire.Build(scanner.RemoteArchiveSet) - return scanner.Scanner{}, nil + return scanner.Scanner{}, nil, nil } // initializeRemoteFilesystemScanner is for filesystem scanning in client/server mode -func initializeRemoteFilesystemScanner(ctx context.Context, path string, artifactCache cache.ArtifactCache, +func initializeRemoteFilesystemScanner(ctx context.Context, path string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ScannerOption, artifactOption artifact.Option) (scanner.Scanner, func(), error) { wire.Build(scanner.RemoteFilesystemSet) return scanner.Scanner{}, nil, nil } // initializeRemoteRepositoryScanner is for repository scanning in client/server mode -func initializeRemoteRepositoryScanner(ctx context.Context, url string, artifactCache cache.ArtifactCache, +func initializeRemoteRepositoryScanner(ctx context.Context, url string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ScannerOption, artifactOption artifact.Option) ( scanner.Scanner, func(), error) { wire.Build(scanner.RemoteRepositorySet) @@ -99,14 +94,14 @@ func initializeRemoteRepositoryScanner(ctx context.Context, url string, artifact } // initializeRemoteSBOMScanner is for sbom scanning in client/server mode -func initializeRemoteSBOMScanner(ctx context.Context, path string, artifactCache cache.ArtifactCache, +func initializeRemoteSBOMScanner(ctx context.Context, path string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ScannerOption, artifactOption artifact.Option) (scanner.Scanner, func(), error) { wire.Build(scanner.RemoteSBOMSet) return scanner.Scanner{}, nil, nil } // initializeRemoteVMScanner is for vm scanning in client/server mode -func initializeRemoteVMScanner(ctx context.Context, path string, artifactCache cache.ArtifactCache, +func initializeRemoteVMScanner(ctx context.Context, path string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ScannerOption, artifactOption artifact.Option) (scanner.Scanner, func(), error) { wire.Build(scanner.RemoteVMSet) return scanner.Scanner{}, nil, nil diff --git a/pkg/commands/artifact/run.go b/pkg/commands/artifact/run.go index bb2e144c5050..db73cf58b391 100644 --- a/pkg/commands/artifact/run.go +++ b/pkg/commands/artifact/run.go @@ -57,8 +57,8 @@ type ScannerConfig struct { Target string // Cache - ArtifactCache cache.ArtifactCache - LocalArtifactCache cache.LocalArtifactCache + CacheOptions cache.Options + RemoteCacheOptions cache.RemoteOptions // Client/Server options ServerOption client.ScannerOption @@ -89,37 +89,31 @@ type Runner interface { } type runner struct { - cache cache.ArtifactCache - localCache cache.LocalArtifactCache - dbOpen bool + initializeScanner InitializeScanner + dbOpen bool // WASM modules module *module.Manager } -type runnerOption func(*runner) +type RunnerOption func(*runner) -// WithCacheClient takes a custom cache implementation +// WithInitializeScanner takes a custom scanner initialization function. // It is useful when Trivy is imported as a library. -func WithCacheClient(c cache.Cache) runnerOption { +func WithInitializeScanner(f InitializeScanner) RunnerOption { return func(r *runner) { - r.cache = c - r.localCache = c + r.initializeScanner = f } } // NewRunner initializes Runner that provides scanning functionalities. // It is possible to return SkipScan and it must be handled by caller. -func NewRunner(ctx context.Context, cliOptions flag.Options, opts ...runnerOption) (Runner, error) { +func NewRunner(ctx context.Context, cliOptions flag.Options, opts ...RunnerOption) (Runner, error) { r := &runner{} for _, opt := range opts { opt(r) } - if err := r.initCache(cliOptions); err != nil { - return nil, xerrors.Errorf("cache error: %w", err) - } - // Update the vulnerability database if needed. if err := r.initDB(ctx, cliOptions); err != nil { return nil, xerrors.Errorf("DB error: %w", err) @@ -142,10 +136,6 @@ func NewRunner(ctx context.Context, cliOptions flag.Options, opts ...runnerOptio // Close closes everything func (r *runner) Close(ctx context.Context) error { var errs error - if err := r.localCache.Close(); err != nil { - errs = multierror.Append(errs, err) - } - if r.dbOpen { if err := db.Close(); err != nil { errs = multierror.Append(errs, err) @@ -258,6 +248,9 @@ func (r *runner) ScanVM(ctx context.Context, opts flag.Options) (types.Report, e } func (r *runner) scanArtifact(ctx context.Context, opts flag.Options, initializeScanner InitializeScanner) (types.Report, error) { + if r.initializeScanner != nil { + initializeScanner = r.initializeScanner + } report, err := r.scan(ctx, opts, initializeScanner) if err != nil { return types.Report{}, xerrors.Errorf("scan error: %w", err) @@ -335,31 +328,6 @@ func (r *runner) initJavaDB(opts flag.Options) error { return nil } -func (r *runner) initCache(opts flag.Options) error { - // Skip initializing cache when custom cache is passed - if r.cache != nil { - return nil - } - - // client/server mode - if opts.ServerAddr != "" { - r.cache = cache.NewRemoteCache(opts.ServerAddr, opts.CustomHeaders, opts.Insecure) - r.localCache = cache.NewNopCache() // No need to use local cache in client/server mode - return nil - } - - // standalone mode - cacheClient, err := cache.New(opts.CacheDir, opts.CacheOptions.CacheBackendOptions) - if err != nil { - return xerrors.Errorf("unable to initialize the cache: %w", err) - } - log.Debug("Cache dir", log.String("dir", opts.CacheDir)) - - r.cache = cacheClient - r.localCache = cacheClient - return nil -} - // Run performs artifact scanning func Run(ctx context.Context, opts flag.Options, targetKind TargetKind) (err error) { ctx, cancel := context.WithTimeout(ctx, opts.Timeout) @@ -588,8 +556,8 @@ func (r *runner) initScannerConfig(opts flag.Options) (ScannerConfig, types.Scan return ScannerConfig{ Target: target, - ArtifactCache: r.cache, - LocalArtifactCache: r.localCache, + CacheOptions: opts.CacheOpts(), + RemoteCacheOptions: opts.RemoteCacheOpts(), ServerOption: client.ScannerOption{ RemoteURL: opts.ServerAddr, CustomHeaders: opts.CustomHeaders, @@ -607,10 +575,9 @@ func (r *runner) initScannerConfig(opts flag.Options) (ScannerConfig, types.Scan RepoTag: opts.RepoTag, SBOMSources: opts.SBOMSources, RekorURL: opts.RekorURL, - //Platform: opts.Platform, - AWSRegion: opts.Region, - AWSEndpoint: opts.Endpoint, - FileChecksum: fileChecksum, + AWSRegion: opts.Region, + AWSEndpoint: opts.Endpoint, + FileChecksum: fileChecksum, // For image scanning ImageOption: ftypes.ImageOptions{ diff --git a/pkg/commands/artifact/scanner.go b/pkg/commands/artifact/scanner.go index cf7a58c52693..88430a09961b 100644 --- a/pkg/commands/artifact/scanner.go +++ b/pkg/commands/artifact/scanner.go @@ -11,8 +11,7 @@ import ( // imageStandaloneScanner initializes a container image scanner in standalone mode // $ trivy image alpine:3.15 func imageStandaloneScanner(ctx context.Context, conf ScannerConfig) (scanner.Scanner, func(), error) { - s, cleanup, err := initializeImageScanner(ctx, conf.Target, conf.ArtifactCache, conf.LocalArtifactCache, - conf.ArtifactOption.ImageOption, conf.ArtifactOption) + s, cleanup, err := initializeImageScanner(ctx, conf.Target, conf.ArtifactOption.ImageOption, conf.CacheOptions, conf.ArtifactOption) if err != nil { return scanner.Scanner{}, func() {}, xerrors.Errorf("unable to initialize an image scanner: %w", err) } @@ -22,18 +21,18 @@ func imageStandaloneScanner(ctx context.Context, conf ScannerConfig) (scanner.Sc // archiveStandaloneScanner initializes an image archive scanner in standalone mode // $ trivy image --input alpine.tar func archiveStandaloneScanner(ctx context.Context, conf ScannerConfig) (scanner.Scanner, func(), error) { - s, err := initializeArchiveScanner(ctx, conf.Target, conf.ArtifactCache, conf.LocalArtifactCache, conf.ArtifactOption) + s, cleanup, err := initializeArchiveScanner(ctx, conf.Target, conf.CacheOptions, conf.ArtifactOption) if err != nil { return scanner.Scanner{}, func() {}, xerrors.Errorf("unable to initialize the archive scanner: %w", err) } - return s, func() {}, nil + return s, cleanup, nil } // imageRemoteScanner initializes a container image scanner in client/server mode // $ trivy image --server localhost:4954 alpine:3.15 func imageRemoteScanner(ctx context.Context, conf ScannerConfig) ( scanner.Scanner, func(), error) { - s, cleanup, err := initializeRemoteImageScanner(ctx, conf.Target, conf.ArtifactCache, conf.ServerOption, + s, cleanup, err := initializeRemoteImageScanner(ctx, conf.Target, conf.RemoteCacheOptions, conf.ServerOption, conf.ArtifactOption.ImageOption, conf.ArtifactOption) if err != nil { return scanner.Scanner{}, nil, xerrors.Errorf("unable to initialize a remote image scanner: %w", err) @@ -45,16 +44,16 @@ func imageRemoteScanner(ctx context.Context, conf ScannerConfig) ( // $ trivy image --server localhost:4954 --input alpine.tar func archiveRemoteScanner(ctx context.Context, conf ScannerConfig) (scanner.Scanner, func(), error) { // Scan tar file - s, err := initializeRemoteArchiveScanner(ctx, conf.Target, conf.ArtifactCache, conf.ServerOption, conf.ArtifactOption) + s, cleanup, err := initializeRemoteArchiveScanner(ctx, conf.Target, conf.RemoteCacheOptions, conf.ServerOption, conf.ArtifactOption) if err != nil { return scanner.Scanner{}, nil, xerrors.Errorf("unable to initialize the remote archive scanner: %w", err) } - return s, func() {}, nil + return s, cleanup, nil } // filesystemStandaloneScanner initializes a filesystem scanner in standalone mode func filesystemStandaloneScanner(ctx context.Context, conf ScannerConfig) (scanner.Scanner, func(), error) { - s, cleanup, err := initializeFilesystemScanner(ctx, conf.Target, conf.ArtifactCache, conf.LocalArtifactCache, conf.ArtifactOption) + s, cleanup, err := initializeFilesystemScanner(ctx, conf.Target, conf.CacheOptions, conf.ArtifactOption) if err != nil { return scanner.Scanner{}, func() {}, xerrors.Errorf("unable to initialize a filesystem scanner: %w", err) } @@ -63,7 +62,7 @@ func filesystemStandaloneScanner(ctx context.Context, conf ScannerConfig) (scann // filesystemRemoteScanner initializes a filesystem scanner in client/server mode func filesystemRemoteScanner(ctx context.Context, conf ScannerConfig) (scanner.Scanner, func(), error) { - s, cleanup, err := initializeRemoteFilesystemScanner(ctx, conf.Target, conf.ArtifactCache, conf.ServerOption, conf.ArtifactOption) + s, cleanup, err := initializeRemoteFilesystemScanner(ctx, conf.Target, conf.RemoteCacheOptions, conf.ServerOption, conf.ArtifactOption) if err != nil { return scanner.Scanner{}, func() {}, xerrors.Errorf("unable to initialize a remote filesystem scanner: %w", err) } @@ -72,7 +71,7 @@ func filesystemRemoteScanner(ctx context.Context, conf ScannerConfig) (scanner.S // repositoryStandaloneScanner initializes a repository scanner in standalone mode func repositoryStandaloneScanner(ctx context.Context, conf ScannerConfig) (scanner.Scanner, func(), error) { - s, cleanup, err := initializeRepositoryScanner(ctx, conf.Target, conf.ArtifactCache, conf.LocalArtifactCache, conf.ArtifactOption) + s, cleanup, err := initializeRepositoryScanner(ctx, conf.Target, conf.CacheOptions, conf.ArtifactOption) if err != nil { return scanner.Scanner{}, func() {}, xerrors.Errorf("unable to initialize a repository scanner: %w", err) } @@ -81,7 +80,7 @@ func repositoryStandaloneScanner(ctx context.Context, conf ScannerConfig) (scann // repositoryRemoteScanner initializes a repository scanner in client/server mode func repositoryRemoteScanner(ctx context.Context, conf ScannerConfig) (scanner.Scanner, func(), error) { - s, cleanup, err := initializeRemoteRepositoryScanner(ctx, conf.Target, conf.ArtifactCache, conf.ServerOption, + s, cleanup, err := initializeRemoteRepositoryScanner(ctx, conf.Target, conf.RemoteCacheOptions, conf.ServerOption, conf.ArtifactOption) if err != nil { return scanner.Scanner{}, func() {}, xerrors.Errorf("unable to initialize a remote repository scanner: %w", err) @@ -91,7 +90,7 @@ func repositoryRemoteScanner(ctx context.Context, conf ScannerConfig) (scanner.S // sbomStandaloneScanner initializes a SBOM scanner in standalone mode func sbomStandaloneScanner(ctx context.Context, conf ScannerConfig) (scanner.Scanner, func(), error) { - s, cleanup, err := initializeSBOMScanner(ctx, conf.Target, conf.ArtifactCache, conf.LocalArtifactCache, conf.ArtifactOption) + s, cleanup, err := initializeSBOMScanner(ctx, conf.Target, conf.CacheOptions, conf.ArtifactOption) if err != nil { return scanner.Scanner{}, func() {}, xerrors.Errorf("unable to initialize a cycloneDX scanner: %w", err) } @@ -100,7 +99,7 @@ func sbomStandaloneScanner(ctx context.Context, conf ScannerConfig) (scanner.Sca // sbomRemoteScanner initializes a SBOM scanner in client/server mode func sbomRemoteScanner(ctx context.Context, conf ScannerConfig) (scanner.Scanner, func(), error) { - s, cleanup, err := initializeRemoteSBOMScanner(ctx, conf.Target, conf.ArtifactCache, conf.ServerOption, conf.ArtifactOption) + s, cleanup, err := initializeRemoteSBOMScanner(ctx, conf.Target, conf.RemoteCacheOptions, conf.ServerOption, conf.ArtifactOption) if err != nil { return scanner.Scanner{}, func() {}, xerrors.Errorf("unable to initialize a remote cycloneDX scanner: %w", err) } @@ -109,7 +108,7 @@ func sbomRemoteScanner(ctx context.Context, conf ScannerConfig) (scanner.Scanner // vmStandaloneScanner initializes a VM scanner in standalone mode func vmStandaloneScanner(ctx context.Context, conf ScannerConfig) (scanner.Scanner, func(), error) { - s, cleanup, err := initializeVMScanner(ctx, conf.Target, conf.ArtifactCache, conf.LocalArtifactCache, conf.ArtifactOption) + s, cleanup, err := initializeVMScanner(ctx, conf.Target, conf.CacheOptions, conf.ArtifactOption) if err != nil { return scanner.Scanner{}, func() {}, xerrors.Errorf("unable to initialize a vm scanner: %w", err) } @@ -118,7 +117,7 @@ func vmStandaloneScanner(ctx context.Context, conf ScannerConfig) (scanner.Scann // vmRemoteScanner initializes a VM scanner in client/server mode func vmRemoteScanner(ctx context.Context, conf ScannerConfig) (scanner.Scanner, func(), error) { - s, cleanup, err := initializeRemoteVMScanner(ctx, conf.Target, conf.ArtifactCache, conf.ServerOption, conf.ArtifactOption) + s, cleanup, err := initializeRemoteVMScanner(ctx, conf.Target, conf.RemoteCacheOptions, conf.ServerOption, conf.ArtifactOption) if err != nil { return scanner.Scanner{}, func() {}, xerrors.Errorf("unable to initialize a remote vm scanner: %w", err) } diff --git a/pkg/commands/artifact/wire_gen.go b/pkg/commands/artifact/wire_gen.go index 85e66b8214f2..f47c0b0f5649 100644 --- a/pkg/commands/artifact/wire_gen.go +++ b/pkg/commands/artifact/wire_gen.go @@ -9,6 +9,7 @@ package artifact import ( "context" "github.com/aquasecurity/trivy-db/pkg/db" + "github.com/aquasecurity/trivy/pkg/cache" "github.com/aquasecurity/trivy/pkg/fanal/applier" "github.com/aquasecurity/trivy/pkg/fanal/artifact" image2 "github.com/aquasecurity/trivy/pkg/fanal/artifact/image" @@ -16,7 +17,6 @@ import ( "github.com/aquasecurity/trivy/pkg/fanal/artifact/repo" "github.com/aquasecurity/trivy/pkg/fanal/artifact/sbom" "github.com/aquasecurity/trivy/pkg/fanal/artifact/vm" - "github.com/aquasecurity/trivy/pkg/cache" "github.com/aquasecurity/trivy/pkg/fanal/image" "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/fanal/walker" @@ -32,32 +32,43 @@ import ( // initializeImageScanner is for container image scanning in standalone mode // e.g. dockerd, container registry, podman, etc. -func initializeImageScanner(ctx context.Context, imageName string, artifactCache cache.ArtifactCache, localArtifactCache cache.LocalArtifactCache, imageOpt types.ImageOptions, artifactOption artifact.Option) (scanner.Scanner, func(), error) { - applierApplier := applier.NewApplier(localArtifactCache) +func initializeImageScanner(ctx context.Context, imageName string, imageOpt types.ImageOptions, cacheOptions cache.Options, artifactOption artifact.Option) (scanner.Scanner, func(), error) { + cacheCache, cleanup, err := cache.New(cacheOptions) + if err != nil { + return scanner.Scanner{}, nil, err + } + applierApplier := applier.NewApplier(cacheCache) ospkgScanner := ospkg.NewScanner() langpkgScanner := langpkg.NewScanner() config := db.Config{} client := vulnerability.NewClient(config) localScanner := local.NewScanner(applierApplier, ospkgScanner, langpkgScanner, client) - typesImage, cleanup, err := image.NewContainerImage(ctx, imageName, imageOpt) + typesImage, cleanup2, err := image.NewContainerImage(ctx, imageName, imageOpt) if err != nil { + cleanup() return scanner.Scanner{}, nil, err } - artifactArtifact, err := image2.NewArtifact(typesImage, artifactCache, artifactOption) + artifactArtifact, err := image2.NewArtifact(typesImage, cacheCache, artifactOption) if err != nil { + cleanup2() cleanup() return scanner.Scanner{}, nil, err } scannerScanner := scanner.NewScanner(localScanner, artifactArtifact) return scannerScanner, func() { + cleanup2() cleanup() }, nil } // initializeArchiveScanner is for container image archive scanning in standalone mode // e.g. docker save -o alpine.tar alpine:3.15 -func initializeArchiveScanner(ctx context.Context, filePath string, artifactCache cache.ArtifactCache, localArtifactCache cache.LocalArtifactCache, artifactOption artifact.Option) (scanner.Scanner, error) { - applierApplier := applier.NewApplier(localArtifactCache) +func initializeArchiveScanner(ctx context.Context, filePath string, cacheOptions cache.Options, artifactOption artifact.Option) (scanner.Scanner, func(), error) { + cacheCache, cleanup, err := cache.New(cacheOptions) + if err != nil { + return scanner.Scanner{}, nil, err + } + applierApplier := applier.NewApplier(cacheCache) ospkgScanner := ospkg.NewScanner() langpkgScanner := langpkg.NewScanner() config := db.Config{} @@ -65,95 +76,124 @@ func initializeArchiveScanner(ctx context.Context, filePath string, artifactCach localScanner := local.NewScanner(applierApplier, ospkgScanner, langpkgScanner, client) typesImage, err := image.NewArchiveImage(filePath) if err != nil { - return scanner.Scanner{}, err + cleanup() + return scanner.Scanner{}, nil, err } - artifactArtifact, err := image2.NewArtifact(typesImage, artifactCache, artifactOption) + artifactArtifact, err := image2.NewArtifact(typesImage, cacheCache, artifactOption) if err != nil { - return scanner.Scanner{}, err + cleanup() + return scanner.Scanner{}, nil, err } scannerScanner := scanner.NewScanner(localScanner, artifactArtifact) - return scannerScanner, nil + return scannerScanner, func() { + cleanup() + }, nil } // initializeFilesystemScanner is for filesystem scanning in standalone mode -func initializeFilesystemScanner(ctx context.Context, path string, artifactCache cache.ArtifactCache, localArtifactCache cache.LocalArtifactCache, artifactOption artifact.Option) (scanner.Scanner, func(), error) { - applierApplier := applier.NewApplier(localArtifactCache) +func initializeFilesystemScanner(ctx context.Context, path string, cacheOptions cache.Options, artifactOption artifact.Option) (scanner.Scanner, func(), error) { + cacheCache, cleanup, err := cache.New(cacheOptions) + if err != nil { + return scanner.Scanner{}, nil, err + } + applierApplier := applier.NewApplier(cacheCache) ospkgScanner := ospkg.NewScanner() langpkgScanner := langpkg.NewScanner() config := db.Config{} client := vulnerability.NewClient(config) localScanner := local.NewScanner(applierApplier, ospkgScanner, langpkgScanner, client) fs := walker.NewFS() - artifactArtifact, err := local2.NewArtifact(path, artifactCache, fs, artifactOption) + artifactArtifact, err := local2.NewArtifact(path, cacheCache, fs, artifactOption) if err != nil { + cleanup() return scanner.Scanner{}, nil, err } scannerScanner := scanner.NewScanner(localScanner, artifactArtifact) return scannerScanner, func() { + cleanup() }, nil } -func initializeRepositoryScanner(ctx context.Context, url string, artifactCache cache.ArtifactCache, localArtifactCache cache.LocalArtifactCache, artifactOption artifact.Option) (scanner.Scanner, func(), error) { - applierApplier := applier.NewApplier(localArtifactCache) +func initializeRepositoryScanner(ctx context.Context, url string, cacheOptions cache.Options, artifactOption artifact.Option) (scanner.Scanner, func(), error) { + cacheCache, cleanup, err := cache.New(cacheOptions) + if err != nil { + return scanner.Scanner{}, nil, err + } + applierApplier := applier.NewApplier(cacheCache) ospkgScanner := ospkg.NewScanner() langpkgScanner := langpkg.NewScanner() config := db.Config{} client := vulnerability.NewClient(config) localScanner := local.NewScanner(applierApplier, ospkgScanner, langpkgScanner, client) fs := walker.NewFS() - artifactArtifact, cleanup, err := repo.NewArtifact(url, artifactCache, fs, artifactOption) + artifactArtifact, cleanup2, err := repo.NewArtifact(url, cacheCache, fs, artifactOption) if err != nil { + cleanup() return scanner.Scanner{}, nil, err } scannerScanner := scanner.NewScanner(localScanner, artifactArtifact) return scannerScanner, func() { + cleanup2() cleanup() }, nil } -func initializeSBOMScanner(ctx context.Context, filePath string, artifactCache cache.ArtifactCache, localArtifactCache cache.LocalArtifactCache, artifactOption artifact.Option) (scanner.Scanner, func(), error) { - applierApplier := applier.NewApplier(localArtifactCache) +func initializeSBOMScanner(ctx context.Context, filePath string, cacheOptions cache.Options, artifactOption artifact.Option) (scanner.Scanner, func(), error) { + cacheCache, cleanup, err := cache.New(cacheOptions) + if err != nil { + return scanner.Scanner{}, nil, err + } + applierApplier := applier.NewApplier(cacheCache) ospkgScanner := ospkg.NewScanner() langpkgScanner := langpkg.NewScanner() config := db.Config{} client := vulnerability.NewClient(config) localScanner := local.NewScanner(applierApplier, ospkgScanner, langpkgScanner, client) - artifactArtifact, err := sbom.NewArtifact(filePath, artifactCache, artifactOption) + artifactArtifact, err := sbom.NewArtifact(filePath, cacheCache, artifactOption) if err != nil { + cleanup() return scanner.Scanner{}, nil, err } scannerScanner := scanner.NewScanner(localScanner, artifactArtifact) return scannerScanner, func() { + cleanup() }, nil } -func initializeVMScanner(ctx context.Context, filePath string, artifactCache cache.ArtifactCache, localArtifactCache cache.LocalArtifactCache, artifactOption artifact.Option) (scanner.Scanner, func(), error) { - applierApplier := applier.NewApplier(localArtifactCache) +func initializeVMScanner(ctx context.Context, filePath string, cacheOptions cache.Options, artifactOption artifact.Option) (scanner.Scanner, func(), error) { + cacheCache, cleanup, err := cache.New(cacheOptions) + if err != nil { + return scanner.Scanner{}, nil, err + } + applierApplier := applier.NewApplier(cacheCache) ospkgScanner := ospkg.NewScanner() langpkgScanner := langpkg.NewScanner() config := db.Config{} client := vulnerability.NewClient(config) localScanner := local.NewScanner(applierApplier, ospkgScanner, langpkgScanner, client) walkerVM := walker.NewVM() - artifactArtifact, err := vm.NewArtifact(filePath, artifactCache, walkerVM, artifactOption) + artifactArtifact, err := vm.NewArtifact(filePath, cacheCache, walkerVM, artifactOption) if err != nil { + cleanup() return scanner.Scanner{}, nil, err } scannerScanner := scanner.NewScanner(localScanner, artifactArtifact) return scannerScanner, func() { + cleanup() }, nil } // initializeRemoteImageScanner is for container image scanning in client/server mode // e.g. dockerd, container registry, podman, etc. -func initializeRemoteImageScanner(ctx context.Context, imageName string, artifactCache cache.ArtifactCache, remoteScanOptions client.ScannerOption, imageOpt types.ImageOptions, artifactOption artifact.Option) (scanner.Scanner, func(), error) { +func initializeRemoteImageScanner(ctx context.Context, imageName string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ScannerOption, imageOpt types.ImageOptions, artifactOption artifact.Option) (scanner.Scanner, func(), error) { v := _wireValue clientScanner := client.NewScanner(remoteScanOptions, v...) typesImage, cleanup, err := image.NewContainerImage(ctx, imageName, imageOpt) if err != nil { return scanner.Scanner{}, nil, err } - artifactArtifact, err := image2.NewArtifact(typesImage, artifactCache, artifactOption) + remoteCache := cache.NewRemoteCache(remoteCacheOptions) + artifactArtifact, err := image2.NewArtifact(typesImage, remoteCache, artifactOption) if err != nil { cleanup() return scanner.Scanner{}, nil, err @@ -170,27 +210,30 @@ var ( // initializeRemoteArchiveScanner is for container image archive scanning in client/server mode // e.g. docker save -o alpine.tar alpine:3.15 -func initializeRemoteArchiveScanner(ctx context.Context, filePath string, artifactCache cache.ArtifactCache, remoteScanOptions client.ScannerOption, artifactOption artifact.Option) (scanner.Scanner, error) { +func initializeRemoteArchiveScanner(ctx context.Context, filePath string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ScannerOption, artifactOption artifact.Option) (scanner.Scanner, func(), error) { v := _wireValue clientScanner := client.NewScanner(remoteScanOptions, v...) typesImage, err := image.NewArchiveImage(filePath) if err != nil { - return scanner.Scanner{}, err + return scanner.Scanner{}, nil, err } - artifactArtifact, err := image2.NewArtifact(typesImage, artifactCache, artifactOption) + remoteCache := cache.NewRemoteCache(remoteCacheOptions) + artifactArtifact, err := image2.NewArtifact(typesImage, remoteCache, artifactOption) if err != nil { - return scanner.Scanner{}, err + return scanner.Scanner{}, nil, err } scannerScanner := scanner.NewScanner(clientScanner, artifactArtifact) - return scannerScanner, nil + return scannerScanner, func() { + }, nil } // initializeRemoteFilesystemScanner is for filesystem scanning in client/server mode -func initializeRemoteFilesystemScanner(ctx context.Context, path string, artifactCache cache.ArtifactCache, remoteScanOptions client.ScannerOption, artifactOption artifact.Option) (scanner.Scanner, func(), error) { +func initializeRemoteFilesystemScanner(ctx context.Context, path string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ScannerOption, artifactOption artifact.Option) (scanner.Scanner, func(), error) { v := _wireValue clientScanner := client.NewScanner(remoteScanOptions, v...) + remoteCache := cache.NewRemoteCache(remoteCacheOptions) fs := walker.NewFS() - artifactArtifact, err := local2.NewArtifact(path, artifactCache, fs, artifactOption) + artifactArtifact, err := local2.NewArtifact(path, remoteCache, fs, artifactOption) if err != nil { return scanner.Scanner{}, nil, err } @@ -200,11 +243,12 @@ func initializeRemoteFilesystemScanner(ctx context.Context, path string, artifac } // initializeRemoteRepositoryScanner is for repository scanning in client/server mode -func initializeRemoteRepositoryScanner(ctx context.Context, url string, artifactCache cache.ArtifactCache, remoteScanOptions client.ScannerOption, artifactOption artifact.Option) (scanner.Scanner, func(), error) { +func initializeRemoteRepositoryScanner(ctx context.Context, url string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ScannerOption, artifactOption artifact.Option) (scanner.Scanner, func(), error) { v := _wireValue clientScanner := client.NewScanner(remoteScanOptions, v...) + remoteCache := cache.NewRemoteCache(remoteCacheOptions) fs := walker.NewFS() - artifactArtifact, cleanup, err := repo.NewArtifact(url, artifactCache, fs, artifactOption) + artifactArtifact, cleanup, err := repo.NewArtifact(url, remoteCache, fs, artifactOption) if err != nil { return scanner.Scanner{}, nil, err } @@ -215,10 +259,11 @@ func initializeRemoteRepositoryScanner(ctx context.Context, url string, artifact } // initializeRemoteSBOMScanner is for sbom scanning in client/server mode -func initializeRemoteSBOMScanner(ctx context.Context, path string, artifactCache cache.ArtifactCache, remoteScanOptions client.ScannerOption, artifactOption artifact.Option) (scanner.Scanner, func(), error) { +func initializeRemoteSBOMScanner(ctx context.Context, path string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ScannerOption, artifactOption artifact.Option) (scanner.Scanner, func(), error) { v := _wireValue clientScanner := client.NewScanner(remoteScanOptions, v...) - artifactArtifact, err := sbom.NewArtifact(path, artifactCache, artifactOption) + remoteCache := cache.NewRemoteCache(remoteCacheOptions) + artifactArtifact, err := sbom.NewArtifact(path, remoteCache, artifactOption) if err != nil { return scanner.Scanner{}, nil, err } @@ -228,11 +273,12 @@ func initializeRemoteSBOMScanner(ctx context.Context, path string, artifactCache } // initializeRemoteVMScanner is for vm scanning in client/server mode -func initializeRemoteVMScanner(ctx context.Context, path string, artifactCache cache.ArtifactCache, remoteScanOptions client.ScannerOption, artifactOption artifact.Option) (scanner.Scanner, func(), error) { +func initializeRemoteVMScanner(ctx context.Context, path string, remoteCacheOptions cache.RemoteOptions, remoteScanOptions client.ScannerOption, artifactOption artifact.Option) (scanner.Scanner, func(), error) { v := _wireValue clientScanner := client.NewScanner(remoteScanOptions, v...) + remoteCache := cache.NewRemoteCache(remoteCacheOptions) walkerVM := walker.NewVM() - artifactArtifact, err := vm.NewArtifact(path, artifactCache, walkerVM, artifactOption) + artifactArtifact, err := vm.NewArtifact(path, remoteCache, walkerVM, artifactOption) if err != nil { return scanner.Scanner{}, nil, err } diff --git a/pkg/commands/clean/run.go b/pkg/commands/clean/run.go index f2ac57539d54..fb20799a571b 100644 --- a/pkg/commands/clean/run.go +++ b/pkg/commands/clean/run.go @@ -62,10 +62,12 @@ func cleanAll(ctx context.Context, opts flag.Options) error { func cleanScanCache(ctx context.Context, opts flag.Options) error { log.InfoContext(ctx, "Removing scan cache...") - c, err := cache.New(opts.CacheDir, opts.CacheBackendOptions) + c, cleanup, err := cache.New(opts.CacheOpts()) if err != nil { return xerrors.Errorf("failed to instantiate cache client: %w", err) } + defer cleanup() + if err = c.Clear(); err != nil { return xerrors.Errorf("clear scan cache: %w", err) } diff --git a/pkg/commands/clean/run_test.go b/pkg/commands/clean/run_test.go index a26fef86a572..9b301d238219 100644 --- a/pkg/commands/clean/run_test.go +++ b/pkg/commands/clean/run_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/aquasecurity/trivy/pkg/cache" "github.com/aquasecurity/trivy/pkg/commands/clean" "github.com/aquasecurity/trivy/pkg/flag" ) @@ -100,6 +101,9 @@ func TestRun(t *testing.T) { GlobalOptions: flag.GlobalOptions{ CacheDir: tempDir, }, + CacheOptions: flag.CacheOptions{ + CacheBackend: string(cache.TypeFS), + }, CleanOptions: tt.cleanOpts, } diff --git a/pkg/commands/server/run.go b/pkg/commands/server/run.go index 917f1b4aa459..c5f7b0da2f0b 100644 --- a/pkg/commands/server/run.go +++ b/pkg/commands/server/run.go @@ -19,12 +19,11 @@ func Run(ctx context.Context, opts flag.Options) (err error) { log.InitLogger(opts.Debug, opts.Quiet) // configure cache dir - cacheClient, err := cache.New(opts.CacheDir, opts.CacheOptions.CacheBackendOptions) + cacheClient, cleanup, err := cache.New(opts.CacheOpts()) if err != nil { return xerrors.Errorf("server cache error: %w", err) } - defer cacheClient.Close() - log.Debug("Cache", log.String("dir", opts.CacheDir)) + defer cleanup() // download the database file if err = operation.DownloadDB(ctx, opts.AppVersion, opts.CacheDir, opts.DBRepository, diff --git a/pkg/flag/cache_flags.go b/pkg/flag/cache_flags.go index 73a31fd2684d..786a0c9c7ffe 100644 --- a/pkg/flag/cache_flags.go +++ b/pkg/flag/cache_flags.go @@ -2,10 +2,6 @@ package flag import ( "time" - - "golang.org/x/xerrors" - - "github.com/aquasecurity/trivy/pkg/cache" ) // e.g. config yaml: @@ -71,14 +67,19 @@ type CacheFlagGroup struct { } type CacheOptions struct { - ClearCache bool - CacheBackendOptions cache.Options + ClearCache bool + + CacheBackend string + CacheTTL time.Duration + RedisTLS bool + RedisCACert string + RedisCert string + RedisKey string } // NewCacheFlagGroup returns a default CacheFlagGroup func NewCacheFlagGroup() *CacheFlagGroup { return &CacheFlagGroup{ - ClearCache: ClearCacheFlag.Clone(), CacheBackend: CacheBackendFlag.Clone(), CacheTTL: CacheTTLFlag.Clone(), RedisTLS: RedisTLSFlag.Clone(), @@ -109,14 +110,12 @@ func (fg *CacheFlagGroup) ToOptions() (CacheOptions, error) { return CacheOptions{}, err } - backendOpts, err := cache.NewOptions(fg.CacheBackend.Value(), fg.RedisCACert.Value(), fg.RedisCert.Value(), - fg.RedisKey.Value(), fg.RedisTLS.Value(), fg.CacheTTL.Value()) - if err != nil { - return CacheOptions{}, xerrors.Errorf("failed to initialize cache options: %w", err) - } - return CacheOptions{ - ClearCache: fg.ClearCache.Value(), - CacheBackendOptions: backendOpts, + CacheBackend: fg.CacheBackend.Value(), + CacheTTL: fg.CacheTTL.Value(), + RedisTLS: fg.RedisTLS.Value(), + RedisCACert: fg.RedisCACert.Value(), + RedisCert: fg.RedisCert.Value(), + RedisKey: fg.RedisKey.Value(), }, nil } diff --git a/pkg/flag/global_flags.go b/pkg/flag/global_flags.go index ef19a09dd4e8..ebd79bd5a06c 100644 --- a/pkg/flag/global_flags.go +++ b/pkg/flag/global_flags.go @@ -7,6 +7,7 @@ import ( "github.com/spf13/cobra" "github.com/aquasecurity/trivy/pkg/cache" + "github.com/aquasecurity/trivy/pkg/log" ) var ( @@ -144,6 +145,8 @@ func (f *GlobalFlagGroup) ToOptions() (GlobalOptions, error) { // Keep TRIVY_NON_SSL for backward compatibility insecure := f.Insecure.Value() || os.Getenv("TRIVY_NON_SSL") != "" + log.Debug("Cache dir", log.String("dir", f.CacheDir.Value())) + return GlobalOptions{ ConfigFile: f.ConfigFile.Value(), ShowVersion: f.ShowVersion.Value(), diff --git a/pkg/flag/options.go b/pkg/flag/options.go index 70470dc05e2f..33190fb76fbe 100644 --- a/pkg/flag/options.go +++ b/pkg/flag/options.go @@ -17,6 +17,7 @@ import ( "github.com/spf13/viper" "golang.org/x/xerrors" + "github.com/aquasecurity/trivy/pkg/cache" "github.com/aquasecurity/trivy/pkg/fanal/analyzer" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/log" @@ -448,6 +449,28 @@ func (o *Options) FilterOpts() result.FilterOption { } } +// CacheOpts returns options for scan cache +func (o *Options) CacheOpts() cache.Options { + return cache.Options{ + Backend: o.CacheBackend, + CacheDir: o.CacheDir, + RedisCACert: o.RedisCACert, + RedisCert: o.RedisCert, + RedisKey: o.RedisKey, + RedisTLS: o.RedisTLS, + TTL: o.CacheTTL, + } +} + +// RemoteCacheOpts returns options for remote scan cache +func (o *Options) RemoteCacheOpts() cache.RemoteOptions { + return cache.RemoteOptions{ + ServerAddr: o.ServerAddr, + CustomHeaders: o.CustomHeaders, + Insecure: o.Insecure, + } +} + // SetOutputWriter sets an output writer. func (o *Options) SetOutputWriter(w io.Writer) { o.outputWriter = w diff --git a/pkg/k8s/wire_gen.go b/pkg/k8s/wire_gen.go index 134fa8c1ec49..e6c4f7e0dff7 100644 --- a/pkg/k8s/wire_gen.go +++ b/pkg/k8s/wire_gen.go @@ -8,8 +8,8 @@ package k8s import ( "github.com/aquasecurity/trivy-db/pkg/db" - "github.com/aquasecurity/trivy/pkg/fanal/applier" "github.com/aquasecurity/trivy/pkg/cache" + "github.com/aquasecurity/trivy/pkg/fanal/applier" "github.com/aquasecurity/trivy/pkg/scanner/langpkg" "github.com/aquasecurity/trivy/pkg/scanner/local" "github.com/aquasecurity/trivy/pkg/scanner/ospkg" diff --git a/pkg/rpc/server/wire_gen.go b/pkg/rpc/server/wire_gen.go index 4d667cbe9dfd..bcba35941b9e 100644 --- a/pkg/rpc/server/wire_gen.go +++ b/pkg/rpc/server/wire_gen.go @@ -8,8 +8,8 @@ package server import ( "github.com/aquasecurity/trivy-db/pkg/db" - "github.com/aquasecurity/trivy/pkg/fanal/applier" "github.com/aquasecurity/trivy/pkg/cache" + "github.com/aquasecurity/trivy/pkg/fanal/applier" "github.com/aquasecurity/trivy/pkg/scanner/langpkg" "github.com/aquasecurity/trivy/pkg/scanner/local" "github.com/aquasecurity/trivy/pkg/scanner/ospkg" diff --git a/pkg/scanner/scan.go b/pkg/scanner/scan.go index 7094e38c71fe..f1e4cf68c515 100644 --- a/pkg/scanner/scan.go +++ b/pkg/scanner/scan.go @@ -6,6 +6,7 @@ import ( "github.com/google/wire" "golang.org/x/xerrors" + "github.com/aquasecurity/trivy/pkg/cache" "github.com/aquasecurity/trivy/pkg/clock" "github.com/aquasecurity/trivy/pkg/fanal/artifact" aimage "github.com/aquasecurity/trivy/pkg/fanal/artifact/image" @@ -28,6 +29,11 @@ import ( // StandaloneSuperSet is used in the standalone mode var StandaloneSuperSet = wire.NewSet( + // Cache + cache.New, + wire.Bind(new(cache.ArtifactCache), new(cache.Cache)), + wire.Bind(new(cache.LocalArtifactCache), new(cache.Cache)), + local.SuperSet, wire.Bind(new(Driver), new(local.Scanner)), NewScanner, @@ -77,6 +83,10 @@ var StandaloneVMSet = wire.NewSet( // RemoteSuperSet is used in the client mode var RemoteSuperSet = wire.NewSet( + // Cache + cache.NewRemoteCache, + wire.Bind(new(cache.ArtifactCache), new(*cache.RemoteCache)), // No need for LocalArtifactCache + client.NewScanner, wire.Value([]client.Option(nil)), wire.Bind(new(Driver), new(client.Scanner)),