diff --git a/internal/flavors/assetinventory/strategy_aws.go b/internal/flavors/assetinventory/strategy_aws.go index a06add2825..e9fa56c160 100644 --- a/internal/flavors/assetinventory/strategy_aws.go +++ b/internal/flavors/assetinventory/strategy_aws.go @@ -54,7 +54,7 @@ func (s *strategy) initAwsFetchers(ctx context.Context) ([]inventory.AssetFetche // Early exit if we're scanning the entire account. if s.cfg.CloudConfig.Aws.AccountType == config.SingleAccount { - return awsfetcher.New(s.logger, awsIdentity, *awsConfig), nil + return awsfetcher.New(ctx, s.logger, awsIdentity, *awsConfig), nil } // Assume audit roles per selected account and generate fetchers for them @@ -81,7 +81,7 @@ func (s *strategy) initAwsFetchers(ctx context.Context) ([]inventory.AssetFetche s.logger.Infof("Skipping identity on purpose %+v", identity) continue } - accountFetchers := awsfetcher.New(s.logger, &identity, assumedRoleConfig) + accountFetchers := awsfetcher.New(ctx, s.logger, &identity, assumedRoleConfig) fetchers = append(fetchers, accountFetchers...) } diff --git a/internal/flavors/benchmark/aws.go b/internal/flavors/benchmark/aws.go index c6e0aa54ab..d815b6b3e9 100644 --- a/internal/flavors/benchmark/aws.go +++ b/internal/flavors/benchmark/aws.go @@ -79,7 +79,7 @@ func (a *AWS) initialize(ctx context.Context, log *logp.Logger, cfg *config.Conf return registry.NewRegistry( log, - registry.WithFetchersMap(preset.NewCisAwsFetchers(log, *awsConfig, ch, awsIdentity)), + registry.WithFetchersMap(preset.NewCisAwsFetchers(ctx, log, *awsConfig, ch, awsIdentity)), ), cloud.NewDataProvider(cloud.WithAccount(*awsIdentity)), nil, nil } diff --git a/internal/flavors/benchmark/aws_org.go b/internal/flavors/benchmark/aws_org.go index 61062227c4..06d4f14fcb 100644 --- a/internal/flavors/benchmark/aws_org.go +++ b/internal/flavors/benchmark/aws_org.go @@ -88,7 +88,7 @@ func (a *AWSOrg) initialize(ctx context.Context, log *logp.Logger, cfg *config.C } log.Info("successfully retrieved AWS Identity") - a.IAMProvider = iam.NewIAMProvider(log, *awsConfig, nil) + a.IAMProvider = iam.NewIAMProvider(ctx, log, *awsConfig, nil) cache := make(map[string]registry.FetchersMap) reg := registry.NewRegistry(log, registry.WithUpdater( diff --git a/internal/flavors/benchmark/eks.go b/internal/flavors/benchmark/eks.go index 80f263244d..8e1838d387 100644 --- a/internal/flavors/benchmark/eks.go +++ b/internal/flavors/benchmark/eks.go @@ -97,7 +97,7 @@ func (k *EKS) initialize(ctx context.Context, log *logp.Logger, cfg *config.Conf return registry.NewRegistry( log, - registry.WithFetchersMap(preset.NewCisEksFetchers(log, awsConfig, ch, k.leaderElector, kubeClient, awsIdentity)), + registry.WithFetchersMap(preset.NewCisEksFetchers(ctx, log, awsConfig, ch, k.leaderElector, kubeClient, awsIdentity)), ), dp, idp, nil } diff --git a/internal/flavors/vulnerability.go b/internal/flavors/vulnerability.go index 416ffc1151..cdf19d606b 100644 --- a/internal/flavors/vulnerability.go +++ b/internal/flavors/vulnerability.go @@ -127,7 +127,7 @@ func (bt *vulnerability) Run(*beat.Beat) error { } func (bt *vulnerability) runIteration() error { - worker, err := vuln.NewVulnerabilityWorker(bt.log, bt.config, bt.bdp, bt.cdp) + worker, err := vuln.NewVulnerabilityWorker(bt.ctx, bt.log, bt.config, bt.bdp, bt.cdp) if err != nil { bt.log.Warn("vulnerability.runIteration worker creation failed") bt.cancel() diff --git a/internal/inventory/awsfetcher/awsfetchers.go b/internal/inventory/awsfetcher/awsfetchers.go index 9405af50d8..b5ce8336bc 100644 --- a/internal/inventory/awsfetcher/awsfetchers.go +++ b/internal/inventory/awsfetcher/awsfetchers.go @@ -18,6 +18,8 @@ package awsfetcher import ( + "context" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/elastic/elastic-agent-libs/logp" @@ -34,15 +36,15 @@ import ( "github.com/elastic/cloudbeat/internal/resources/providers/awslib/sns" ) -func New(logger *logp.Logger, identity *cloud.Identity, cfg aws.Config) []inventory.AssetFetcher { - ec2Provider := ec2.NewEC2Provider(logger, identity.Account, cfg, &awslib.MultiRegionClientFactory[ec2.Client]{}) - elbProvider := elb.NewElbProvider(logger, identity.Account, cfg, &awslib.MultiRegionClientFactory[elb.Client]{}) - elbv2Provider := elbv2.NewElbV2Provider(logger, cfg, &awslib.MultiRegionClientFactory[elbv2.Client]{}) - iamProvider := iam.NewIAMProvider(logger, cfg, &awslib.MultiRegionClientFactory[iam.AccessAnalyzerClient]{}) - lambdaProvider := lambda.NewLambdaProvider(logger, cfg, &awslib.MultiRegionClientFactory[lambda.Client]{}) - rdsProvider := rds.NewProvider(logger, cfg, &awslib.MultiRegionClientFactory[rds.Client]{}, ec2Provider) - s3Provider := s3.NewProvider(logger, cfg, &awslib.MultiRegionClientFactory[s3.Client]{}, identity.Account) - snsProvider := sns.NewSNSProvider(logger, cfg, &awslib.MultiRegionClientFactory[sns.Client]{}) +func New(ctx context.Context, logger *logp.Logger, identity *cloud.Identity, cfg aws.Config) []inventory.AssetFetcher { + ec2Provider := ec2.NewEC2Provider(ctx, logger, identity.Account, cfg, &awslib.MultiRegionClientFactory[ec2.Client]{}) + elbProvider := elb.NewElbProvider(ctx, logger, identity.Account, cfg, &awslib.MultiRegionClientFactory[elb.Client]{}) + elbv2Provider := elbv2.NewElbV2Provider(ctx, logger, cfg, &awslib.MultiRegionClientFactory[elbv2.Client]{}) + iamProvider := iam.NewIAMProvider(ctx, logger, cfg, &awslib.MultiRegionClientFactory[iam.AccessAnalyzerClient]{}) + lambdaProvider := lambda.NewLambdaProvider(ctx, logger, cfg, &awslib.MultiRegionClientFactory[lambda.Client]{}) + rdsProvider := rds.NewProvider(ctx, logger, cfg, &awslib.MultiRegionClientFactory[rds.Client]{}, ec2Provider) + s3Provider := s3.NewProvider(ctx, logger, cfg, &awslib.MultiRegionClientFactory[s3.Client]{}, identity.Account) + snsProvider := sns.NewSNSProvider(ctx, logger, cfg, &awslib.MultiRegionClientFactory[sns.Client]{}) return []inventory.AssetFetcher{ newEc2InstancesFetcher(logger, identity, ec2Provider), diff --git a/internal/resources/fetching/preset/aws_org_preset.go b/internal/resources/fetching/preset/aws_org_preset.go index 7e63924c2f..469f79dd54 100644 --- a/internal/resources/fetching/preset/aws_org_preset.go +++ b/internal/resources/fetching/preset/aws_org_preset.go @@ -65,7 +65,7 @@ func NewCisAwsOrganizationFetchers(ctx context.Context, log *logp.Logger, rootCh } // awsFactory is the same function type as NewCisAwsFetchers, and it's used to mock the function in tests -type awsFactory func(*logp.Logger, aws.Config, chan fetching.ResourceInfo, *cloud.Identity) registry.FetchersMap +type awsFactory func(context.Context, *logp.Logger, aws.Config, chan fetching.ResourceInfo, *cloud.Identity) registry.FetchersMap func newCisAwsOrganizationFetchers( ctx context.Context, @@ -117,6 +117,7 @@ func newCisAwsOrganizationFetchers( }(account.Identity) f := factory( + ctx, log.Named("aws").WithOptions(zap.Fields(zap.String("cloud.account.id", account.Identity.Account))), account.Config, ch, diff --git a/internal/resources/fetching/preset/aws_org_preset_test.go b/internal/resources/fetching/preset/aws_org_preset_test.go index 81b1cbcbc7..d2484cfee4 100644 --- a/internal/resources/fetching/preset/aws_org_preset_test.go +++ b/internal/resources/fetching/preset/aws_org_preset_test.go @@ -70,7 +70,7 @@ func subtest(t *testing.T, drain bool) { //revive:disable-line:flag-parameter ctx, cancel := context.WithCancel(context.Background()) factory := mockFactory(nAccounts, - func(_ *logp.Logger, _ aws.Config, ch chan fetching.ResourceInfo, _ *cloud.Identity) registry.FetchersMap { + func(_ context.Context, _ *logp.Logger, _ aws.Config, ch chan fetching.ResourceInfo, _ *cloud.Identity) registry.FetchersMap { if drain { // create some resources if we are testing for that go func() { @@ -152,7 +152,7 @@ func TestNewCisAwsOrganizationFetchers_LeakContextDone(t *testing.T) { }}, nil, mockFactory(1, - func(_ *logp.Logger, _ aws.Config, ch chan fetching.ResourceInfo, _ *cloud.Identity) registry.FetchersMap { + func(_ context.Context, _ *logp.Logger, _ aws.Config, ch chan fetching.ResourceInfo, _ *cloud.Identity) registry.FetchersMap { ch <- fetching.ResourceInfo{ Resource: mockResource(), CycleMetadata: cycle.Metadata{Sequence: 1}, @@ -181,7 +181,7 @@ func TestNewCisAwsOrganizationFetchers_CloseChannel(t *testing.T) { }}, nil, mockFactory(1, - func(_ *logp.Logger, _ aws.Config, ch chan fetching.ResourceInfo, _ *cloud.Identity) registry.FetchersMap { + func(_ context.Context, _ *logp.Logger, _ aws.Config, ch chan fetching.ResourceInfo, _ *cloud.Identity) registry.FetchersMap { defer close(ch) return registry.FetchersMap{"fetcher": registry.RegisteredFetcher{}} }, @@ -214,7 +214,7 @@ func TestNewCisAwsOrganizationFetchers_Cache(t *testing.T) { }, cache, mockFactory(1, - func(_ *logp.Logger, _ aws.Config, _ chan fetching.ResourceInfo, identity *cloud.Identity) registry.FetchersMap { + func(_ context.Context, _ *logp.Logger, _ aws.Config, _ chan fetching.ResourceInfo, identity *cloud.Identity) registry.FetchersMap { assert.Equal(t, "2", identity.Account) return registry.FetchersMap{"fetcher": registry.RegisteredFetcher{}} }, @@ -241,6 +241,6 @@ func mockResource() *fetching.MockResource { func mockFactory(times int, f awsFactory) awsFactory { factory := mockAwsFactory{} - factory.EXPECT().Execute(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(f).Times(times) + factory.EXPECT().Execute(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(f).Times(times) return factory.Execute } diff --git a/internal/resources/fetching/preset/aws_preset.go b/internal/resources/fetching/preset/aws_preset.go index b0fdb1ba5f..978510000c 100644 --- a/internal/resources/fetching/preset/aws_preset.go +++ b/internal/resources/fetching/preset/aws_preset.go @@ -18,6 +18,8 @@ package preset import ( + "context" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/elastic/elastic-agent-libs/logp" @@ -41,24 +43,25 @@ import ( "github.com/elastic/cloudbeat/internal/resources/providers/awslib/sns" ) -func NewCisAwsFetchers(log *logp.Logger, cfg aws.Config, ch chan fetching.ResourceInfo, identity *cloud.Identity) registry.FetchersMap { +func NewCisAwsFetchers(ctx context.Context, log *logp.Logger, cfg aws.Config, ch chan fetching.ResourceInfo, identity *cloud.Identity) registry.FetchersMap { log.Infof("Initializing AWS fetchers for account: '%s'", identity.Account) m := make(registry.FetchersMap) - iamProvider := iam.NewIAMProvider(log, cfg, &awslib.MultiRegionClientFactory[iam.AccessAnalyzerClient]{}) + iamProvider := iam.NewIAMProvider(ctx, log, cfg, &awslib.MultiRegionClientFactory[iam.AccessAnalyzerClient]{}) iamFetcher := fetchers.NewIAMFetcher(log, iamProvider, ch, identity) m[fetching.IAMType] = registry.RegisteredFetcher{Fetcher: iamFetcher} - kmsProvider := kms.NewKMSProvider(log, cfg, &awslib.MultiRegionClientFactory[kms.Client]{}) + kmsProvider := kms.NewKMSProvider(ctx, log, cfg, &awslib.MultiRegionClientFactory[kms.Client]{}) kmsFetcher := fetchers.NewKMSFetcher(log, kmsProvider, ch) m[fetching.KmsType] = registry.RegisteredFetcher{Fetcher: kmsFetcher} - loggingProvider := logging.NewProvider(log, cfg, &awslib.MultiRegionClientFactory[cloudtrail.Client]{}, &awslib.MultiRegionClientFactory[s3.Client]{}, identity.Account) - configserviceProvider := configservice.NewProvider(log, cfg, &awslib.MultiRegionClientFactory[configservice.Client]{}, identity.Account) + loggingProvider := logging.NewProvider(ctx, log, cfg, &awslib.MultiRegionClientFactory[cloudtrail.Client]{}, &awslib.MultiRegionClientFactory[s3.Client]{}, identity.Account) + configserviceProvider := configservice.NewProvider(ctx, log, cfg, &awslib.MultiRegionClientFactory[configservice.Client]{}, identity.Account) loggingFetcher := fetchers.NewLoggingFetcher(log, loggingProvider, configserviceProvider, ch, identity) m[fetching.TrailType] = registry.RegisteredFetcher{Fetcher: loggingFetcher} monitoringProvider := monitoring.NewProvider( + ctx, log, cfg, &awslib.MultiRegionClientFactory[cloudtrail.Client]{}, @@ -67,19 +70,19 @@ func NewCisAwsFetchers(log *logp.Logger, cfg aws.Config, ch chan fetching.Resour &awslib.MultiRegionClientFactory[sns.Client]{}, ) - securityHubProvider := securityhub.NewProvider(log, cfg, &awslib.MultiRegionClientFactory[securityhub.Client]{}, identity.Account) + securityHubProvider := securityhub.NewProvider(ctx, log, cfg, &awslib.MultiRegionClientFactory[securityhub.Client]{}, identity.Account) monitoringFetcher := fetchers.NewMonitoringFetcher(log, monitoringProvider, securityHubProvider, ch, identity) m[fetching.AwsMonitoringType] = registry.RegisteredFetcher{Fetcher: monitoringFetcher} - ec2Provider := ec2.NewEC2Provider(log, identity.Account, cfg, &awslib.MultiRegionClientFactory[ec2.Client]{}) + ec2Provider := ec2.NewEC2Provider(ctx, log, identity.Account, cfg, &awslib.MultiRegionClientFactory[ec2.Client]{}) networkFetcher := fetchers.NewNetworkFetcher(log, ec2Provider, ch) m[fetching.EC2NetworkingType] = registry.RegisteredFetcher{Fetcher: networkFetcher} - rdsProvider := rds.NewProvider(log, cfg, &awslib.MultiRegionClientFactory[rds.Client]{}, ec2Provider) + rdsProvider := rds.NewProvider(ctx, log, cfg, &awslib.MultiRegionClientFactory[rds.Client]{}, ec2Provider) rdsFetcher := fetchers.NewRdsFetcher(log, rdsProvider, ch) m[fetching.RdsType] = registry.RegisteredFetcher{Fetcher: rdsFetcher} - s3Provider := s3.NewProvider(log, cfg, &awslib.MultiRegionClientFactory[s3.Client]{}, identity.Account) + s3Provider := s3.NewProvider(ctx, log, cfg, &awslib.MultiRegionClientFactory[s3.Client]{}, identity.Account) s3Fetcher := fetchers.NewS3Fetcher(log, s3Provider, ch) m[fetching.S3Type] = registry.RegisteredFetcher{Fetcher: s3Fetcher} diff --git a/internal/resources/fetching/preset/eks_preset.go b/internal/resources/fetching/preset/eks_preset.go index 80325cb16e..fb7e06ff8a 100644 --- a/internal/resources/fetching/preset/eks_preset.go +++ b/internal/resources/fetching/preset/eks_preset.go @@ -18,6 +18,7 @@ package preset import ( + "context" "fmt" "regexp" @@ -45,16 +46,17 @@ var ( eksRequiredProcesses = k8sfetchers.ProcessesConfigMap{"kubelet": {ConfigFileArguments: []string{"config"}}} eksFsPatterns = []string{ "/hostfs/etc/kubernetes/kubelet/kubelet-config.json", - "/hostfs/var/lib/kubelet/kubeconfig"} + "/hostfs/var/lib/kubelet/kubeconfig", + } ) -func NewCisEksFetchers(log *logp.Logger, awsConfig aws.Config, ch chan fetching.ResourceInfo, le uniqueness.Manager, k8sClient k8s.Interface, identity *cloud.Identity) registry.FetchersMap { +func NewCisEksFetchers(ctx context.Context, log *logp.Logger, awsConfig aws.Config, ch chan fetching.ResourceInfo, le uniqueness.Manager, k8sClient k8s.Interface, identity *cloud.Identity) registry.FetchersMap { log.Infof("Initializing EKS fetchers") m := make(registry.FetchersMap) if identity != nil { log.Info("Initialize aws-related fetchers") - ecrPrivateProvider := ecr.NewEcrProvider(log, awsConfig, &awslib.MultiRegionClientFactory[ecr.Client]{}) + ecrPrivateProvider := ecr.NewEcrProvider(ctx, log, awsConfig, &awslib.MultiRegionClientFactory[ecr.Client]{}) privateRepoRegex := fmt.Sprintf(awsfetchers.PrivateRepoRegexTemplate, identity.Account) ecrPodDescriber := awsfetchers.PodDescriber{ @@ -65,7 +67,7 @@ func NewCisEksFetchers(log *logp.Logger, awsConfig aws.Config, ch chan fetching. ecrFetcher := awsfetchers.NewEcrFetcher(log, ch, k8sClient, ecrPodDescriber) m[fetching.EcrType] = registry.RegisteredFetcher{Fetcher: ecrFetcher, Condition: []fetching.Condition{condition.NewIsLeader(le)}} - elbProvider := elb.NewElbProvider(log, identity.Account, awsConfig, &awslib.MultiRegionClientFactory[elb.Client]{}) + elbProvider := elb.NewElbProvider(ctx, log, identity.Account, awsConfig, &awslib.MultiRegionClientFactory[elb.Client]{}) loadBalancerRegex := fmt.Sprintf(elbRegexTemplate, awsConfig.Region) elbFetcher := awsfetchers.NewElbFetcher(log, ch, k8sClient, elbProvider, identity, loadBalancerRegex) m[fetching.ElbType] = registry.RegisteredFetcher{Fetcher: elbFetcher, Condition: []fetching.Condition{condition.NewIsLeader(le)}} diff --git a/internal/resources/fetching/preset/mock_aws_factory.go b/internal/resources/fetching/preset/mock_aws_factory.go index 204934bcbb..663883b9ae 100644 --- a/internal/resources/fetching/preset/mock_aws_factory.go +++ b/internal/resources/fetching/preset/mock_aws_factory.go @@ -20,8 +20,11 @@ package preset import ( + context "context" + aws "github.com/aws/aws-sdk-go-v2/aws" cloud "github.com/elastic/cloudbeat/internal/dataprovider/providers/cloud" + fetching "github.com/elastic/cloudbeat/internal/resources/fetching" logp "github.com/elastic/elastic-agent-libs/logp" @@ -44,13 +47,13 @@ func (_m *mockAwsFactory) EXPECT() *mockAwsFactory_Expecter { return &mockAwsFactory_Expecter{mock: &_m.Mock} } -// Execute provides a mock function with given fields: _a0, _a1, _a2, _a3 -func (_m *mockAwsFactory) Execute(_a0 *logp.Logger, _a1 aws.Config, _a2 chan fetching.ResourceInfo, _a3 *cloud.Identity) registry.FetchersMap { - ret := _m.Called(_a0, _a1, _a2, _a3) +// Execute provides a mock function with given fields: _a0, _a1, _a2, _a3, _a4 +func (_m *mockAwsFactory) Execute(_a0 context.Context, _a1 *logp.Logger, _a2 aws.Config, _a3 chan fetching.ResourceInfo, _a4 *cloud.Identity) registry.FetchersMap { + ret := _m.Called(_a0, _a1, _a2, _a3, _a4) var r0 registry.FetchersMap - if rf, ok := ret.Get(0).(func(*logp.Logger, aws.Config, chan fetching.ResourceInfo, *cloud.Identity) registry.FetchersMap); ok { - r0 = rf(_a0, _a1, _a2, _a3) + if rf, ok := ret.Get(0).(func(context.Context, *logp.Logger, aws.Config, chan fetching.ResourceInfo, *cloud.Identity) registry.FetchersMap); ok { + r0 = rf(_a0, _a1, _a2, _a3, _a4) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(registry.FetchersMap) @@ -66,17 +69,18 @@ type mockAwsFactory_Execute_Call struct { } // Execute is a helper method to define mock.On call -// - _a0 *logp.Logger -// - _a1 aws.Config -// - _a2 chan fetching.ResourceInfo -// - _a3 *cloud.Identity -func (_e *mockAwsFactory_Expecter) Execute(_a0 interface{}, _a1 interface{}, _a2 interface{}, _a3 interface{}) *mockAwsFactory_Execute_Call { - return &mockAwsFactory_Execute_Call{Call: _e.mock.On("Execute", _a0, _a1, _a2, _a3)} +// - _a0 context.Context +// - _a1 *logp.Logger +// - _a2 aws.Config +// - _a3 chan fetching.ResourceInfo +// - _a4 *cloud.Identity +func (_e *mockAwsFactory_Expecter) Execute(_a0 interface{}, _a1 interface{}, _a2 interface{}, _a3 interface{}, _a4 interface{}) *mockAwsFactory_Execute_Call { + return &mockAwsFactory_Execute_Call{Call: _e.mock.On("Execute", _a0, _a1, _a2, _a3, _a4)} } -func (_c *mockAwsFactory_Execute_Call) Run(run func(_a0 *logp.Logger, _a1 aws.Config, _a2 chan fetching.ResourceInfo, _a3 *cloud.Identity)) *mockAwsFactory_Execute_Call { +func (_c *mockAwsFactory_Execute_Call) Run(run func(_a0 context.Context, _a1 *logp.Logger, _a2 aws.Config, _a3 chan fetching.ResourceInfo, _a4 *cloud.Identity)) *mockAwsFactory_Execute_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(*logp.Logger), args[1].(aws.Config), args[2].(chan fetching.ResourceInfo), args[3].(*cloud.Identity)) + run(args[0].(context.Context), args[1].(*logp.Logger), args[2].(aws.Config), args[3].(chan fetching.ResourceInfo), args[4].(*cloud.Identity)) }) return _c } @@ -86,7 +90,7 @@ func (_c *mockAwsFactory_Execute_Call) Return(_a0 registry.FetchersMap) *mockAws return _c } -func (_c *mockAwsFactory_Execute_Call) RunAndReturn(run func(*logp.Logger, aws.Config, chan fetching.ResourceInfo, *cloud.Identity) registry.FetchersMap) *mockAwsFactory_Execute_Call { +func (_c *mockAwsFactory_Execute_Call) RunAndReturn(run func(context.Context, *logp.Logger, aws.Config, chan fetching.ResourceInfo, *cloud.Identity) registry.FetchersMap) *mockAwsFactory_Execute_Call { _c.Call.Return(run) return _c } diff --git a/internal/resources/providers/aws_cis/logging/logging.go b/internal/resources/providers/aws_cis/logging/logging.go index 05661d893f..06108fb939 100644 --- a/internal/resources/providers/aws_cis/logging/logging.go +++ b/internal/resources/providers/aws_cis/logging/logging.go @@ -39,6 +39,7 @@ type Provider struct { } func NewProvider( + ctx context.Context, log *logp.Logger, cfg aws.Config, multiRegionTrailFactory awslib.CrossRegionFactory[cloudtrail.Client], @@ -47,7 +48,7 @@ func NewProvider( ) *Provider { return &Provider{ log: log, - s3Provider: s3.NewProvider(log, cfg, multiRegionS3Factory, accountId), - trailProvider: cloudtrail.NewProvider(log, cfg, multiRegionTrailFactory), + s3Provider: s3.NewProvider(ctx, log, cfg, multiRegionS3Factory, accountId), + trailProvider: cloudtrail.NewProvider(ctx, log, cfg, multiRegionTrailFactory), } } diff --git a/internal/resources/providers/aws_cis/monitoring/monitoring.go b/internal/resources/providers/aws_cis/monitoring/monitoring.go index 2d6524b62d..edd6043ad9 100644 --- a/internal/resources/providers/aws_cis/monitoring/monitoring.go +++ b/internal/resources/providers/aws_cis/monitoring/monitoring.go @@ -64,12 +64,12 @@ type ( } ) -func NewProvider(log *logp.Logger, awsConfig aws.Config, trailCrossRegionFactory awslib.CrossRegionFactory[cloudtrail.Client], cloudwatchCrossResignFactory awslib.CrossRegionFactory[cloudwatch.Client], cloudwatchlogsCrossRegionFactory awslib.CrossRegionFactory[logs.Client], snsCrossRegionFactory awslib.CrossRegionFactory[sns.Client]) *Provider { +func NewProvider(ctx context.Context, log *logp.Logger, awsConfig aws.Config, trailCrossRegionFactory awslib.CrossRegionFactory[cloudtrail.Client], cloudwatchCrossResignFactory awslib.CrossRegionFactory[cloudwatch.Client], cloudwatchlogsCrossRegionFactory awslib.CrossRegionFactory[logs.Client], snsCrossRegionFactory awslib.CrossRegionFactory[sns.Client]) *Provider { return &Provider{ - Cloudtrail: cloudtrail.NewProvider(log, awsConfig, trailCrossRegionFactory), - Cloudwatch: cloudwatch.NewProvider(log, awsConfig, cloudwatchCrossResignFactory), - Cloudwatchlogs: logs.NewCloudwatchLogsProvider(log, awsConfig, cloudwatchlogsCrossRegionFactory), - Sns: sns.NewSNSProvider(log, awsConfig, snsCrossRegionFactory), + Cloudtrail: cloudtrail.NewProvider(ctx, log, awsConfig, trailCrossRegionFactory), + Cloudwatch: cloudwatch.NewProvider(ctx, log, awsConfig, cloudwatchCrossResignFactory), + Cloudwatchlogs: logs.NewCloudwatchLogsProvider(ctx, log, awsConfig, cloudwatchlogsCrossRegionFactory), + Sns: sns.NewSNSProvider(ctx, log, awsConfig, snsCrossRegionFactory), Log: log, } } diff --git a/internal/resources/providers/awslib/cloudtrail/cloudtrail.go b/internal/resources/providers/awslib/cloudtrail/cloudtrail.go index 84d80b7740..ca6a180feb 100644 --- a/internal/resources/providers/awslib/cloudtrail/cloudtrail.go +++ b/internal/resources/providers/awslib/cloudtrail/cloudtrail.go @@ -31,12 +31,12 @@ type TrailService interface { DescribeTrails(ctx context.Context) ([]TrailInfo, error) } -func NewProvider(log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client]) *Provider { +func NewProvider(ctx context.Context, log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client]) *Provider { f := func(cfg aws.Config) Client { return trailClient.NewFromConfig(cfg) } - m := factory.NewMultiRegionClients(awslib.AllRegionSelector(), cfg, f, log) + m := factory.NewMultiRegionClients(ctx, awslib.AllRegionSelector(), cfg, f, log) return &Provider{ log: log, clients: m.GetMultiRegionsClientMap(), diff --git a/internal/resources/providers/awslib/cloudwatch/cloudwatch.go b/internal/resources/providers/awslib/cloudwatch/cloudwatch.go index 76347c8e6a..98ce1900f9 100644 --- a/internal/resources/providers/awslib/cloudwatch/cloudwatch.go +++ b/internal/resources/providers/awslib/cloudwatch/cloudwatch.go @@ -32,11 +32,11 @@ type Cloudwatch interface { DescribeAlarms(ctx context.Context, region *string, filters []string) ([]types.MetricAlarm, error) } -func NewProvider(log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client]) *Provider { +func NewProvider(ctx context.Context, log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client]) *Provider { f := func(cfg aws.Config) Client { return cloudwatch.NewFromConfig(cfg) } - m := factory.NewMultiRegionClients(awslib.AllRegionSelector(), cfg, f, log) + m := factory.NewMultiRegionClients(ctx, awslib.AllRegionSelector(), cfg, f, log) return &Provider{ clients: m.GetMultiRegionsClientMap(), } diff --git a/internal/resources/providers/awslib/cloudwatch/logs/logs.go b/internal/resources/providers/awslib/cloudwatch/logs/logs.go index 0e1bca0090..53a55bd234 100644 --- a/internal/resources/providers/awslib/cloudwatch/logs/logs.go +++ b/internal/resources/providers/awslib/cloudwatch/logs/logs.go @@ -32,11 +32,11 @@ type CloudwatchLogs interface { DescribeMetricFilters(ctx context.Context, region *string, logGroup string) ([]types.MetricFilter, error) } -func NewCloudwatchLogsProvider(log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client]) *Provider { +func NewCloudwatchLogsProvider(ctx context.Context, log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client]) *Provider { f := func(cfg aws.Config) Client { return cloudwatchlogs.NewFromConfig(cfg) } - m := factory.NewMultiRegionClients(awslib.AllRegionSelector(), cfg, f, log) + m := factory.NewMultiRegionClients(ctx, awslib.AllRegionSelector(), cfg, f, log) return &Provider{ clients: m.GetMultiRegionsClientMap(), } diff --git a/internal/resources/providers/awslib/configservice/configservice.go b/internal/resources/providers/awslib/configservice/configservice.go index 6ecad10241..e40c7d8847 100644 --- a/internal/resources/providers/awslib/configservice/configservice.go +++ b/internal/resources/providers/awslib/configservice/configservice.go @@ -57,12 +57,12 @@ type Recorder struct { Status []types.ConfigurationRecorderStatus `json:"statuses"` } -func NewProvider(log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client], accountId string) *Provider { +func NewProvider(ctx context.Context, log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client], accountId string) *Provider { f := func(cfg aws.Config) Client { return configSDK.NewFromConfig(cfg) } - m := factory.NewMultiRegionClients(awslib.AllRegionSelector(), cfg, f, log) + m := factory.NewMultiRegionClients(ctx, awslib.AllRegionSelector(), cfg, f, log) return &Provider{ log: log, clients: m.GetMultiRegionsClientMap(), diff --git a/internal/resources/providers/awslib/ec2/ec2.go b/internal/resources/providers/awslib/ec2/ec2.go index 7a68cbd981..54267dd8b7 100644 --- a/internal/resources/providers/awslib/ec2/ec2.go +++ b/internal/resources/providers/awslib/ec2/ec2.go @@ -36,11 +36,11 @@ type ElasticCompute interface { GetRouteTableForSubnet(ctx context.Context, region string, subnetId string, vpcId string) (types.RouteTable, error) } -func NewEC2Provider(log *logp.Logger, awsAccountID string, cfg aws.Config, factory awslib.CrossRegionFactory[Client]) *Provider { +func NewEC2Provider(ctx context.Context, log *logp.Logger, awsAccountID string, cfg aws.Config, factory awslib.CrossRegionFactory[Client]) *Provider { f := func(cfg aws.Config) Client { return ec2.NewFromConfig(cfg) } - m := factory.NewMultiRegionClients(awslib.AllRegionSelector(), cfg, f, log) + m := factory.NewMultiRegionClients(ctx, awslib.AllRegionSelector(), cfg, f, log) return &Provider{ log: log, clients: m.GetMultiRegionsClientMap(), @@ -48,11 +48,11 @@ func NewEC2Provider(log *logp.Logger, awsAccountID string, cfg aws.Config, facto } } -func NewCurrentRegionEC2Provider(log *logp.Logger, awsAccountID string, cfg aws.Config, factory awslib.CrossRegionFactory[Client]) *Provider { +func NewCurrentRegionEC2Provider(ctx context.Context, log *logp.Logger, awsAccountID string, cfg aws.Config, factory awslib.CrossRegionFactory[Client]) *Provider { f := func(cfg aws.Config) Client { return ec2.NewFromConfig(cfg) } - m := factory.NewMultiRegionClients(awslib.CurrentRegionSelector(), cfg, f, log) + m := factory.NewMultiRegionClients(ctx, awslib.CurrentRegionSelector(), cfg, f, log) return &Provider{ log: log, clients: m.GetMultiRegionsClientMap(), diff --git a/internal/resources/providers/awslib/ecr/ecr.go b/internal/resources/providers/awslib/ecr/ecr.go index 7707bac074..39c72aeee1 100644 --- a/internal/resources/providers/awslib/ecr/ecr.go +++ b/internal/resources/providers/awslib/ecr/ecr.go @@ -41,11 +41,11 @@ type Client interface { DescribeRepositories(ctx context.Context, params *ecrClient.DescribeRepositoriesInput, optFns ...func(*ecrClient.Options)) (*ecrClient.DescribeRepositoriesOutput, error) } -func NewEcrProvider(log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client]) *Provider { +func NewEcrProvider(ctx context.Context, log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client]) *Provider { f := func(cfg aws.Config) Client { return ecrClient.NewFromConfig(cfg) } - m := factory.NewMultiRegionClients(awslib.AllRegionSelector(), cfg, f, log) + m := factory.NewMultiRegionClients(ctx, awslib.AllRegionSelector(), cfg, f, log) return &Provider{clients: m.GetMultiRegionsClientMap()} } diff --git a/internal/resources/providers/awslib/elb/elb.go b/internal/resources/providers/awslib/elb/elb.go index ba140bb34a..06d96b1fa5 100644 --- a/internal/resources/providers/awslib/elb/elb.go +++ b/internal/resources/providers/awslib/elb/elb.go @@ -44,11 +44,11 @@ type Provider struct { awsAccountID string } -func NewElbProvider(log *logp.Logger, awsAccountID string, cfg aws.Config, factory awslib.CrossRegionFactory[Client]) *Provider { +func NewElbProvider(ctx context.Context, log *logp.Logger, awsAccountID string, cfg aws.Config, factory awslib.CrossRegionFactory[Client]) *Provider { f := func(cfg aws.Config) Client { return elb.NewFromConfig(cfg) } - m := factory.NewMultiRegionClients(awslib.AllRegionSelector(), cfg, f, log) + m := factory.NewMultiRegionClients(ctx, awslib.AllRegionSelector(), cfg, f, log) return &Provider{ log: log, client: elb.NewFromConfig(cfg), diff --git a/internal/resources/providers/awslib/elb_v2/elb_v2.go b/internal/resources/providers/awslib/elb_v2/elb_v2.go index c6fe0da423..3ccd1abf68 100644 --- a/internal/resources/providers/awslib/elb_v2/elb_v2.go +++ b/internal/resources/providers/awslib/elb_v2/elb_v2.go @@ -41,11 +41,11 @@ type Provider struct { clients map[string]Client } -func NewElbV2Provider(log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client]) *Provider { +func NewElbV2Provider(ctx context.Context, log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client]) *Provider { f := func(cfg aws.Config) Client { return elb.NewFromConfig(cfg) } - m := factory.NewMultiRegionClients(awslib.AllRegionSelector(), cfg, f, log) + m := factory.NewMultiRegionClients(ctx, awslib.AllRegionSelector(), cfg, f, log) return &Provider{ log: log, clients: m.GetMultiRegionsClientMap(), diff --git a/internal/resources/providers/awslib/iam/iam.go b/internal/resources/providers/awslib/iam/iam.go index 8c3439f078..c6f3706047 100644 --- a/internal/resources/providers/awslib/iam/iam.go +++ b/internal/resources/providers/awslib/iam/iam.go @@ -153,13 +153,13 @@ type PolicyDocument struct { Policy string `json:"policy,omitempty"` } -func NewIAMProvider(log *logp.Logger, cfg aws.Config, crossRegionFactory awslib.CrossRegionFactory[AccessAnalyzerClient]) *Provider { +func NewIAMProvider(ctx context.Context, log *logp.Logger, cfg aws.Config, crossRegionFactory awslib.CrossRegionFactory[AccessAnalyzerClient]) *Provider { provider := Provider{ log: log, client: iamsdk.NewFromConfig(cfg), } if crossRegionFactory != nil { - m := crossRegionFactory.NewMultiRegionClients(awslib.AllRegionSelector(), cfg, func(cfg aws.Config) AccessAnalyzerClient { + m := crossRegionFactory.NewMultiRegionClients(ctx, awslib.AllRegionSelector(), cfg, func(cfg aws.Config) AccessAnalyzerClient { return accessanalyzer.NewFromConfig(cfg) }, log) provider.accessAnalyzerClients = m.GetMultiRegionsClientMap() diff --git a/internal/resources/providers/awslib/kms/kms.go b/internal/resources/providers/awslib/kms/kms.go index b77f69ac81..ee5436d622 100644 --- a/internal/resources/providers/awslib/kms/kms.go +++ b/internal/resources/providers/awslib/kms/kms.go @@ -39,11 +39,11 @@ type KMS interface { DescribeSymmetricKeys(ctx context.Context) ([]awslib.AwsResource, error) } -func NewKMSProvider(log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client]) *Provider { +func NewKMSProvider(ctx context.Context, log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client]) *Provider { f := func(cfg aws.Config) Client { return kms.NewFromConfig(cfg) } - m := factory.NewMultiRegionClients(awslib.CurrentRegionSelector(), cfg, f, log) + m := factory.NewMultiRegionClients(ctx, awslib.CurrentRegionSelector(), cfg, f, log) return &Provider{ log: log, diff --git a/internal/resources/providers/awslib/lambda/lambda.go b/internal/resources/providers/awslib/lambda/lambda.go index 3dd4261e9f..56ee7ad7ee 100644 --- a/internal/resources/providers/awslib/lambda/lambda.go +++ b/internal/resources/providers/awslib/lambda/lambda.go @@ -34,11 +34,11 @@ type Lambda interface { ListLayers(context.Context, string, string) ([]awslib.AwsResource, error) } -func NewLambdaProvider(log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client]) *Provider { +func NewLambdaProvider(ctx context.Context, log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client]) *Provider { f := func(cfg aws.Config) Client { return lambda.NewFromConfig(cfg) } - m := factory.NewMultiRegionClients(awslib.AllRegionSelector(), cfg, f, log) + m := factory.NewMultiRegionClients(ctx, awslib.AllRegionSelector(), cfg, f, log) return &Provider{ log: log, clients: m.GetMultiRegionsClientMap(), diff --git a/internal/resources/providers/awslib/mock_cross_region_factory.go b/internal/resources/providers/awslib/mock_cross_region_factory.go index d09f27ffab..a1f876ec6a 100644 --- a/internal/resources/providers/awslib/mock_cross_region_factory.go +++ b/internal/resources/providers/awslib/mock_cross_region_factory.go @@ -20,7 +20,10 @@ package awslib import ( + context "context" + aws "github.com/aws/aws-sdk-go-v2/aws" + logp "github.com/elastic/elastic-agent-libs/logp" mock "github.com/stretchr/testify/mock" @@ -39,13 +42,13 @@ func (_m *MockCrossRegionFactory[T]) EXPECT() *MockCrossRegionFactory_Expecter[T return &MockCrossRegionFactory_Expecter[T]{mock: &_m.Mock} } -// NewMultiRegionClients provides a mock function with given fields: selector, cfg, factory, log -func (_m *MockCrossRegionFactory[T]) NewMultiRegionClients(selector RegionsSelector, cfg aws.Config, factory func(aws.Config) T, log *logp.Logger) CrossRegionFetcher[T] { - ret := _m.Called(selector, cfg, factory, log) +// NewMultiRegionClients provides a mock function with given fields: ctx, selector, cfg, factory, log +func (_m *MockCrossRegionFactory[T]) NewMultiRegionClients(ctx context.Context, selector RegionsSelector, cfg aws.Config, factory func(aws.Config) T, log *logp.Logger) CrossRegionFetcher[T] { + ret := _m.Called(ctx, selector, cfg, factory, log) var r0 CrossRegionFetcher[T] - if rf, ok := ret.Get(0).(func(RegionsSelector, aws.Config, func(aws.Config) T, *logp.Logger) CrossRegionFetcher[T]); ok { - r0 = rf(selector, cfg, factory, log) + if rf, ok := ret.Get(0).(func(context.Context, RegionsSelector, aws.Config, func(aws.Config) T, *logp.Logger) CrossRegionFetcher[T]); ok { + r0 = rf(ctx, selector, cfg, factory, log) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(CrossRegionFetcher[T]) @@ -61,17 +64,18 @@ type MockCrossRegionFactory_NewMultiRegionClients_Call[T interface{}] struct { } // NewMultiRegionClients is a helper method to define mock.On call +// - ctx context.Context // - selector RegionsSelector // - cfg aws.Config // - factory func(aws.Config) T // - log *logp.Logger -func (_e *MockCrossRegionFactory_Expecter[T]) NewMultiRegionClients(selector interface{}, cfg interface{}, factory interface{}, log interface{}) *MockCrossRegionFactory_NewMultiRegionClients_Call[T] { - return &MockCrossRegionFactory_NewMultiRegionClients_Call[T]{Call: _e.mock.On("NewMultiRegionClients", selector, cfg, factory, log)} +func (_e *MockCrossRegionFactory_Expecter[T]) NewMultiRegionClients(ctx interface{}, selector interface{}, cfg interface{}, factory interface{}, log interface{}) *MockCrossRegionFactory_NewMultiRegionClients_Call[T] { + return &MockCrossRegionFactory_NewMultiRegionClients_Call[T]{Call: _e.mock.On("NewMultiRegionClients", ctx, selector, cfg, factory, log)} } -func (_c *MockCrossRegionFactory_NewMultiRegionClients_Call[T]) Run(run func(selector RegionsSelector, cfg aws.Config, factory func(aws.Config) T, log *logp.Logger)) *MockCrossRegionFactory_NewMultiRegionClients_Call[T] { +func (_c *MockCrossRegionFactory_NewMultiRegionClients_Call[T]) Run(run func(ctx context.Context, selector RegionsSelector, cfg aws.Config, factory func(aws.Config) T, log *logp.Logger)) *MockCrossRegionFactory_NewMultiRegionClients_Call[T] { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(RegionsSelector), args[1].(aws.Config), args[2].(func(aws.Config) T), args[3].(*logp.Logger)) + run(args[0].(context.Context), args[1].(RegionsSelector), args[2].(aws.Config), args[3].(func(aws.Config) T), args[4].(*logp.Logger)) }) return _c } @@ -81,7 +85,7 @@ func (_c *MockCrossRegionFactory_NewMultiRegionClients_Call[T]) Return(_a0 Cross return _c } -func (_c *MockCrossRegionFactory_NewMultiRegionClients_Call[T]) RunAndReturn(run func(RegionsSelector, aws.Config, func(aws.Config) T, *logp.Logger) CrossRegionFetcher[T]) *MockCrossRegionFactory_NewMultiRegionClients_Call[T] { +func (_c *MockCrossRegionFactory_NewMultiRegionClients_Call[T]) RunAndReturn(run func(context.Context, RegionsSelector, aws.Config, func(aws.Config) T, *logp.Logger) CrossRegionFetcher[T]) *MockCrossRegionFactory_NewMultiRegionClients_Call[T] { _c.Call.Return(run) return _c } diff --git a/internal/resources/providers/awslib/multi_region.go b/internal/resources/providers/awslib/multi_region.go index 14dc5831ff..b4c9f75b60 100644 --- a/internal/resources/providers/awslib/multi_region.go +++ b/internal/resources/providers/awslib/multi_region.go @@ -37,7 +37,7 @@ type CrossRegionFetcher[T any] interface { } type CrossRegionFactory[T any] interface { - NewMultiRegionClients(selector RegionsSelector, cfg aws.Config, factory func(cfg aws.Config) T, log *logp.Logger) CrossRegionFetcher[T] + NewMultiRegionClients(ctx context.Context, selector RegionsSelector, cfg aws.Config, factory func(cfg aws.Config) T, log *logp.Logger) CrossRegionFetcher[T] } type ( @@ -48,9 +48,9 @@ type ( ) // NewMultiRegionClients is a utility function that is used to create a map of client instances of a given type T for multiple regions. -func (w *MultiRegionClientFactory[T]) NewMultiRegionClients(selector RegionsSelector, cfg aws.Config, factory func(cfg aws.Config) T, log *logp.Logger) CrossRegionFetcher[T] { +func (w *MultiRegionClientFactory[T]) NewMultiRegionClients(ctx context.Context, selector RegionsSelector, cfg aws.Config, factory func(cfg aws.Config) T, log *logp.Logger) CrossRegionFetcher[T] { clientsMap := make(map[string]T, 0) - regionList, err := selector.Regions(context.TODO(), cfg) + regionList, err := selector.Regions(ctx, cfg) if err != nil { log.Errorf("Region '%s' selected after failure to retrieve aws regions: %v", cfg.Region, err) regionList = []string{cfg.Region} diff --git a/internal/resources/providers/awslib/multi_region_test.go b/internal/resources/providers/awslib/multi_region_test.go index 3bcea0cc0b..0bc34921ed 100644 --- a/internal/resources/providers/awslib/multi_region_test.go +++ b/internal/resources/providers/awslib/multi_region_test.go @@ -92,7 +92,7 @@ func TestMultiRegionWrapper_NewMultiRegionClients(t *testing.T) { } t.Run(tt.name, func(t *testing.T) { - multiRegionClients := wrapper.NewMultiRegionClients(tt.args.selector(), tt.args.cfg, factory, tt.args.log) + multiRegionClients := wrapper.NewMultiRegionClients(context.Background(), tt.args.selector(), tt.args.cfg, factory, tt.args.log) clients := multiRegionClients.GetMultiRegionsClientMap() if !reflect.DeepEqual(clients, tt.want) { t.Errorf("GetRegions() got = %v, want %v", clients, tt.want) diff --git a/internal/resources/providers/awslib/rds/provider.go b/internal/resources/providers/awslib/rds/provider.go index 4da7137779..3c8d92472e 100644 --- a/internal/resources/providers/awslib/rds/provider.go +++ b/internal/resources/providers/awslib/rds/provider.go @@ -31,11 +31,11 @@ import ( ec2Provider "github.com/elastic/cloudbeat/internal/resources/providers/awslib/ec2" ) -func NewProvider(log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client], ec2Provider ec2Provider.ElasticCompute) *Provider { +func NewProvider(ctx context.Context, log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client], ec2Provider ec2Provider.ElasticCompute) *Provider { f := func(cfg aws.Config) Client { return rds.NewFromConfig(cfg) } - m := factory.NewMultiRegionClients(awslib.AllRegionSelector(), cfg, f, log) + m := factory.NewMultiRegionClients(ctx, awslib.AllRegionSelector(), cfg, f, log) return &Provider{ log: log, clients: m.GetMultiRegionsClientMap(), diff --git a/internal/resources/providers/awslib/s3/provider.go b/internal/resources/providers/awslib/s3/provider.go index d75f290abe..3b667aa997 100644 --- a/internal/resources/providers/awslib/s3/provider.go +++ b/internal/resources/providers/awslib/s3/provider.go @@ -42,11 +42,11 @@ const ( NoPublicAccessBlockConfigurationCode = "NoSuchPublicAccessBlockConfiguration" ) -func NewProvider(log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client], accountId string) *Provider { +func NewProvider(ctx context.Context, log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client], accountId string) *Provider { f := func(cfg aws.Config) Client { return s3Client.NewFromConfig(cfg) } - m := factory.NewMultiRegionClients(awslib.AllRegionSelector(), cfg, f, log) + m := factory.NewMultiRegionClients(ctx, awslib.AllRegionSelector(), cfg, f, log) controlClient := s3control.NewFromConfig(cfg) diff --git a/internal/resources/providers/awslib/securityhub/provider.go b/internal/resources/providers/awslib/securityhub/provider.go index 5ab367ad5e..c2e8a7911d 100644 --- a/internal/resources/providers/awslib/securityhub/provider.go +++ b/internal/resources/providers/awslib/securityhub/provider.go @@ -39,11 +39,11 @@ type ( } ) -func NewProvider(log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client], accountId string) *Provider { +func NewProvider(ctx context.Context, log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client], accountId string) *Provider { f := func(cfg aws.Config) Client { return securityhub.NewFromConfig(cfg) } - m := factory.NewMultiRegionClients(awslib.AllRegionSelector(), cfg, f, log) + m := factory.NewMultiRegionClients(ctx, awslib.AllRegionSelector(), cfg, f, log) return &Provider{ accountId: accountId, clients: m.GetMultiRegionsClientMap(), diff --git a/internal/resources/providers/awslib/sns/sns.go b/internal/resources/providers/awslib/sns/sns.go index 83763a325c..c2201ce21b 100644 --- a/internal/resources/providers/awslib/sns/sns.go +++ b/internal/resources/providers/awslib/sns/sns.go @@ -34,11 +34,11 @@ type SNS interface { ListTopicsWithSubscriptions(ctx context.Context) ([]awslib.AwsResource, error) } -func NewSNSProvider(log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client]) *Provider { +func NewSNSProvider(ctx context.Context, log *logp.Logger, cfg aws.Config, factory awslib.CrossRegionFactory[Client]) *Provider { f := func(cfg aws.Config) Client { return sns.NewFromConfig(cfg) } - m := factory.NewMultiRegionClients(awslib.AllRegionSelector(), cfg, f, log) + m := factory.NewMultiRegionClients(ctx, awslib.AllRegionSelector(), cfg, f, log) return &Provider{ log: log, clients: m.GetMultiRegionsClientMap(), diff --git a/internal/vulnerability/runner.go b/internal/vulnerability/runner.go index 63475c5c1f..b8fe2dae3e 100644 --- a/internal/vulnerability/runner.go +++ b/internal/vulnerability/runner.go @@ -35,8 +35,7 @@ type VulnerabilityRunner struct { runner artifact.Runner } -func NewVulnerabilityRunner(log *logp.Logger) (VulnerabilityRunner, error) { - ctx := context.Background() +func NewVulnerabilityRunner(ctx context.Context, log *logp.Logger) (VulnerabilityRunner, error) { log.Debug("NewVulnerabilityRunner: New") if err := clearTrivyCache(ctx, log); err != nil { diff --git a/internal/vulnerability/runner_test.go b/internal/vulnerability/runner_test.go index 32fdb48ba7..6a6b4d8faa 100644 --- a/internal/vulnerability/runner_test.go +++ b/internal/vulnerability/runner_test.go @@ -32,7 +32,7 @@ func TestNewVulnerabilityRunner(t *testing.T) { testhelper.SkipLong(t) log := testhelper.NewLogger(t) - runner, err := NewVulnerabilityRunner(log) + runner, err := NewVulnerabilityRunner(context.Background(), log) defer func() { require.NoError(t, runner.runner.Close(context.Background())) }() @@ -46,7 +46,7 @@ func TestGetRunner(t *testing.T) { testhelper.SkipLong(t) log := testhelper.NewLogger(t) - runner, err := NewVulnerabilityRunner(log) + runner, err := NewVulnerabilityRunner(context.Background(), log) defer func() { require.NoError(t, runner.runner.Close(context.Background())) }() diff --git a/internal/vulnerability/worker.go b/internal/vulnerability/worker.go index 27325a78df..494b113d84 100644 --- a/internal/vulnerability/worker.go +++ b/internal/vulnerability/worker.go @@ -54,18 +54,18 @@ type workerProvider interface { DeleteSnapshot(ctx context.Context, snapshot ec2.EBSSnapshot) error } -func NewVulnerabilityWorker(log *logp.Logger, c *config.Config, bdp dataprovider.CommonDataProvider, cdp dataprovider.ElasticCommonDataProvider) (*VulnerabilityWorker, error) { +func NewVulnerabilityWorker(ctx context.Context, log *logp.Logger, c *config.Config, bdp dataprovider.CommonDataProvider, cdp dataprovider.ElasticCommonDataProvider) (*VulnerabilityWorker, error) { log.Debug("VulnerabilityWorker: New") awsConfig, err := awslib.InitializeAWSConfig(c.CloudConfig.Aws.Cred) if err != nil { return nil, fmt.Errorf("VulnerabilityWorker: failed to initialize AWS credentials: %w", err) } - provider := ec2.NewCurrentRegionEC2Provider(log, "", *awsConfig, &awslib.MultiRegionClientFactory[ec2.Client]{}) + provider := ec2.NewCurrentRegionEC2Provider(ctx, log, "", *awsConfig, &awslib.MultiRegionClientFactory[ec2.Client]{}) fetcher := NewVulnerabilityFetcher(log, provider) replicator := NewVulnerabilityReplicator(log, provider) verifier := NewVulnerabilityVerifier(log, provider) - runner, err := NewVulnerabilityRunner(log) + runner, err := NewVulnerabilityRunner(ctx, log) if err != nil { return nil, fmt.Errorf("VulnerabilityWorker: could not get init NewVulnerabilityRunner: %w", err) } diff --git a/internal/vulnerability/worker_test.go b/internal/vulnerability/worker_test.go index 50fa737bf7..2847bd0322 100644 --- a/internal/vulnerability/worker_test.go +++ b/internal/vulnerability/worker_test.go @@ -86,7 +86,7 @@ func TestNewVulnerabilityWorker(t *testing.T) { bdp := &dataprovider.MockCommonDataProvider{} cdp := &dataprovider.MockElasticCommonDataProvider{} - worker, err := NewVulnerabilityWorker(log, c, bdp, cdp) + worker, err := NewVulnerabilityWorker(context.Background(), log, c, bdp, cdp) defer goleak.VerifyNone(t, goleak.IgnoreCurrent(), goleak.Cleanup(func(_ int) { worker.runner.GetRunner().Close(context.Background()) @@ -129,7 +129,7 @@ func TestVulnerabilityWorker_Run(t *testing.T) { } // Not used runner, just to increase coverage - runner, err := NewVulnerabilityRunner(log) + runner, err := NewVulnerabilityRunner(context.Background(), log) require.NoError(t, err) runner.GetRunner().Close(context.Background())