diff --git a/lib/utils/set.go b/lib/utils/set.go index b16d1102247b2..685a0b28a5de2 100644 --- a/lib/utils/set.go +++ b/lib/utils/set.go @@ -84,6 +84,15 @@ func (s Set[T]) Subtract(other Set[T]) Set[T] { return s } +// Intersection updates the set to contain the similarity between the set and `other` +func (s Set[T]) Intersection(other Set[T]) { + for b := range s { + if !other.Contains(b) { + s.Remove(b) + } + } +} + // Elements returns the elements in the set. Order of the elements is undefined. // // NOTE: Due to the underlying map type, a set can be naturally ranged over like diff --git a/lib/utils/set_test.go b/lib/utils/set_test.go index e2f3357a5e9bb..dd3a2382f8146 100644 --- a/lib/utils/set_test.go +++ b/lib/utils/set_test.go @@ -279,6 +279,54 @@ func TestSet(t *testing.T) { }) } }) + + t.Run("intersection", func(t *testing.T) { + testCases := []struct { + name string + a []string + b []string + expected []string + }{ + { + name: "empty intersection empty", + expected: []string{}, + }, + { + name: "empty intersection populated", + b: []string{"alpha", "beta"}, + expected: []string{}, + }, + { + name: "populated intersection empty", + a: []string{"alpha", "beta"}, + expected: []string{}, + }, + { + name: "populated intersection populated", + a: []string{"alpha", "beta", "gamma", "delta", "epsilon"}, + b: []string{"beta", "eta", "zeta", "epsilon"}, + expected: []string{"beta", "epsilon"}, + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + // GIVEN a pair of sets + a := NewSet(test.a...) + b := NewSet(test.b...) + bItems := b.Elements() + + // WHEN I take the intersection of both sets + a.Intersection(b) + + // EXPECT that the target set is updated with the intersection of both sets. + require.ElementsMatch(t, a.Elements(), test.expected) + + // EXPECT also that the second set is unchanged + require.ElementsMatch(t, b.Elements(), bItems) + }) + } + }) } func TestSetTransform(t *testing.T) { diff --git a/lib/web/integrations_awsoidc.go b/lib/web/integrations_awsoidc.go index ce20aaff25c6c..e85e05e8ed99b 100644 --- a/lib/web/integrations_awsoidc.go +++ b/lib/web/integrations_awsoidc.go @@ -280,13 +280,18 @@ func (h *Handler) awsOIDCListDeployedDatabaseService(w http.ResponseWriter, r *h return nil, trace.Wrap(err) } - services, err := listDeployedDatabaseServices(ctx, h.logger, integrationName, regions, clt.IntegrationAWSOIDCClient()) + if len(regions) == 0 { + // return an empty list if there are no relevant regions in which to fetch database services + return ui.AWSOIDCListDeployedDatabaseServiceResponse{}, nil + } + + s, err := listDeployedDatabaseServices(ctx, h.logger, integrationName, regions, clt.IntegrationAWSOIDCClient()) if err != nil { return nil, trace.Wrap(err) } return ui.AWSOIDCListDeployedDatabaseServiceResponse{ - Services: services, + Services: s, }, nil } @@ -302,21 +307,33 @@ func extractAWSRegionsFromQuery(r *http.Request) ([]string, error) { return ret, nil } +// regionsForListingDeployedDatabaseService fetches relevant aws regions, parse the regions query param, and returns the overlap. +// If no query params are present, relevant regions are returned. +// ex: relevant = ["us-west-1"]; params = []; returns ["us-west-1"] +// ex: relevant = []; params = ["us-west-1"]; returns [] +// ex: relevant = ["us-west-1"]; params = ["us-west-1"]; returns ["us-west-1"] +// ex: relevant = ["us-west-1"]; params = ["us-west-2"]; returns [] func regionsForListingDeployedDatabaseService(ctx context.Context, r *http.Request, authClient databaseGetter, discoveryConfigsClient discoveryConfigLister) ([]string, error) { + relevant, err := fetchRelevantAWSRegions(ctx, authClient, discoveryConfigsClient) + if err != nil { + return nil, trace.Wrap(err) + } + if r.URL.Query().Has("regions") { - regions, err := extractAWSRegionsFromQuery(r) + params, err := extractAWSRegionsFromQuery(r) if err != nil { return nil, trace.Wrap(err) } - return regions, err - } - regions, err := fetchRelevantAWSRegions(ctx, authClient, discoveryConfigsClient) - if err != nil { - return nil, trace.Wrap(err) + if len(params) > 0 { + a := libutils.NewSet(relevant...) + b := libutils.NewSet(params...) + a.Intersection(b) + return a.Elements(), nil + } } - return regions, nil + return relevant, nil } type databaseGetter interface { diff --git a/lib/web/integrations_awsoidc_test.go b/lib/web/integrations_awsoidc_test.go index 62c3ca8ce692b..dca45e1888d5e 100644 --- a/lib/web/integrations_awsoidc_test.go +++ b/lib/web/integrations_awsoidc_test.go @@ -1330,10 +1330,16 @@ func dummyDeployedDatabaseServices(count int, command []string) []*integrationv1 func TestRegionsForListingDeployedDatabaseService(t *testing.T) { ctx := context.Background() - t.Run("regions query param is used instead of parsing internal resources", func(t *testing.T) { + t.Run("regions query param, returns nil if no internal resources match", func(t *testing.T) { clt := &mockRelevantAWSRegionsClient{ databaseServices: &proto.ListResourcesResponse{ - Resources: []*proto.PaginatedResource{}, + Resources: []*proto.PaginatedResource{{Resource: &proto.PaginatedResource_DatabaseService{ + DatabaseService: &types.DatabaseServiceV1{Spec: types.DatabaseServiceSpecV1{ + ResourceMatchers: []*types.DatabaseResourceMatcher{ + {Labels: &types.Labels{"region": []string{"af-south-1"}}}, + }, + }}, + }}}, }, databases: make([]types.Database, 0), discoveryConfigs: make([]*discoveryconfig.DiscoveryConfig, 0), @@ -1343,7 +1349,28 @@ func TestRegionsForListingDeployedDatabaseService(t *testing.T) { } gotRegions, err := regionsForListingDeployedDatabaseService(ctx, &r, clt, clt) require.NoError(t, err) - require.ElementsMatch(t, []string{"us-east-1", "us-east-2"}, gotRegions) + require.ElementsMatch(t, nil, gotRegions) + }) + + t.Run("regions query param, returns matches in internal resources", func(t *testing.T) { + clt := &mockRelevantAWSRegionsClient{ + databaseServices: &proto.ListResourcesResponse{ + Resources: []*proto.PaginatedResource{{Resource: &proto.PaginatedResource_DatabaseService{ + DatabaseService: &types.DatabaseServiceV1{Spec: types.DatabaseServiceSpecV1{ + ResourceMatchers: []*types.DatabaseResourceMatcher{ + {Labels: &types.Labels{"region": []string{"af-south-1"}}}}, + }}, + }}}, + }, + databases: make([]types.Database, 0), + discoveryConfigs: make([]*discoveryconfig.DiscoveryConfig, 0), + } + r := http.Request{ + URL: &url.URL{RawQuery: "regions=af-south-1®ions=us-east-2"}, + } + gotRegions, err := regionsForListingDeployedDatabaseService(ctx, &r, clt, clt) + require.NoError(t, err) + require.ElementsMatch(t, []string{"af-south-1"}, gotRegions) }) t.Run("fallbacks to internal resources when query param is not present", func(t *testing.T) { @@ -1362,13 +1389,14 @@ func TestRegionsForListingDeployedDatabaseService(t *testing.T) { discoveryConfigs: make([]*discoveryconfig.DiscoveryConfig, 0), } r := http.Request{ - URL: &url.URL{}, + URL: &url.URL{RawQuery: "regions="}, } gotRegions, err := regionsForListingDeployedDatabaseService(ctx, &r, clt, clt) require.NoError(t, err) require.ElementsMatch(t, []string{"us-east-1", "us-east-2"}, gotRegions) }) } + func TestFetchRelevantAWSRegions(t *testing.T) { ctx := context.Background()