diff --git a/internal/dbtest/fake.go b/internal/dbtest/fake.go index 528c549bda7f..3a585adb99f5 100644 --- a/internal/dbtest/fake.go +++ b/internal/dbtest/fake.go @@ -13,8 +13,8 @@ import ( "github.com/samber/lo" "github.com/stretchr/testify/require" + "github.com/aquasecurity/trivy/pkg/asset" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" - "github.com/aquasecurity/trivy/pkg/oci" ) const defaultMediaType = "application/vnd.aquasec.trivy.db.layer.v1.tar+gzip" @@ -38,7 +38,7 @@ type FakeDBOptions struct { MediaType types.MediaType } -func NewFakeDB(t *testing.T, dbPath string, opts FakeDBOptions) *oci.Artifact { +func NewFakeDB(t *testing.T, dbPath string, opts FakeDBOptions) *asset.OCI { mediaType := lo.Ternary(opts.MediaType != "", opts.MediaType, defaultMediaType) img := new(fakei.FakeImage) img.LayersReturns([]v1.Layer{NewFakeLayer(t, dbPath, mediaType)}, nil) @@ -59,10 +59,13 @@ func NewFakeDB(t *testing.T, dbPath string, opts FakeDBOptions) *oci.Artifact { }, nil) // Mock OCI artifact - opt := ftypes.RegistryOptions{ - Insecure: false, + assetOpts := asset.Options{ + MediaType: defaultMediaType, + RegistryOptions: ftypes.RegistryOptions{ + Insecure: false, + }, } - return oci.NewArtifact("dummy", opt, oci.WithImage(img)) + return asset.NewOCI("dummy", assetOpts, asset.WithImage(img)) } func ArchiveDir(t *testing.T, dir string) string { diff --git a/pkg/asset/asset.go b/pkg/asset/asset.go new file mode 100644 index 000000000000..d8fde714d0bb --- /dev/null +++ b/pkg/asset/asset.go @@ -0,0 +1,91 @@ +package asset + +import ( + "context" + "errors" + "strings" + + "github.com/google/go-containerregistry/pkg/v1/remote/transport" + "github.com/hashicorp/go-multierror" + "golang.org/x/xerrors" + + "github.com/aquasecurity/trivy/pkg/fanal/types" + "github.com/aquasecurity/trivy/pkg/log" + "github.com/aquasecurity/trivy/pkg/version/doc" +) + +type Options struct { + // For OCI + MediaType string // Accept any media type if not specified + + // Common + Filename string // Use the annotation if not specified + Quiet bool + + types.RegistryOptions +} + +type Assets []Asset + +type Asset interface { + Location() string + Download(ctx context.Context, dst string) error +} + +func NewAssets(locations []string, assetOpts Options, opts ...Option) Assets { + var assets Assets + for _, location := range locations { + switch { + case strings.HasPrefix(location, "https://"): + assets = append(assets, NewHTTP(location, assetOpts)) + default: + assets = append(assets, NewOCI(location, assetOpts, opts...)) + } + } + return assets +} + +// Download downloads artifacts until one of them succeeds. +// Attempts to download next artifact if the first one fails due to a temporary error. +func (a Assets) Download(ctx context.Context, dst string) error { + var errs error + for i, art := range a { + logger := log.With("location", art.Location()) + logger.InfoContext(ctx, "Downloading artifact...") + err := art.Download(ctx, dst) + if err == nil { + logger.InfoContext(ctx, "OCI successfully downloaded") + return nil + } + + if !shouldTryOtherRepo(err) { + return xerrors.Errorf("failed to download artifact from %s: %w", art.Location(), err) + } + logger.ErrorContext(ctx, "Failed to download artifact", log.Err(err)) + if i < len(a)-1 { + log.InfoContext(ctx, "Trying to download artifact from other repository...") // Use the default logger + } + errs = multierror.Append(errs, err) + } + + return xerrors.Errorf("failed to download artifact from any source: %w", errs) +} + +func shouldTryOtherRepo(err error) bool { + var terr *transport.Error + if !errors.As(err, &terr) { + return false + } + + for _, diagnostic := range terr.Errors { + // For better user experience + if diagnostic.Code == transport.DeniedErrorCode || diagnostic.Code == transport.UnauthorizedErrorCode { + // e.g. https://aquasecurity.github.io/trivy/latest/docs/references/troubleshooting/#db + log.Warnf("See %s", doc.URL("/docs/references/troubleshooting/", "db")) + break + } + } + + // try the following artifact only if a temporary error occurs + return terr.Temporary() +} diff --git a/pkg/asset/http.go b/pkg/asset/http.go new file mode 100644 index 000000000000..0666256c0ab2 --- /dev/null +++ b/pkg/asset/http.go @@ -0,0 +1,35 @@ +package asset + +import ( + "context" + + "golang.org/x/xerrors" + + "github.com/aquasecurity/trivy/pkg/downloader" +) + +type HTTP struct { + url string + opts Options +} + +func NewHTTP(location string, assetOpts Options) *HTTP { + return &HTTP{ + url: location, + opts: assetOpts, + } +} + +func (h *HTTP) Location() string { + return h.url +} + +func (h *HTTP) Download(ctx context.Context, dir string) error { + _, err := downloader.Download(ctx, h.url, dir, ".", downloader.Options{ + Insecure: h.opts.Insecure, + }) + if err != nil { + return xerrors.Errorf("failed to download artifact via HTTP: %w", err) + } + return nil +} diff --git a/pkg/asset/oci.go b/pkg/asset/oci.go new file mode 100644 index 000000000000..93bc62face98 --- /dev/null +++ b/pkg/asset/oci.go @@ -0,0 +1,203 @@ +package asset + +import ( + "context" + "io" + "os" + "path/filepath" + "sync" + + "github.com/cheggaaa/pb/v3" + "github.com/google/go-containerregistry/pkg/name" + v1 "github.com/google/go-containerregistry/pkg/v1" + "golang.org/x/xerrors" + + "github.com/aquasecurity/trivy/pkg/downloader" + "github.com/aquasecurity/trivy/pkg/remote" +) + +const ( + // Artifact types + CycloneDXArtifactType = "application/vnd.cyclonedx+json" + SPDXArtifactType = "application/spdx+json" + + // Media types + OCIImageManifest = "application/vnd.oci.image.manifest.v1+json" + + // Annotations + titleAnnotation = "org.opencontainers.image.title" +) + +var SupportedSBOMArtifactTypes = []string{ + CycloneDXArtifactType, + SPDXArtifactType, +} + +// Option is a functional option +type Option func(*OCI) + +// WithImage takes an OCI v1 Image +func WithImage(img v1.Image) Option { + return func(a *OCI) { + a.image = img + } +} + +// OCI is used to download OCI artifacts such as vulnerability database and policies from OCI registries. +type OCI struct { + m sync.Mutex + repository string + opts Options + + image v1.Image // For testing +} + +// NewOCI returns a new instance of the OCI artifact +func NewOCI(repo string, assetOpts Options, opts ...Option) *OCI { + art := &OCI{ + repository: repo, + opts: assetOpts, + } + + for _, o := range opts { + o(art) + } + return art +} + +func (o *OCI) populate(ctx context.Context) error { + if o.image != nil { + return nil + } + + o.m.Lock() + defer o.m.Unlock() + + var nameOpts []name.Option + if o.opts.Insecure { + nameOpts = append(nameOpts, name.Insecure) + } + + ref, err := name.ParseReference(o.repository, nameOpts...) + if err != nil { + return xerrors.Errorf("repository name error (%s): %w", o.repository, err) + } + + o.image, err = remote.Image(ctx, ref, o.opts.RegistryOptions) + if err != nil { + return xerrors.Errorf("OCI repository error: %w", err) + } + return nil +} + +func (o *OCI) Location() string { + return o.repository +} + +func (o *OCI) Download(ctx context.Context, dir string) error { + if err := o.populate(ctx); err != nil { + return err + } + + layers, err := o.image.Layers() + if err != nil { + return xerrors.Errorf("OCI layer error: %w", err) + } + + manifest, err := o.image.Manifest() + if err != nil { + return xerrors.Errorf("OCI manifest error: %w", err) + } + + // A single layer is only supported now. + if len(layers) != 1 || len(manifest.Layers) != 1 { + return xerrors.Errorf("OCI artifact must be a single layer") + } + + // Take the first layer + layer := layers[0] + + // Take the file name of the first layer if not specified + fileName := o.opts.Filename + if fileName == "" { + if v, ok := manifest.Layers[0].Annotations[titleAnnotation]; !ok { + return xerrors.Errorf("annotation %s is missing", titleAnnotation) + } else { + fileName = v + } + } + + layerMediaType, err := layer.MediaType() + if err != nil { + return xerrors.Errorf("media type error: %w", err) + } else if o.opts.MediaType != "" && o.opts.MediaType != string(layerMediaType) { + return xerrors.Errorf("unacceptable media type: %s", string(layerMediaType)) + } + + if err = o.download(ctx, layer, fileName, dir, o.opts.Quiet); err != nil { + return xerrors.Errorf("oci download error: %w", err) + } + + return nil +} + +func (o *OCI) download(ctx context.Context, layer v1.Layer, fileName, dir string, quiet bool) error { + size, err := layer.Size() + if err != nil { + return xerrors.Errorf("size error: %w", err) + } + + rc, err := layer.Compressed() + if err != nil { + return xerrors.Errorf("failed to fetch the layer: %w", err) + } + defer rc.Close() + + // Show progress bar + bar := pb.Full.Start64(size) + if quiet { + bar.SetWriter(io.Discard) + } + pr := bar.NewProxyReader(rc) + defer bar.Finish() + + // https://github.com/hashicorp/go-getter/issues/326 + tempDir, err := os.MkdirTemp("", "trivy") + if err != nil { + return xerrors.Errorf("failed to create o temp dir: %w", err) + } + + f, err := os.Create(filepath.Join(tempDir, fileName)) + if err != nil { + return xerrors.Errorf("failed to create o temp file: %w", err) + } + defer func() { + _ = f.Close() + _ = os.RemoveAll(tempDir) + }() + + // Download the layer content into o temporal file + if _, err = io.Copy(f, pr); err != nil { + return xerrors.Errorf("copy error: %w", err) + } + + // Decompress the downloaded file if it is compressed and copy it into the dst + // NOTE: it's local copying, the insecure option doesn't matter. + if _, err = downloader.Download(ctx, f.Name(), dir, dir, downloader.Options{}); err != nil { + return xerrors.Errorf("download error: %w", err) + } + + return nil +} + +func (o *OCI) Digest(ctx context.Context) (string, error) { + if err := o.populate(ctx); err != nil { + return "", err + } + + digest, err := o.image.Digest() + if err != nil { + return "", xerrors.Errorf("digest error: %w", err) + } + return digest.String(), nil +} diff --git a/pkg/oci/artifact_test.go b/pkg/asset/oci_test.go similarity index 91% rename from pkg/oci/artifact_test.go rename to pkg/asset/oci_test.go index a8ce6e542641..2a9ac3b9d5f8 100644 --- a/pkg/oci/artifact_test.go +++ b/pkg/asset/oci_test.go @@ -1,4 +1,4 @@ -package oci_test +package asset_test import ( "context" @@ -14,8 +14,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" - "github.com/aquasecurity/trivy/pkg/oci" + "github.com/aquasecurity/trivy/pkg/asset" ) type fakeLayer struct { @@ -116,11 +115,12 @@ func TestArtifact_Download(t *testing.T) { }, }, nil) - artifact := oci.NewArtifact("repo", ftypes.RegistryOptions{}, oci.WithImage(img)) - err = artifact.Download(context.Background(), tempDir, oci.DownloadOption{ + artifact := asset.NewOCI("repo", asset.Options{ MediaType: tt.mediaType, Quiet: true, - }) + }, asset.WithImage(img)) + + err = artifact.Download(context.Background(), tempDir) if tt.wantErr != "" { assert.ErrorContains(t, err, tt.wantErr) return diff --git a/pkg/oci/testdata/test.tar.gz b/pkg/asset/testdata/test.tar.gz similarity index 100% rename from pkg/oci/testdata/test.tar.gz rename to pkg/asset/testdata/test.tar.gz diff --git a/pkg/oci/testdata/test.txt b/pkg/asset/testdata/test.txt similarity index 100% rename from pkg/oci/testdata/test.txt rename to pkg/asset/testdata/test.txt diff --git a/pkg/commands/artifact/run.go b/pkg/commands/artifact/run.go index 5018434d10c2..e1898d2fcda5 100644 --- a/pkg/commands/artifact/run.go +++ b/pkg/commands/artifact/run.go @@ -292,7 +292,7 @@ func (r *runner) initDB(ctx context.Context, opts flag.Options) error { // download the database file noProgress := opts.Quiet || opts.NoProgress - if err := operation.DownloadDB(ctx, opts.AppVersion, opts.CacheDir, opts.DBRepositories, noProgress, opts.SkipDBUpdate, opts.RegistryOpts()); err != nil { + if err := operation.DownloadDB(ctx, opts.AppVersion, opts.CacheDir, opts.DBLocations, noProgress, opts.SkipDBUpdate, opts.RegistryOpts()); err != nil { return err } @@ -322,7 +322,7 @@ func (r *runner) initJavaDB(opts flag.Options) error { // Update the Java DB noProgress := opts.Quiet || opts.NoProgress - javadb.Init(opts.CacheDir, opts.JavaDBRepositories, opts.SkipJavaDBUpdate, noProgress, opts.RegistryOpts()) + javadb.Init(opts.CacheDir, opts.JavaDBLocations, opts.SkipJavaDBUpdate, noProgress, opts.RegistryOpts()) if opts.DownloadJavaDBOnly { if err := javadb.Update(); err != nil { return xerrors.Errorf("Java DB error: %w", err) diff --git a/pkg/commands/operation/operation.go b/pkg/commands/operation/operation.go index ac52eee7fb1e..40ebe0450e7d 100644 --- a/pkg/commands/operation/operation.go +++ b/pkg/commands/operation/operation.go @@ -4,7 +4,6 @@ import ( "context" "sync" - "github.com/google/go-containerregistry/pkg/name" "github.com/samber/lo" "golang.org/x/xerrors" @@ -21,7 +20,7 @@ import ( var mu sync.Mutex // DownloadDB downloads the DB -func DownloadDB(ctx context.Context, appVersion, cacheDir string, dbRepositories []name.Reference, quiet, skipUpdate bool, +func DownloadDB(ctx context.Context, appVersion, cacheDir string, dbRepositories []string, quiet, skipUpdate bool, opt ftypes.RegistryOptions) error { mu.Lock() defer mu.Unlock() @@ -97,7 +96,7 @@ func InitBuiltinChecks(ctx context.Context, cacheDir string, quiet, skipUpdate b if needsUpdate { log.InfoContext(ctx, "Need to update the built-in checks") - log.InfoContext(ctx, "Downloading the built-in checks...") + log.InfoContext(ctx, "Downloading the built-in checks...", log.String("repo", checkBundleRepository)) if err = client.DownloadBuiltinChecks(ctx, registryOpts); err != nil { return nil, xerrors.Errorf("failed to download built-in policies: %w", err) } diff --git a/pkg/commands/server/run.go b/pkg/commands/server/run.go index e9187b3442f5..31ceb5ecf45e 100644 --- a/pkg/commands/server/run.go +++ b/pkg/commands/server/run.go @@ -26,7 +26,7 @@ func Run(ctx context.Context, opts flag.Options) (err error) { defer cleanup() // download the database file - if err = operation.DownloadDB(ctx, opts.AppVersion, opts.CacheDir, opts.DBRepositories, + if err = operation.DownloadDB(ctx, opts.AppVersion, opts.CacheDir, opts.DBLocations, true, opts.SkipDBUpdate, opts.RegistryOpts()); err != nil { return err } @@ -50,6 +50,6 @@ func Run(ctx context.Context, opts flag.Options) (err error) { m.Register() server := rpcServer.NewServer(opts.AppVersion, opts.Listen, opts.CacheDir, opts.Token, opts.TokenHeader, - opts.PathPrefix, opts.DBRepositories, opts.RegistryOpts()) + opts.PathPrefix, opts.DBLocations, opts.RegistryOpts()) return server.ListenAndServe(ctx, cacheClient, opts.SkipDBUpdate) } diff --git a/pkg/db/db.go b/pkg/db/db.go index 70fbb93a5a91..f6001be9ba2c 100644 --- a/pkg/db/db.go +++ b/pkg/db/db.go @@ -7,16 +7,14 @@ import ( "path/filepath" "time" - "github.com/google/go-containerregistry/pkg/name" - "github.com/samber/lo" "golang.org/x/xerrors" "github.com/aquasecurity/trivy-db/pkg/db" "github.com/aquasecurity/trivy-db/pkg/metadata" + "github.com/aquasecurity/trivy/pkg/asset" "github.com/aquasecurity/trivy/pkg/clock" "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/log" - "github.com/aquasecurity/trivy/pkg/oci" ) const ( @@ -27,7 +25,6 @@ const ( var ( // GitHub Container Registry DefaultGHCRRepository = fmt.Sprintf("%s:%d", "ghcr.io/aquasecurity/trivy-db", db.SchemaVersion) - defaultGHCRRepository = lo.Must(name.NewTag(DefaultGHCRRepository)) Init = db.Init Close = db.Close @@ -35,24 +32,24 @@ var ( ) type options struct { - artifact *oci.Artifact - dbRepositories []name.Reference + artifact *asset.OCI + dbLocations []string } // Option is a functional option type Option func(*options) // WithOCIArtifact takes an OCI artifact -func WithOCIArtifact(art *oci.Artifact) Option { +func WithOCIArtifact(art *asset.OCI) Option { return func(opts *options) { opts.artifact = art } } // WithDBRepository takes a dbRepository -func WithDBRepository(dbRepository []name.Reference) Option { +func WithDBRepository(dbLocations []string) Option { return func(opts *options) { - opts.dbRepositories = dbRepository + opts.dbLocations = dbLocations } } @@ -72,8 +69,8 @@ func Dir(cacheDir string) string { // NewClient is the factory method for DB client func NewClient(dbDir string, quiet bool, opts ...Option) *Client { o := &options{ - dbRepositories: []name.Reference{ - defaultGHCRRepository, + dbLocations: []string{ + DefaultGHCRRepository, }, } @@ -190,20 +187,21 @@ func (c *Client) updateDownloadedAt(ctx context.Context, dbDir string) error { return nil } -func (c *Client) initArtifacts(opt types.RegistryOptions) oci.Artifacts { +func (c *Client) initArtifacts(opts types.RegistryOptions) asset.Assets { if c.artifact != nil { - return oci.Artifacts{c.artifact} + return asset.Assets{c.artifact} } - return oci.NewArtifacts(c.dbRepositories, opt) + return asset.NewAssets(c.dbLocations, asset.Options{ + MediaType: dbMediaType, + Quiet: c.quiet, + + RegistryOptions: opts, + }) } func (c *Client) downloadDB(ctx context.Context, opt types.RegistryOptions, dst string) error { log.InfoContext(ctx, "Downloading vulnerability DB...") - downloadOpt := oci.DownloadOption{ - MediaType: dbMediaType, - Quiet: c.quiet, - } - if err := c.initArtifacts(opt).Download(ctx, dst, downloadOpt); err != nil { + if err := c.initArtifacts(opt).Download(ctx, dst); err != nil { return xerrors.Errorf("failed to download vulnerability DB: %w", err) } return nil diff --git a/pkg/fanal/analyzer/analyzer_test.go b/pkg/fanal/analyzer/analyzer_test.go index d8f93c0aa420..b1f27af897e8 100644 --- a/pkg/fanal/analyzer/analyzer_test.go +++ b/pkg/fanal/analyzer/analyzer_test.go @@ -7,7 +7,6 @@ import ( "sync" "testing" - "github.com/google/go-containerregistry/pkg/name" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/sync/semaphore" @@ -623,9 +622,7 @@ func TestAnalyzerGroup_PostAnalyze(t *testing.T) { if tt.analyzerType == analyzer.TypeJar { // init java-trivy-db with skip update - repo, err := name.NewTag(javadb.DefaultGHCRRepository) - require.NoError(t, err) - javadb.Init("./language/java/jar/testdata", []name.Reference{repo}, true, false, types.RegistryOptions{Insecure: false}) + javadb.Init("./language/java/jar/testdata", []string{javadb.DefaultGHCRRepository}, true, false, types.RegistryOptions{Insecure: false}) } ctx := context.Background() diff --git a/pkg/fanal/analyzer/language/java/jar/jar_test.go b/pkg/fanal/analyzer/language/java/jar/jar_test.go index 58e7221066ac..7b0e7fa1c92d 100644 --- a/pkg/fanal/analyzer/language/java/jar/jar_test.go +++ b/pkg/fanal/analyzer/language/java/jar/jar_test.go @@ -6,7 +6,6 @@ import ( "path/filepath" "testing" - "github.com/google/go-containerregistry/pkg/name" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -128,15 +127,13 @@ func Test_javaLibraryAnalyzer_Analyze(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // init java-trivy-db with skip update - repo, err := name.NewTag(javadb.DefaultGHCRRepository) - require.NoError(t, err) - javadb.Init("testdata", []name.Reference{repo}, true, false, types.RegistryOptions{Insecure: false}) + javadb.Init("testdata", []string{javadb.DefaultGHCRRepository}, true, false, types.RegistryOptions{Insecure: false}) a := javaLibraryAnalyzer{} ctx := context.Background() mfs := mapfs.New() - err = mfs.MkdirAll(filepath.Dir(tt.inputFile), os.ModePerm) + err := mfs.MkdirAll(filepath.Dir(tt.inputFile), os.ModePerm) require.NoError(t, err) err = mfs.WriteFile(tt.inputFile, tt.inputFile) require.NoError(t, err) diff --git a/pkg/fanal/artifact/image/remote_sbom.go b/pkg/fanal/artifact/image/remote_sbom.go index d07d9cafe3ac..0abd20e4b70d 100644 --- a/pkg/fanal/artifact/image/remote_sbom.go +++ b/pkg/fanal/artifact/image/remote_sbom.go @@ -13,12 +13,12 @@ import ( "github.com/samber/lo" "golang.org/x/xerrors" + "github.com/aquasecurity/trivy/pkg/asset" sbomatt "github.com/aquasecurity/trivy/pkg/attestation/sbom" "github.com/aquasecurity/trivy/pkg/fanal/artifact" "github.com/aquasecurity/trivy/pkg/fanal/artifact/sbom" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/log" - "github.com/aquasecurity/trivy/pkg/oci" "github.com/aquasecurity/trivy/pkg/remote" "github.com/aquasecurity/trivy/pkg/types" ) @@ -70,7 +70,7 @@ func (a Artifact) inspectOCIReferrerSBOM(ctx context.Context) (artifact.Referenc } for _, m := range lo.FromPtr(manifest).Manifests { // Unsupported artifact type - if !slices.Contains(oci.SupportedSBOMArtifactTypes, m.ArtifactType) { + if !slices.Contains(asset.SupportedSBOMArtifactTypes, m.ArtifactType) { continue } res, err := a.parseReferrer(ctx, digest.Context().String(), m) @@ -95,12 +95,13 @@ func (a Artifact) parseReferrer(ctx context.Context, repo string, desc v1.Descri defer os.RemoveAll(tmpDir) // Download SBOM to local filesystem - referrer := oci.NewArtifact(repoName, a.artifactOption.ImageOption.RegistryOptions) - if err = referrer.Download(ctx, tmpDir, oci.DownloadOption{ - MediaType: desc.ArtifactType, - Filename: fileName, - Quiet: true, - }); err != nil { + referrer := asset.NewOCI(repoName, asset.Options{ + MediaType: desc.ArtifactType, + Filename: fileName, + Quiet: true, + RegistryOptions: a.artifactOption.ImageOption.RegistryOptions, + }) + if err = referrer.Download(ctx, tmpDir); err != nil { return artifact.Reference{}, xerrors.Errorf("SBOM download error: %w", err) } diff --git a/pkg/flag/db_flags.go b/pkg/flag/db_flags.go index df0d6c6f5194..15d9e141ef0b 100644 --- a/pkg/flag/db_flags.go +++ b/pkg/flag/db_flags.go @@ -2,6 +2,8 @@ package flag import ( "fmt" + "net/url" + "strings" "github.com/google/go-containerregistry/pkg/name" "golang.org/x/xerrors" @@ -90,8 +92,8 @@ type DBOptions struct { DownloadJavaDBOnly bool SkipJavaDBUpdate bool NoProgress bool - DBRepositories []name.Reference - JavaDBRepositories []name.Reference + DBLocations []string + JavaDBLocations []string } // NewDBFlagGroup returns a default DBFlagGroup @@ -147,21 +149,21 @@ func (f *DBFlagGroup) ToOptions() (DBOptions, error) { return DBOptions{}, xerrors.New("--skip-java-db-update and --download-java-db-only options can not be specified both") } - var dbRepositories, javaDBRepositories []name.Reference + var dbLocations, javaDBLocations []string for _, repo := range f.DBRepositories.Value() { - ref, err := parseRepository(repo, db.SchemaVersion) + ref, err := parseLocation(repo, db.SchemaVersion) if err != nil { - return DBOptions{}, xerrors.Errorf("invalid DB repository: %w", err) + return DBOptions{}, xerrors.Errorf("invalid DB location: %w", err) } - dbRepositories = append(dbRepositories, ref) + dbLocations = append(dbLocations, ref) } for _, repo := range f.JavaDBRepositories.Value() { - ref, err := parseRepository(repo, javadb.SchemaVersion) + ref, err := parseLocation(repo, javadb.SchemaVersion) if err != nil { - return DBOptions{}, xerrors.Errorf("invalid javadb repository: %w", err) + return DBOptions{}, xerrors.Errorf("invalid javadb location: %w", err) } - javaDBRepositories = append(javaDBRepositories, ref) + javaDBLocations = append(javaDBLocations, ref) } return DBOptions{ @@ -171,26 +173,41 @@ func (f *DBFlagGroup) ToOptions() (DBOptions, error) { DownloadJavaDBOnly: downloadJavaDBOnly, SkipJavaDBUpdate: skipJavaDBUpdate, NoProgress: f.NoProgress.Value(), - DBRepositories: dbRepositories, - JavaDBRepositories: javaDBRepositories, + DBLocations: dbLocations, + JavaDBLocations: javaDBLocations, }, nil } -func parseRepository(repo string, dbSchemaVersion int) (name.Reference, error) { +// TODO: Remove this function after the deprecation period +func parseLocation(location string, dbSchemaVersion int) (string, error) { + if strings.HasPrefix(location, "https://") { + return parseURL(location) + } + return parseRepository(location, dbSchemaVersion) +} + +func parseURL(location string) (string, error) { + if _, err := url.Parse(location); err != nil { + return "", xerrors.Errorf("invalid URL format: %w", err) + } + return location, nil +} + +func parseRepository(repo string, dbSchemaVersion int) (string, error) { dbRepository, err := name.ParseReference(repo, name.WithDefaultTag("")) if err != nil { - return nil, err + return "", err } // Add the schema version if the tag is not specified for backward compatibility. t, ok := dbRepository.(name.Tag) if !ok || t.TagStr() != "" { - return dbRepository, nil + return dbRepository.String(), nil } dbRepository = t.Tag(fmt.Sprint(dbSchemaVersion)) log.Info("Adding schema version to the DB repository for backward compatibility", log.String("repository", dbRepository.String())) - return dbRepository, nil + return dbRepository.String(), nil } diff --git a/pkg/flag/db_flags_test.go b/pkg/flag/db_flags_test.go index 4f742e74ed68..77a2b43c42da 100644 --- a/pkg/flag/db_flags_test.go +++ b/pkg/flag/db_flags_test.go @@ -3,7 +3,6 @@ package flag_test import ( "testing" - "github.com/google/go-containerregistry/pkg/name" "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -36,10 +35,10 @@ func TestDBFlagGroup_ToOptions(t *testing.T) { JavaDBRepository: []string{"ghcr.io/aquasecurity/trivy-java-db"}, }, want: flag.DBOptions{ - SkipDBUpdate: true, - DownloadDBOnly: false, - DBRepositories: []name.Reference{name.Tag{}}, // All fields are unexported - JavaDBRepositories: []name.Reference{name.Tag{}}, // All fields are unexported + SkipDBUpdate: true, + DownloadDBOnly: false, + DBLocations: []string{"ghcr.io/aquasecurity/trivy-db:2"}, + JavaDBLocations: []string{"ghcr.io/aquasecurity/trivy-java-db:1"}, }, wantLogs: []string{ `Adding schema version to the DB repository for backward compatibility repository="ghcr.io/aquasecurity/trivy-db:2"`, @@ -61,21 +60,33 @@ func TestDBFlagGroup_ToOptions(t *testing.T) { DownloadDBOnly: false, DBRepository: []string{"foo:bar:baz"}, }, - wantErr: "invalid DB repository", + wantErr: "invalid DB location", }, { name: "multiple repos", fields: fields{ - SkipDBUpdate: true, - DownloadDBOnly: false, - DBRepository: []string{"ghcr.io/aquasecurity/trivy-db:2", "gallery.ecr.aws/aquasecurity/trivy-db:2"}, - JavaDBRepository: []string{"ghcr.io/aquasecurity/trivy-java-db:1", "gallery.ecr.aws/aquasecurity/trivy-java-db:1"}, + SkipDBUpdate: true, + DownloadDBOnly: false, + DBRepository: []string{ + "ghcr.io/aquasecurity/trivy-db:2", + "gallery.ecr.aws/aquasecurity/trivy-db:2", + }, + JavaDBRepository: []string{ + "ghcr.io/aquasecurity/trivy-java-db:1", + "gallery.ecr.aws/aquasecurity/trivy-java-db:1", + }, }, want: flag.DBOptions{ - SkipDBUpdate: true, - DownloadDBOnly: false, - DBRepositories: []name.Reference{name.Tag{}, name.Tag{}}, // All fields are unexported - JavaDBRepositories: []name.Reference{name.Tag{}, name.Tag{}}, // All fields are unexported + SkipDBUpdate: true, + DownloadDBOnly: false, + DBLocations: []string{ + "ghcr.io/aquasecurity/trivy-db:2", + "gallery.ecr.aws/aquasecurity/trivy-db:2", + }, + JavaDBLocations: []string{ + "ghcr.io/aquasecurity/trivy-java-db:1", + "gallery.ecr.aws/aquasecurity/trivy-java-db:1", + }, }, }, } @@ -97,12 +108,10 @@ func TestDBFlagGroup_ToOptions(t *testing.T) { } got, err := f.ToOptions() if tt.wantErr != "" { - require.Error(t, err) assert.ErrorContains(t, err, tt.wantErr) return } require.NoError(t, err) - assert.EqualExportedValues(t, tt.want, got) // Assert log messages diff --git a/pkg/javadb/client.go b/pkg/javadb/client.go index 835730109b02..6974302f7c38 100644 --- a/pkg/javadb/client.go +++ b/pkg/javadb/client.go @@ -10,15 +10,14 @@ import ( "sync" "time" - "github.com/google/go-containerregistry/pkg/name" "golang.org/x/xerrors" "github.com/aquasecurity/trivy-java-db/pkg/db" "github.com/aquasecurity/trivy-java-db/pkg/types" + "github.com/aquasecurity/trivy/pkg/asset" "github.com/aquasecurity/trivy/pkg/dependency/parser/java/jar" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/log" - "github.com/aquasecurity/trivy/pkg/oci" ) const ( @@ -34,7 +33,7 @@ var ( var updater *Updater type Updater struct { - repos []name.Reference + locations []string dbDir string skip bool quiet bool @@ -60,7 +59,7 @@ func (u *Updater) Update() error { // Download DB // TODO: support remote options if err := u.downloadDB(ctx); err != nil { - return xerrors.Errorf("OCI artifact error: %w", err) + return xerrors.Errorf("download error: %w", err) } // Parse the newly downloaded metadata.json @@ -98,21 +97,22 @@ func (u *Updater) isNewDB(ctx context.Context, meta db.Metadata) bool { func (u *Updater) downloadDB(ctx context.Context) error { log.InfoContext(ctx, "Downloading Java DB...") - artifacts := oci.NewArtifacts(u.repos, u.registryOption) - downloadOpt := oci.DownloadOption{ + assets := asset.NewAssets(u.locations, asset.Options{ MediaType: mediaType, Quiet: u.quiet, - } - if err := artifacts.Download(ctx, u.dbDir, downloadOpt); err != nil { + + RegistryOptions: u.registryOption, + }) + if err := assets.Download(ctx, u.dbDir); err != nil { return xerrors.Errorf("failed to download Java DB: %w", err) } return nil } -func Init(cacheDir string, javaDBRepositories []name.Reference, skip, quiet bool, registryOption ftypes.RegistryOptions) { +func Init(cacheDir string, javaDBLocations []string, skip, quiet bool, registryOption ftypes.RegistryOptions) { updater = &Updater{ - repos: javaDBRepositories, + locations: javaDBLocations, dbDir: dbDir(cacheDir), skip: skip, quiet: quiet, diff --git a/pkg/module/command.go b/pkg/module/command.go index a74da8384c3f..0dfa9a22d457 100644 --- a/pkg/module/command.go +++ b/pkg/module/command.go @@ -8,26 +8,30 @@ import ( "github.com/google/go-containerregistry/pkg/name" "golang.org/x/xerrors" + "github.com/aquasecurity/trivy/pkg/asset" "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/log" - "github.com/aquasecurity/trivy/pkg/oci" ) const mediaType = "application/vnd.module.wasm.content.layer.v1+wasm" // Install installs a module -func Install(ctx context.Context, dir, repo string, quiet bool, opt types.RegistryOptions) error { +func Install(ctx context.Context, dir, repo string, quiet bool, opts types.RegistryOptions) error { ref, err := name.ParseReference(repo) if err != nil { return xerrors.Errorf("repository parse error: %w", err) } log.Info("Installing the module from the repository...", log.String("repo", repo)) - art := oci.NewArtifact(repo, opt) + art := asset.NewOCI(repo, asset.Options{ + MediaType: mediaType, + Quiet: quiet, + RegistryOptions: opts, + }) dst := filepath.Join(dir, ref.Context().Name()) log.Debug("Installing the module...", log.String("dst", dst)) - if err = art.Download(ctx, dst, oci.DownloadOption{MediaType: mediaType, Quiet: quiet}); err != nil { + if err = art.Download(ctx, dst); err != nil { return xerrors.Errorf("module download error: %w", err) } diff --git a/pkg/oci/artifact.go b/pkg/oci/artifact.go deleted file mode 100644 index 8ed7dcdad03d..000000000000 --- a/pkg/oci/artifact.go +++ /dev/null @@ -1,267 +0,0 @@ -package oci - -import ( - "context" - "errors" - "io" - "os" - "path/filepath" - "sync" - - "github.com/cheggaaa/pb/v3" - "github.com/google/go-containerregistry/pkg/name" - v1 "github.com/google/go-containerregistry/pkg/v1" - "github.com/google/go-containerregistry/pkg/v1/remote/transport" - "github.com/hashicorp/go-multierror" - "github.com/samber/lo" - "golang.org/x/xerrors" - - "github.com/aquasecurity/trivy/pkg/downloader" - "github.com/aquasecurity/trivy/pkg/fanal/types" - "github.com/aquasecurity/trivy/pkg/log" - "github.com/aquasecurity/trivy/pkg/remote" - "github.com/aquasecurity/trivy/pkg/version/doc" -) - -const ( - // Artifact types - CycloneDXArtifactType = "application/vnd.cyclonedx+json" - SPDXArtifactType = "application/spdx+json" - - // Media types - OCIImageManifest = "application/vnd.oci.image.manifest.v1+json" - - // Annotations - titleAnnotation = "org.opencontainers.image.title" -) - -var SupportedSBOMArtifactTypes = []string{ - CycloneDXArtifactType, - SPDXArtifactType, -} - -// Option is a functional option -type Option func(*Artifact) - -// WithImage takes an OCI v1 Image -func WithImage(img v1.Image) Option { - return func(a *Artifact) { - a.image = img - } -} - -// Artifact is used to download artifacts such as vulnerability database and policies from OCI registries. -type Artifact struct { - m sync.Mutex - repository string - - // For OCI registries - types.RegistryOptions - - image v1.Image // For testing -} - -// NewArtifact returns a new artifact -func NewArtifact(repo string, registryOpt types.RegistryOptions, opts ...Option) *Artifact { - art := &Artifact{ - repository: repo, - RegistryOptions: registryOpt, - } - - for _, o := range opts { - o(art) - } - return art -} - -func (a *Artifact) populate(ctx context.Context, opt types.RegistryOptions) error { - if a.image != nil { - return nil - } - - a.m.Lock() - defer a.m.Unlock() - - var nameOpts []name.Option - if opt.Insecure { - nameOpts = append(nameOpts, name.Insecure) - } - - ref, err := name.ParseReference(a.repository, nameOpts...) - if err != nil { - return xerrors.Errorf("repository name error (%s): %w", a.repository, err) - } - - a.image, err = remote.Image(ctx, ref, opt) - if err != nil { - return xerrors.Errorf("OCI repository error: %w", err) - } - return nil -} - -type DownloadOption struct { - MediaType string // Accept any media type if not specified - Filename string // Use the annotation if not specified - Quiet bool -} - -func (a *Artifact) Download(ctx context.Context, dir string, opt DownloadOption) error { - if err := a.populate(ctx, a.RegistryOptions); err != nil { - return err - } - - layers, err := a.image.Layers() - if err != nil { - return xerrors.Errorf("OCI layer error: %w", err) - } - - manifest, err := a.image.Manifest() - if err != nil { - return xerrors.Errorf("OCI manifest error: %w", err) - } - - // A single layer is only supported now. - if len(layers) != 1 || len(manifest.Layers) != 1 { - return xerrors.Errorf("OCI artifact must be a single layer") - } - - // Take the first layer - layer := layers[0] - - // Take the file name of the first layer if not specified - fileName := opt.Filename - if fileName == "" { - if v, ok := manifest.Layers[0].Annotations[titleAnnotation]; !ok { - return xerrors.Errorf("annotation %s is missing", titleAnnotation) - } else { - fileName = v - } - } - - layerMediaType, err := layer.MediaType() - if err != nil { - return xerrors.Errorf("media type error: %w", err) - } else if opt.MediaType != "" && opt.MediaType != string(layerMediaType) { - return xerrors.Errorf("unacceptable media type: %s", string(layerMediaType)) - } - - if err = a.download(ctx, layer, fileName, dir, opt.Quiet); err != nil { - return xerrors.Errorf("oci download error: %w", err) - } - - return nil -} - -func (a *Artifact) download(ctx context.Context, layer v1.Layer, fileName, dir string, quiet bool) error { - size, err := layer.Size() - if err != nil { - return xerrors.Errorf("size error: %w", err) - } - - rc, err := layer.Compressed() - if err != nil { - return xerrors.Errorf("failed to fetch the layer: %w", err) - } - defer rc.Close() - - // Show progress bar - bar := pb.Full.Start64(size) - if quiet { - bar.SetWriter(io.Discard) - } - pr := bar.NewProxyReader(rc) - defer bar.Finish() - - // https://github.com/hashicorp/go-getter/issues/326 - tempDir, err := os.MkdirTemp("", "trivy") - if err != nil { - return xerrors.Errorf("failed to create a temp dir: %w", err) - } - - f, err := os.Create(filepath.Join(tempDir, fileName)) - if err != nil { - return xerrors.Errorf("failed to create a temp file: %w", err) - } - defer func() { - _ = f.Close() - _ = os.RemoveAll(tempDir) - }() - - // Download the layer content into a temporal file - if _, err = io.Copy(f, pr); err != nil { - return xerrors.Errorf("copy error: %w", err) - } - - // Decompress the downloaded file if it is compressed and copy it into the dst - // NOTE: it's local copying, the insecure option doesn't matter. - if _, err = downloader.Download(ctx, f.Name(), dir, dir, downloader.Options{}); err != nil { - return xerrors.Errorf("download error: %w", err) - } - - return nil -} - -func (a *Artifact) Digest(ctx context.Context) (string, error) { - if err := a.populate(ctx, a.RegistryOptions); err != nil { - return "", err - } - - digest, err := a.image.Digest() - if err != nil { - return "", xerrors.Errorf("digest error: %w", err) - } - return digest.String(), nil -} - -type Artifacts []*Artifact - -// NewArtifacts returns a slice of artifacts. -func NewArtifacts(repos []name.Reference, opt types.RegistryOptions, opts ...Option) Artifacts { - return lo.Map(repos, func(r name.Reference, _ int) *Artifact { - return NewArtifact(r.String(), opt, opts...) - }) -} - -// Download downloads artifacts until one of them succeeds. -// Attempts to download next artifact if the first one fails due to a temporary error. -func (a Artifacts) Download(ctx context.Context, dst string, opt DownloadOption) error { - var errs error - for i, art := range a { - log.InfoContext(ctx, "Downloading artifact...", log.String("repo", art.repository)) - err := art.Download(ctx, dst, opt) - if err == nil { - log.InfoContext(ctx, "Artifact successfully downloaded", log.String("repo", art.repository)) - return nil - } - - if !shouldTryOtherRepo(err) { - return xerrors.Errorf("failed to download artifact from %s: %w", art.repository, err) - } - log.ErrorContext(ctx, "Failed to download artifact", log.String("repo", art.repository), log.Err(err)) - if i < len(a)-1 { - log.InfoContext(ctx, "Trying to download artifact from other repository...") - } - errs = multierror.Append(errs, err) - } - - return xerrors.Errorf("failed to download artifact from any source: %w", errs) -} - -func shouldTryOtherRepo(err error) bool { - var terr *transport.Error - if !errors.As(err, &terr) { - return false - } - - for _, diagnostic := range terr.Errors { - // For better user experience - if diagnostic.Code == transport.DeniedErrorCode || diagnostic.Code == transport.UnauthorizedErrorCode { - // e.g. https://aquasecurity.github.io/trivy/latest/docs/references/troubleshooting/#db - log.Warnf("See %s", doc.URL("/docs/references/troubleshooting/", "db")) - break - } - } - - // try the following artifact only if a temporary error occurs - return terr.Temporary() -} diff --git a/pkg/policy/policy.go b/pkg/policy/policy.go index 670588868dcb..7df8a0446d0f 100644 --- a/pkg/policy/policy.go +++ b/pkg/policy/policy.go @@ -12,9 +12,9 @@ import ( "golang.org/x/xerrors" "k8s.io/utils/clock" + "github.com/aquasecurity/trivy/pkg/asset" "github.com/aquasecurity/trivy/pkg/fanal/types" "github.com/aquasecurity/trivy/pkg/log" - "github.com/aquasecurity/trivy/pkg/oci" ) const ( @@ -25,12 +25,12 @@ const ( ) type options struct { - artifact *oci.Artifact + artifact *asset.OCI clock clock.Clock } // WithOCIArtifact takes an OCI artifact -func WithOCIArtifact(art *oci.Artifact) Option { +func WithOCIArtifact(art *asset.OCI) Option { return func(opts *options) { opts.artifact = art } @@ -92,7 +92,11 @@ func NewClient(cacheDir string, quiet bool, checkBundleRepo string, opts ...Opti func (c *Client) populateOCIArtifact(ctx context.Context, registryOpts types.RegistryOptions) { if c.artifact == nil { log.DebugContext(ctx, "Loading check bundle", log.String("repository", c.checkBundleRepo)) - c.artifact = oci.NewArtifact(c.checkBundleRepo, registryOpts) + c.artifact = asset.NewOCI(c.checkBundleRepo, asset.Options{ + MediaType: policyMediaType, + Quiet: c.quiet, + RegistryOptions: registryOpts, + }) } } @@ -101,11 +105,7 @@ func (c *Client) DownloadBuiltinChecks(ctx context.Context, registryOpts types.R c.populateOCIArtifact(ctx, registryOpts) dst := c.contentDir() - if err := c.artifact.Download(ctx, dst, oci.DownloadOption{ - MediaType: policyMediaType, - Quiet: c.quiet, - }, - ); err != nil { + if err := c.artifact.Download(ctx, dst); err != nil { return xerrors.Errorf("download error: %w", err) } diff --git a/pkg/policy/policy_test.go b/pkg/policy/policy_test.go index 4752fa4ce7fc..238c531d7af3 100644 --- a/pkg/policy/policy_test.go +++ b/pkg/policy/policy_test.go @@ -19,8 +19,8 @@ import ( "k8s.io/utils/clock" fake "k8s.io/utils/clock/testing" + "github.com/aquasecurity/trivy/pkg/asset" ftypes "github.com/aquasecurity/trivy/pkg/fanal/types" - "github.com/aquasecurity/trivy/pkg/oci" "github.com/aquasecurity/trivy/pkg/policy" ) @@ -116,14 +116,13 @@ func TestClient_LoadBuiltinPolicies(t *testing.T) { }, nil) // Mock OCI artifact - art := oci.NewArtifact("repo", ftypes.RegistryOptions{}, oci.WithImage(img)) + art := asset.NewOCI("repo", asset.Options{}, asset.WithImage(img)) c, err := policy.NewClient(tt.cacheDir, true, "", policy.WithOCIArtifact(art)) require.NoError(t, err) got, err := c.LoadBuiltinChecks() if tt.wantErr != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.wantErr) + assert.ErrorContains(t, err, tt.wantErr) return } require.NoError(t, err) @@ -255,7 +254,7 @@ func TestClient_NeedsUpdate(t *testing.T) { require.NoError(t, err) } - art := oci.NewArtifact("repo", ftypes.RegistryOptions{}, oci.WithImage(img)) + art := asset.NewOCI("repo", asset.Options{}, asset.WithImage(img)) c, err := policy.NewClient(tmpDir, true, "", policy.WithOCIArtifact(art), policy.WithClock(tt.clock)) require.NoError(t, err) @@ -357,7 +356,7 @@ func TestClient_DownloadBuiltinPolicies(t *testing.T) { }, nil) // Mock OCI artifact - art := oci.NewArtifact("repo", ftypes.RegistryOptions{}, oci.WithImage(img)) + art := asset.NewOCI("repo", asset.Options{}, asset.WithImage(img)) c, err := policy.NewClient(tempDir, true, "", policy.WithClock(tt.clock), policy.WithOCIArtifact(art)) require.NoError(t, err) diff --git a/pkg/rpc/server/listen.go b/pkg/rpc/server/listen.go index c1bc5033530c..58beca3359af 100644 --- a/pkg/rpc/server/listen.go +++ b/pkg/rpc/server/listen.go @@ -10,7 +10,6 @@ import ( "time" "github.com/NYTimes/gziphandler" - "github.com/google/go-containerregistry/pkg/name" "github.com/twitchtv/twirp" "golang.org/x/xerrors" @@ -29,21 +28,21 @@ const updateInterval = 1 * time.Hour // Server represents Trivy server type Server struct { - appVersion string - addr string - cacheDir string - dbDir string - token string - tokenHeader string - pathPrefix string - dbRepositories []name.Reference + appVersion string + addr string + cacheDir string + dbDir string + token string + tokenHeader string + pathPrefix string + dbLocations []string // For OCI registries types.RegistryOptions } // NewServer returns an instance of Server -func NewServer(appVersion, addr, cacheDir, token, tokenHeader, pathPrefix string, dbRepositories []name.Reference, opt types.RegistryOptions) Server { +func NewServer(appVersion, addr, cacheDir, token, tokenHeader, pathPrefix string, dbLocations []string, opt types.RegistryOptions) Server { return Server{ appVersion: appVersion, addr: addr, @@ -52,7 +51,7 @@ func NewServer(appVersion, addr, cacheDir, token, tokenHeader, pathPrefix string token: token, tokenHeader: tokenHeader, pathPrefix: pathPrefix, - dbRepositories: dbRepositories, + dbLocations: dbLocations, RegistryOptions: opt, } } @@ -63,7 +62,7 @@ func (s Server) ListenAndServe(ctx context.Context, serverCache cache.Cache, ski dbUpdateWg := &sync.WaitGroup{} go func() { - worker := newDBWorker(db.NewClient(s.dbDir, true, db.WithDBRepository(s.dbRepositories))) + worker := newDBWorker(db.NewClient(s.dbDir, true, db.WithDBRepository(s.dbLocations))) for { time.Sleep(updateInterval) if err := worker.update(ctx, s.appVersion, s.dbDir, skipDBUpdate, dbUpdateWg, requestWg, s.RegistryOptions); err != nil {