diff --git a/pkg/k8s/k8s.go b/pkg/k8s/k8s.go index 58cc49c..c298cec 100644 --- a/pkg/k8s/k8s.go +++ b/pkg/k8s/k8s.go @@ -403,6 +403,10 @@ func IsBuiltInWorkload(resource *metav1.OwnerReference) bool { resource.Kind == string(KindJob)) } +func GetAllResources() []string { + return append(getClusterResources(), getNamespaceResources()...) +} + func getClusterResources() []string { return []string{ ClusterRoles, diff --git a/pkg/trivyk8s/trivyk8s.go b/pkg/trivyk8s/trivyk8s.go index a7b37bf..2a8266b 100644 --- a/pkg/trivyk8s/trivyk8s.go +++ b/pkg/trivyk8s/trivyk8s.go @@ -161,8 +161,66 @@ func (c client) GetIncludeKinds() []string { return c.includeKinds } +// initResourceList collects scannable resources. +func (c *client) initResourceList() { + // skip if resources are already created + if len(c.resources) > 0 { + return + } + + // collect only included kinds + if len(c.includeKinds) != 0 { + // `includeKinds` are already low cased. + c.resources = c.includeKinds + return + } + // if there are no included and excluded kinds - don't collect resources + if len(c.excludeKinds) == 0 { + return + } + // skip excluded resources + for _, kind := range k8s.GetAllResources() { + if slices.Contains(c.excludeKinds, kind) { + continue + } + c.resources = append(c.resources, kind) + } +} + +// getNamespaces collects scannable namespaces +func (c *client) getNamespaces() []string { + if len(c.includeNamespaces) > 0 { + return c.includeNamespaces + } + if len(c.excludeNamespaces) == 0 { + return nil + } + // ToDo: get all namespaces and skip excluded namespaces + return []string{} +} + // ListArtifacts returns kubernetes scannable artifacs. func (c *client) ListArtifacts(ctx context.Context) ([]*artifacts.Artifact, error) { + c.initResourceList() + namespaces := c.getNamespaces() + if len(namespaces) == 0 { + return c.ListSpecificArtifacts(ctx) + } + artifactList := make([]*artifacts.Artifact, 0) + + for _, namespace := range namespaces { + c.namespace = namespace + arts, err := c.ListSpecificArtifacts(ctx) + if err != nil { + return nil, err + } + artifactList = append(artifactList, arts...) + } + return artifactList, nil +} + +// ListSpecificArtifacts returns kubernetes scannable artifacs for a specific namespace or a cluster +func (c *client) ListSpecificArtifacts(ctx context.Context) ([]*artifacts.Artifact, error) { artifactList := make([]*artifacts.Artifact, 0) namespaced := isNamespaced(c.namespace, c.allNamespaces) @@ -195,15 +253,6 @@ func (c *client) ListArtifacts(ctx context.Context) ([]*artifacts.Artifact, erro if c.excludeOwned && c.hasOwner(resource) { continue } - // filter resources by kind - if FilterResources(c.includeKinds, c.excludeKinds, resource.GetKind()) { - continue - } - - // filter resources by namespace - if FilterResources(c.includeNamespaces, c.excludeNamespaces, resource.GetNamespace()) { - continue - } lastAppliedResource := resource if jsonManifest, ok := resource.GetAnnotations()["kubectl.kubernetes.io/last-applied-configuration"]; ok { // required for outdated-api when k8s convert resources @@ -470,7 +519,7 @@ func rawResource(resource interface{}) (map[string]interface{}, error) { func (c *client) getDynamicClient(gvr schema.GroupVersionResource) dynamic.ResourceInterface { dclient := c.cluster.GetDynamicClient() - // don't use namespace if it is a cluster levle resource, + // don't use namespace if it is a cluster level resource, // or namespace is empty if k8s.IsClusterResource(gvr) || len(c.namespace) == 0 { return dclient.Resource(gvr) diff --git a/pkg/trivyk8s/trivyk8s_test.go b/pkg/trivyk8s/trivyk8s_test.go index f57ede5..eb2fb43 100644 --- a/pkg/trivyk8s/trivyk8s_test.go +++ b/pkg/trivyk8s/trivyk8s_test.go @@ -3,8 +3,10 @@ package trivyk8s import ( "testing" - "github.com/aquasecurity/trivy-kubernetes/pkg/artifacts" "github.com/stretchr/testify/assert" + + "github.com/aquasecurity/trivy-kubernetes/pkg/artifacts" + "github.com/aquasecurity/trivy-kubernetes/pkg/k8s" ) func TestIgnoreNodeByLabel(t *testing.T) { @@ -87,3 +89,50 @@ func TestFilterResource(t *testing.T) { }) } } + +func TestInitResources(t *testing.T) { + tests := []struct { + name string + includeKinds []string + excludeKinds []string + want []string + }{ + { + "scan only pods", + []string{"pods"}, + nil, + []string{k8s.Pods}, + }, + { + "skip ClusterRoles, Deployments and Ingresses", + nil, + []string{"deployments", "ingresses", "clusterroles"}, + []string{ + k8s.ClusterRoleBindings, + k8s.Nodes, + k8s.Pods, + k8s.ReplicaSets, + k8s.ReplicationControllers, + k8s.StatefulSets, + k8s.DaemonSets, + k8s.CronJobs, + k8s.Jobs, + k8s.Services, + k8s.ServiceAccounts, + k8s.ConfigMaps, + k8s.Roles, + k8s.RoleBindings, + k8s.NetworkPolicies, + k8s.ResourceQuotas, + k8s.LimitRanges, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &client{excludeKinds: tt.excludeKinds, includeKinds: tt.includeKinds} + c.initResourceList() + assert.Equal(t, tt.want, c.resources) + }) + } +}