Skip to content

Commit

Permalink
Fetch relevant regions for agents endpoint (#51814)
Browse files Browse the repository at this point in the history
  • Loading branch information
michellescripts committed Feb 7, 2025
1 parent b65f52f commit 3668763
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 13 deletions.
9 changes: 9 additions & 0 deletions lib/utils/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ func (s Set[T]) Subtract(other Set[T]) Set[T] {
return s
}

// Intersection updates the set to contain the similarity between the set and `other`
func (s Set[T]) Intersection(other Set[T]) {
for b := range s {
if !other.Contains(b) {
s.Remove(b)
}
}
}

// Elements returns the elements in the set. Order of the elements is undefined.
//
// NOTE: Due to the underlying map type, a set can be naturally ranged over like
Expand Down
48 changes: 48 additions & 0 deletions lib/utils/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,54 @@ func TestSet(t *testing.T) {
})
}
})

t.Run("intersection", func(t *testing.T) {
testCases := []struct {
name string
a []string
b []string
expected []string
}{
{
name: "empty intersection empty",
expected: []string{},
},
{
name: "empty intersection populated",
b: []string{"alpha", "beta"},
expected: []string{},
},
{
name: "populated intersection empty",
a: []string{"alpha", "beta"},
expected: []string{},
},
{
name: "populated intersection populated",
a: []string{"alpha", "beta", "gamma", "delta", "epsilon"},
b: []string{"beta", "eta", "zeta", "epsilon"},
expected: []string{"beta", "epsilon"},
},
}

for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {
// GIVEN a pair of sets
a := NewSet(test.a...)
b := NewSet(test.b...)
bItems := b.Elements()

// WHEN I take the intersection of both sets
a.Intersection(b)

// EXPECT that the target set is updated with the intersection of both sets.
require.ElementsMatch(t, a.Elements(), test.expected)

// EXPECT also that the second set is unchanged
require.ElementsMatch(t, b.Elements(), bItems)
})
}
})
}

func TestSetTransform(t *testing.T) {
Expand Down
43 changes: 34 additions & 9 deletions lib/web/integrations_awsoidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,19 +280,29 @@ func (h *Handler) awsOIDCListDeployedDatabaseService(w http.ResponseWriter, r *h
return nil, trace.Wrap(err)
}

services, err := listDeployedDatabaseServices(ctx, h.logger, integrationName, regions, clt.IntegrationAWSOIDCClient())
if len(regions) == 0 {
// return an empty list if there are no relevant regions in which to fetch database services
return ui.AWSOIDCListDeployedDatabaseServiceResponse{}, nil
}

s, err := listDeployedDatabaseServices(ctx, h.logger, integrationName, regions, clt.IntegrationAWSOIDCClient())
if err != nil {
return nil, trace.Wrap(err)
}

return ui.AWSOIDCListDeployedDatabaseServiceResponse{
Services: services,
Services: s,
}, nil
}

func extractAWSRegionsFromQuery(r *http.Request) ([]string, error) {
var ret []string
for _, region := range r.URL.Query()["regions"] {
if region == "" {
// no regions passed in params, empty key
return ret, nil
}

if err := aws.IsValidRegion(region); err != nil {
return nil, trace.BadParameter("invalid region %s", region)
}
Expand All @@ -302,21 +312,36 @@ func extractAWSRegionsFromQuery(r *http.Request) ([]string, error) {
return ret, nil
}

// regionsForListingDeployedDatabaseService fetches relevant AWS regions and parses the regions query param.
// If no query params are present, relevant regions are returned.
// If query params are present, we take the intersection of relevant regions and filter regions to avoid requesting
// services which have not been set up which would result in an error.
// ex: relevant = ["us-west-1"]; params = []; returns ["us-west-1"]
// ex: relevant = []; params = ["us-west-1"]; returns []
// ex: relevant = ["us-west-1"]; params = ["us-west-1"]; returns ["us-west-1"]
// ex: relevant = ["us-west-1"]; params = ["us-west-2"]; returns []
func regionsForListingDeployedDatabaseService(ctx context.Context, r *http.Request, authClient databaseGetter, discoveryConfigsClient discoveryConfigLister) ([]string, error) {
// use the auth client & discover configs to collect a list of relevant AWS regions
relevant, err := fetchRelevantAWSRegions(ctx, authClient, discoveryConfigsClient)
if err != nil {
return nil, trace.Wrap(err)
}

if r.URL.Query().Has("regions") {
regions, err := extractAWSRegionsFromQuery(r)
params, err := extractAWSRegionsFromQuery(r)
if err != nil {
return nil, trace.Wrap(err)
}
return regions, err
}

regions, err := fetchRelevantAWSRegions(ctx, authClient, discoveryConfigsClient)
if err != nil {
return nil, trace.Wrap(err)
if len(params) > 0 {
a := libutils.NewSet(relevant...)
b := libutils.NewSet(params...)
a.Intersection(b)
return a.Elements(), nil
}
}

return regions, nil
return relevant, nil
}

type databaseGetter interface {
Expand Down
36 changes: 32 additions & 4 deletions lib/web/integrations_awsoidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1465,10 +1465,16 @@ func dummyDeployedDatabaseServices(count int, command []string) []*integrationv1
func TestRegionsForListingDeployedDatabaseService(t *testing.T) {
ctx := context.Background()

t.Run("regions query param is used instead of parsing internal resources", func(t *testing.T) {
t.Run("regions query param, returns nil if no internal resources match", func(t *testing.T) {
clt := &mockRelevantAWSRegionsClient{
databaseServices: &proto.ListResourcesResponse{
Resources: []*proto.PaginatedResource{},
Resources: []*proto.PaginatedResource{{Resource: &proto.PaginatedResource_DatabaseService{
DatabaseService: &types.DatabaseServiceV1{Spec: types.DatabaseServiceSpecV1{
ResourceMatchers: []*types.DatabaseResourceMatcher{
{Labels: &types.Labels{"region": []string{"af-south-1"}}},
},
}},
}}},
},
databases: make([]types.Database, 0),
discoveryConfigs: make([]*discoveryconfig.DiscoveryConfig, 0),
Expand All @@ -1478,7 +1484,28 @@ func TestRegionsForListingDeployedDatabaseService(t *testing.T) {
}
gotRegions, err := regionsForListingDeployedDatabaseService(ctx, &r, clt, clt)
require.NoError(t, err)
require.ElementsMatch(t, []string{"us-east-1", "us-east-2"}, gotRegions)
require.ElementsMatch(t, nil, gotRegions)
})

t.Run("regions query param, returns matches in internal resources", func(t *testing.T) {
clt := &mockRelevantAWSRegionsClient{
databaseServices: &proto.ListResourcesResponse{
Resources: []*proto.PaginatedResource{{Resource: &proto.PaginatedResource_DatabaseService{
DatabaseService: &types.DatabaseServiceV1{Spec: types.DatabaseServiceSpecV1{
ResourceMatchers: []*types.DatabaseResourceMatcher{
{Labels: &types.Labels{"region": []string{"af-south-1"}}}},
}},
}}},
},
databases: make([]types.Database, 0),
discoveryConfigs: make([]*discoveryconfig.DiscoveryConfig, 0),
}
r := http.Request{
URL: &url.URL{RawQuery: "regions=af-south-1&regions=us-east-2"},
}
gotRegions, err := regionsForListingDeployedDatabaseService(ctx, &r, clt, clt)
require.NoError(t, err)
require.ElementsMatch(t, []string{"af-south-1"}, gotRegions)
})

t.Run("fallbacks to internal resources when query param is not present", func(t *testing.T) {
Expand All @@ -1497,13 +1524,14 @@ func TestRegionsForListingDeployedDatabaseService(t *testing.T) {
discoveryConfigs: make([]*discoveryconfig.DiscoveryConfig, 0),
}
r := http.Request{
URL: &url.URL{},
URL: &url.URL{RawQuery: "regions="},
}
gotRegions, err := regionsForListingDeployedDatabaseService(ctx, &r, clt, clt)
require.NoError(t, err)
require.ElementsMatch(t, []string{"us-east-1", "us-east-2"}, gotRegions)
})
}

func TestFetchRelevantAWSRegions(t *testing.T) {
ctx := context.Background()

Expand Down

0 comments on commit 3668763

Please sign in to comment.