Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: add generic Set implementation #8149

Merged
merged 16 commits into from
Dec 24, 2024
5 changes: 5 additions & 0 deletions misc/lint/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,8 @@ func errorsJoin(m dsl.Matcher) {
m.Match(`errors.Join($*args)`).
Report("use github.com/hashicorp/go-multierror.Append instead of errors.Join.")
}

func mapSet(m dsl.Matcher) {
m.Match(`map[$x]struct{}`).
Report("use github.com/aquasecurity/trivy/pkg/set.Set instead of map.")
}
8 changes: 4 additions & 4 deletions pkg/compliance/spec/compliance.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ import (
"path/filepath"
"strings"

"github.com/samber/lo"
"golang.org/x/xerrors"
"gopkg.in/yaml.v3"

sp "github.com/aquasecurity/trivy-checks/pkg/spec"
iacTypes "github.com/aquasecurity/trivy/pkg/iac/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/types"
)

Expand All @@ -31,17 +31,17 @@ const (

// Scanners reads spec control and determines the scanners by check ID prefix
func (cs *ComplianceSpec) Scanners() (types.Scanners, error) {
scannerTypes := make(map[types.Scanner]struct{})
scannerTypes := set.New[types.Scanner]()
for _, control := range cs.Spec.Controls {
for _, check := range control.Checks {
scannerType := scannerByCheckID(check.ID)
if scannerType == types.UnknownScanner {
return nil, xerrors.Errorf("unsupported check ID: %s", check.ID)
}
scannerTypes[scannerType] = struct{}{}
scannerTypes.Append(scannerType)
}
}
return lo.Keys(scannerTypes), nil
return scannerTypes.Items(), nil
}

// CheckIDs return list of compliance check IDs
Expand Down
3 changes: 2 additions & 1 deletion pkg/dependency/parser/java/pom/artifact.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/version/doc"
)

Expand All @@ -30,7 +31,7 @@ type artifact struct {
Version version
Licenses []string

Exclusions map[string]struct{}
Exclusions set.Set[string]

Module bool
Relationship ftypes.Relationship
Expand Down
30 changes: 17 additions & 13 deletions pkg/dependency/parser/java/pom/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/aquasecurity/trivy/pkg/dependency/parser/utils"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
xio "github.com/aquasecurity/trivy/pkg/x/io"
)

Expand Down Expand Up @@ -118,11 +119,11 @@ func (p *Parser) Parse(r xio.ReadSeekerAt) ([]ftypes.Package, []ftypes.Dependenc
rootArt := root.artifact()
rootArt.Relationship = ftypes.RelationshipRoot

return p.parseRoot(rootArt, make(map[string]struct{}))
return p.parseRoot(rootArt, set.New[string]())
}

// nolint: gocyclo
func (p *Parser) parseRoot(root artifact, uniqModules map[string]struct{}) ([]ftypes.Package, []ftypes.Dependency, error) {
func (p *Parser) parseRoot(root artifact, uniqModules set.Set[string]) ([]ftypes.Package, []ftypes.Dependency, error) {
// Prepare a queue for dependencies
queue := newArtifactQueue()

Expand All @@ -145,10 +146,10 @@ func (p *Parser) parseRoot(root artifact, uniqModules map[string]struct{}) ([]ft
// Modules should be handled separately so that they can have independent dependencies.
// It means multi-module allows for duplicate dependencies.
if art.Module {
if _, ok := uniqModules[art.String()]; ok {
if uniqModules.Contains(art.String()) {
continue
}
uniqModules[art.String()] = struct{}{}
uniqModules.Append(art.String())

modulePkgs, moduleDeps, err := p.parseRoot(art, uniqModules)
if err != nil {
Expand Down Expand Up @@ -251,7 +252,7 @@ func (p *Parser) parseRoot(root artifact, uniqModules map[string]struct{}) ([]ft
// `mvn` shows modules separately from the root package and does not show module nesting.
// So we can add all modules as dependencies of root package.
if art.Relationship == ftypes.RelationshipRoot {
dependsOn = append(dependsOn, lo.Keys(uniqModules)...)
dependsOn = append(dependsOn, uniqModules.Items()...)
}

sort.Strings(dependsOn)
Expand Down Expand Up @@ -340,14 +341,17 @@ type analysisResult struct {
}

type analysisOptions struct {
exclusions map[string]struct{}
exclusions set.Set[string]
depManagement []pomDependency // from the root POM
}

func (p *Parser) analyze(pom *pom, opts analysisOptions) (analysisResult, error) {
if pom.nil() {
return analysisResult{}, nil
}
if opts.exclusions == nil {
opts.exclusions = set.New[string]()
}
// Update remoteRepositories
pomReleaseRemoteRepos, pomSnapshotRemoteRepos := pom.repositories(p.servers)
p.releaseRemoteRepos = lo.Uniq(append(pomReleaseRemoteRepos, p.releaseRemoteRepos...))
Expand Down Expand Up @@ -408,16 +412,16 @@ func (p *Parser) resolveParent(pom *pom) error {
}

func (p *Parser) mergeDependencyManagements(depManagements ...[]pomDependency) []pomDependency {
uniq := make(map[string]struct{})
uniq := set.New[string]()
var depManagement []pomDependency
// The preceding argument takes precedence.
for _, dm := range depManagements {
for _, dep := range dm {
if _, ok := uniq[dep.Name()]; ok {
if uniq.Contains(dep.Name()) {
continue
}
depManagement = append(depManagement, dep)
uniq[dep.Name()] = struct{}{}
uniq.Append(dep.Name())
}
}
return depManagement
Expand Down Expand Up @@ -492,19 +496,19 @@ func (p *Parser) mergeDependencies(child, parent []pomDependency) []pomDependenc
})
}

func (p *Parser) filterDependencies(artifacts []artifact, exclusions map[string]struct{}) []artifact {
func (p *Parser) filterDependencies(artifacts []artifact, exclusions set.Set[string]) []artifact {
return lo.Filter(artifacts, func(art artifact, _ int) bool {
return !excludeDep(exclusions, art)
})
}

func excludeDep(exclusions map[string]struct{}, art artifact) bool {
if _, ok := exclusions[art.Name()]; ok {
func excludeDep(exclusions set.Set[string], art artifact) bool {
if exclusions.Contains(art.Name()) {
return true
}
// Maven can use "*" in GroupID and ArtifactID fields to exclude dependencies
// https://maven.apache.org/pom.html#exclusions
for exlusion := range exclusions {
for exlusion := range exclusions.Iter() {
// exclusion format - "<groupID>:<artifactID>"
e := strings.Split(exlusion, ":")
if (e[0] == art.GroupID || e[0] == "*") && (e[1] == art.ArtifactID || e[1] == "*") {
Expand Down
8 changes: 4 additions & 4 deletions pkg/dependency/parser/java/pom/pom.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"encoding/xml"
"fmt"
"io"
"maps"
"net/url"
"reflect"
"strings"
Expand All @@ -15,6 +14,7 @@ import (
"github.com/aquasecurity/trivy/pkg/dependency/parser/utils"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/x/slices"
)

Expand Down Expand Up @@ -287,12 +287,12 @@ func (d pomDependency) ToArtifact(opts analysisOptions) artifact {
// To avoid shadow adding exclusions to top pom's,
// we need to initialize a new map for each new artifact
// See `exclusions in child` test for more information
exclusions := make(map[string]struct{})
exclusions := set.New[string]()
if opts.exclusions != nil {
exclusions = maps.Clone(opts.exclusions)
exclusions = opts.exclusions.Clone()
}
for _, e := range d.Exclusions.Exclusion {
exclusions[fmt.Sprintf("%s:%s", e.GroupID, e.ArtifactID)] = struct{}{}
exclusions.Append(fmt.Sprintf("%s:%s", e.GroupID, e.ArtifactID))
}

var locations ftypes.Locations
Expand Down
15 changes: 8 additions & 7 deletions pkg/dependency/parser/nodejs/npm/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/aquasecurity/trivy/pkg/dependency/parser/utils"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
xio "github.com/aquasecurity/trivy/pkg/x/io"
)

Expand Down Expand Up @@ -91,7 +92,7 @@ func (p *Parser) parseV2(packages map[string]Package) ([]ftypes.Package, []ftype
// https://docs.npmjs.com/cli/v9/configuring-npm/package-lock-json#packages
p.resolveLinks(packages)

directDeps := make(map[string]struct{})
directDeps := set.New[string]()
for name, version := range lo.Assign(packages[""].Dependencies, packages[""].OptionalDependencies, packages[""].DevDependencies, packages[""].PeerDependencies) {
pkgPath := joinPaths(nodeModulesDir, name)
if _, ok := packages[pkgPath]; !ok {
Expand All @@ -101,7 +102,7 @@ func (p *Parser) parseV2(packages map[string]Package) ([]ftypes.Package, []ftype
}
// Store the package paths of direct dependencies
// e.g. node_modules/body-parser
directDeps[pkgPath] = struct{}{}
directDeps.Append(pkgPath)
}

for pkgPath, pkg := range packages {
Expand Down Expand Up @@ -366,13 +367,13 @@ func (p *Parser) pkgNameFromPath(pkgPath string) string {

func uniqueDeps(deps []ftypes.Dependency) []ftypes.Dependency {
var uniqDeps ftypes.Dependencies
unique := make(map[string]struct{})
unique := set.New[string]()

for _, dep := range deps {
sort.Strings(dep.DependsOn)
depKey := fmt.Sprintf("%s:%s", dep.ID, strings.Join(dep.DependsOn, ","))
if _, ok := unique[depKey]; !ok {
unique[depKey] = struct{}{}
if !unique.Contains(depKey) {
unique.Append(depKey)
uniqDeps = append(uniqDeps, dep)
}
}
Expand All @@ -381,11 +382,11 @@ func uniqueDeps(deps []ftypes.Dependency) []ftypes.Dependency {
return uniqDeps
}

func isIndirectPkg(pkgPath string, directDeps map[string]struct{}) bool {
func isIndirectPkg(pkgPath string, directDeps set.Set[string]) bool {
// A project can contain 2 different versions of the same dependency.
// e.g. `node_modules/string-width/node_modules/strip-ansi` and `node_modules/string-ansi`
// direct dependencies always have root path (`node_modules/<pkg_name>`)
if _, ok := directDeps[pkgPath]; ok {
if directDeps.Contains(pkgPath) {
return false
}
return true
Expand Down
9 changes: 5 additions & 4 deletions pkg/dependency/parser/nodejs/pnpm/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/aquasecurity/trivy/pkg/dependency"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
xio "github.com/aquasecurity/trivy/pkg/x/io"
)

Expand Down Expand Up @@ -215,7 +216,7 @@ func (p *Parser) parseV9(lockFile LockFile) ([]ftypes.Package, []ftypes.Dependen
}
}

visited := make(map[string]struct{})
visited := set.New[string]()
// Overwrite the `Dev` field for dev deps and their child dependencies.
for _, pkg := range resolvedPkgs {
if !pkg.Dev {
Expand All @@ -227,8 +228,8 @@ func (p *Parser) parseV9(lockFile LockFile) ([]ftypes.Package, []ftypes.Dependen
}

// markRootPkgs sets `Dev` to false for non dev dependency.
func (p *Parser) markRootPkgs(id string, pkgs map[string]ftypes.Package, deps map[string]ftypes.Dependency, visited map[string]struct{}) {
if _, ok := visited[id]; ok {
func (p *Parser) markRootPkgs(id string, pkgs map[string]ftypes.Package, deps map[string]ftypes.Dependency, visited set.Set[string]) {
if visited.Contains(id) {
return
}
pkg, ok := pkgs[id]
Expand All @@ -238,7 +239,7 @@ func (p *Parser) markRootPkgs(id string, pkgs map[string]ftypes.Package, deps ma

pkg.Dev = false
pkgs[id] = pkg
visited[id] = struct{}{}
visited.Append(id)

// Update child deps
for _, depID := range deps[id].DependsOn {
Expand Down
2 changes: 1 addition & 1 deletion pkg/dependency/parser/nuget/lock/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (p *Parser) Parse(r xio.ReadSeekerAt) ([]ftypes.Package, []ftypes.Dependenc
}

if savedDependsOn, ok := depsMap[depId]; ok {
dependsOn = utils.UniqueStrings(append(dependsOn, savedDependsOn...))
dependsOn = lo.Uniq(append(dependsOn, savedDependsOn...))
}

if len(dependsOn) > 0 {
Expand Down
17 changes: 10 additions & 7 deletions pkg/dependency/parser/python/pyproject/pyproject.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"golang.org/x/xerrors"

"github.com/aquasecurity/trivy/pkg/dependency/parser/python"
"github.com/aquasecurity/trivy/pkg/set"
)

type PyProject struct {
Expand All @@ -19,25 +20,27 @@ type Tool struct {
}

type Poetry struct {
Dependencies dependencies `toml:"dependencies"`
Dependencies Dependencies `toml:"dependencies"`
Groups map[string]Group `toml:"group"`
}

type Group struct {
Dependencies dependencies `toml:"dependencies"`
Dependencies Dependencies `toml:"dependencies"`
}

type dependencies map[string]struct{}
type Dependencies struct {
set.Set[string]
}

func (d *dependencies) UnmarshalTOML(data any) error {
func (d *Dependencies) UnmarshalTOML(data any) error {
m, ok := data.(map[string]any)
if !ok {
return xerrors.Errorf("dependencies must be map, but got: %T", data)
}

*d = lo.MapEntries(m, func(pkgName string, _ any) (string, struct{}) {
return python.NormalizePkgName(pkgName), struct{}{}
})
d.Set = set.New[string](lo.MapToSlice(m, func(pkgName string, _ any) string {
return python.NormalizePkgName(pkgName)
})...)
return nil
}

Expand Down
16 changes: 7 additions & 9 deletions pkg/dependency/parser/python/pyproject/pyproject_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require"

"github.com/aquasecurity/trivy/pkg/dependency/parser/python/pyproject"
"github.com/aquasecurity/trivy/pkg/set"
)

func TestParser_Parse(t *testing.T) {
Expand All @@ -24,21 +25,18 @@ func TestParser_Parse(t *testing.T) {
want: pyproject.PyProject{
Tool: pyproject.Tool{
Poetry: pyproject.Poetry{
Dependencies: map[string]struct{}{
"flask": {},
"python": {},
"requests": {},
"virtualenv": {},
Dependencies: pyproject.Dependencies{
Set: set.New[string]("flask", "python", "requests", "virtualenv"),
},
Groups: map[string]pyproject.Group{
"dev": {
Dependencies: map[string]struct{}{
"pytest": {},
Dependencies: pyproject.Dependencies{
Set: set.New[string]("pytest"),
},
},
"lint": {
Dependencies: map[string]struct{}{
"ruff": {},
Dependencies: pyproject.Dependencies{
Set: set.New[string]("ruff"),
},
},
},
Expand Down
Loading
Loading