diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index 7bb586693f6d..5deb0d2b5bf5 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -31,6 +31,8 @@ const ( InstanceTypesAndZonesTTL = 5 * time.Minute // InstanceProfileTTL is the time before we refresh checking instance profile existence at IAM InstanceProfileTTL = 15 * time.Minute + // InstanceTypesAndZonesTTL is the time before we remove subnets that have been refreshed + SubnetTTL = 10 * time.Minute ) const ( diff --git a/pkg/controllers/nodeclass/status/subnet_test.go b/pkg/controllers/nodeclass/status/subnet_test.go index ba37f1d3b11a..050b3b59367f 100644 --- a/pkg/controllers/nodeclass/status/subnet_test.go +++ b/pkg/controllers/nodeclass/status/subnet_test.go @@ -53,6 +53,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { }) It("Should update EC2NodeClass status for Subnets", func() { ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) nodeClass = ExpectExists(ctx, env.Client, nodeClass) Expect(nodeClass.Status.Subnets).To(Equal([]v1beta1.Subnet{ @@ -81,6 +82,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { {SubnetId: aws.String("subnet-test3"), AvailabilityZone: aws.String("test-zone-1c"), AvailableIpAddressCount: aws.Int64(50)}, }}) ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) nodeClass = ExpectExists(ctx, env.Client, nodeClass) Expect(nodeClass.Status.Subnets).To(Equal([]v1beta1.Subnet{ @@ -108,6 +110,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { }, } ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) nodeClass = ExpectExists(ctx, env.Client, nodeClass) Expect(nodeClass.Status.Subnets).To(Equal([]v1beta1.Subnet{ @@ -128,6 +131,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { }, } ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) nodeClass = ExpectExists(ctx, env.Client, nodeClass) Expect(nodeClass.Status.Subnets).To(Equal([]v1beta1.Subnet{ @@ -139,6 +143,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { }) It("Should update Subnet status when the Subnet selector gets updated by tags", func() { ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) nodeClass = ExpectExists(ctx, env.Client, nodeClass) Expect(nodeClass.Status.Subnets).To(Equal([]v1beta1.Subnet{ @@ -173,6 +178,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { }, } ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) nodeClass = ExpectExists(ctx, env.Client, nodeClass) Expect(nodeClass.Status.Subnets).To(Equal([]v1beta1.Subnet{ @@ -188,6 +194,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { }) It("Should update Subnet status when the Subnet selector gets updated by ids", func() { ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) nodeClass = ExpectExists(ctx, env.Client, nodeClass) Expect(nodeClass.Status.Subnets).To(Equal([]v1beta1.Subnet{ @@ -215,6 +222,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { }, } ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) nodeClass = ExpectExists(ctx, env.Client, nodeClass) Expect(nodeClass.Status.Subnets).To(Equal([]v1beta1.Subnet{ @@ -231,12 +239,14 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { }, } ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) ExpectReconcileFailed(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) nodeClass = ExpectExists(ctx, env.Client, nodeClass) Expect(nodeClass.Status.Subnets).To(BeNil()) }) It("Should not resolve a invalid selectors for an updated subnet selector", func() { ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) nodeClass = ExpectExists(ctx, env.Client, nodeClass) Expect(nodeClass.Status.Subnets).To(Equal([]v1beta1.Subnet{ @@ -264,6 +274,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { }, } ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) ExpectReconcileFailed(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) nodeClass = ExpectExists(ctx, env.Client, nodeClass) Expect(nodeClass.Status.Subnets).To(BeNil()) diff --git a/pkg/controllers/nodeclass/status/suite_test.go b/pkg/controllers/nodeclass/status/suite_test.go index 545c9ba6cfd5..a7f97fd9a88b 100644 --- a/pkg/controllers/nodeclass/status/suite_test.go +++ b/pkg/controllers/nodeclass/status/suite_test.go @@ -73,6 +73,7 @@ var _ = BeforeEach(func() { ctx = coreoptions.ToContext(ctx, coretest.Options()) nodeClass = test.EC2NodeClass() awsEnv.Reset() + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) }) var _ = AfterEach(func() { diff --git a/pkg/controllers/providers/subnet/controller.go b/pkg/controllers/providers/subnet/controller.go new file mode 100644 index 000000000000..fa866859ba80 --- /dev/null +++ b/pkg/controllers/providers/subnet/controller.go @@ -0,0 +1,59 @@ +/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package subnet + +import ( + "context" + "time" + + controllerruntime "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/manager" + "sigs.k8s.io/controller-runtime/pkg/reconcile" + corecontroller "sigs.k8s.io/karpenter/pkg/operator/controller" + + "sigs.k8s.io/controller-runtime/pkg/controller" + + "github.com/aws/karpenter-provider-aws/pkg/apis/v1beta1" + "github.com/aws/karpenter-provider-aws/pkg/providers/subnet" +) + +var _ corecontroller.TypedController[*v1beta1.EC2NodeClass] = (*Controller)(nil) + +type Controller struct { + subnetProvider subnet.Provider +} + +func NewController(subnetProvider subnet.Provider) *Controller { + return &Controller{ + subnetProvider: subnetProvider, + } +} + +func (c *Controller) Reconcile(ctx context.Context, nodeClass *v1beta1.EC2NodeClass) (reconcile.Result, error) { + return reconcile.Result{RequeueAfter: 5 * time.Minute}, c.subnetProvider.UpdateSubnets(ctx, nodeClass) +} + +func (c *Controller) Name() string { + return "subnet" +} + +func (c *Controller) Builder(_ context.Context, m manager.Manager) corecontroller.Builder { + return corecontroller.Adapt(controllerruntime. + NewControllerManagedBy(m). + For(&v1beta1.EC2NodeClass{}). + WithOptions(controller.Options{ + MaxConcurrentReconciles: 10, + })) +} diff --git a/pkg/controllers/providers/subnet/suite_test.go b/pkg/controllers/providers/subnet/suite_test.go new file mode 100644 index 000000000000..2424ae785025 --- /dev/null +++ b/pkg/controllers/providers/subnet/suite_test.go @@ -0,0 +1,59 @@ +/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package subnet_test + +import ( + "context" + "testing" + + coreoptions "sigs.k8s.io/karpenter/pkg/operator/options" + "sigs.k8s.io/karpenter/pkg/operator/scheme" + coretest "sigs.k8s.io/karpenter/pkg/test" + + "github.com/aws/karpenter-provider-aws/pkg/apis" + controllerspricing "github.com/aws/karpenter-provider-aws/pkg/controllers/providers/pricing" + "github.com/aws/karpenter-provider-aws/pkg/operator/options" + "github.com/aws/karpenter-provider-aws/pkg/test" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + . "knative.dev/pkg/logging/testing" +) + +var ctx context.Context +var stop context.CancelFunc +var env *coretest.Environment +var awsEnv *test.Environment +var controller *controllerspricing.Controller + +func TestAWS(t *testing.T) { + ctx = TestContextWithLogger(t) + RegisterFailHandler(Fail) + RunSpecs(t, "Subnet") +} + +var _ = BeforeSuite(func() { + env = coretest.NewEnvironment(scheme.Scheme, coretest.WithCRDs(apis.CRDs...)) + ctx = coreoptions.ToContext(ctx, coretest.Options()) + ctx = options.ToContext(ctx, test.Options()) + ctx, stop = context.WithCancel(ctx) + awsEnv = test.NewEnvironment(ctx, env) + controller = controllerspricing.NewController(awsEnv.PricingProvider) +}) + +var _ = AfterSuite(func() { + stop() + Expect(env.Stop()).To(Succeed(), "Failed to stop environment") +}) diff --git a/pkg/operator/operator.go b/pkg/operator/operator.go index 39be55788322..c2dc090f805f 100644 --- a/pkg/operator/operator.go +++ b/pkg/operator/operator.go @@ -132,7 +132,7 @@ func NewOperator(ctx context.Context, operator *operator.Operator) (context.Cont } unavailableOfferingsCache := awscache.NewUnavailableOfferings() - subnetProvider := subnet.NewDefaultProvider(ec2api, cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval)) + subnetProvider := subnet.NewDefaultProvider(ec2api, cache.New(awscache.SubnetTTL, awscache.DefaultCleanupInterval)) securityGroupProvider := securitygroup.NewDefaultProvider(ec2api, cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval)) instanceProfileProvider := instanceprofile.NewDefaultProvider(*sess.Config.Region, iam.New(sess), cache.New(awscache.InstanceProfileTTL, awscache.DefaultCleanupInterval)) pricingProvider := pricing.NewDefaultProvider( diff --git a/pkg/providers/instance/suite_test.go b/pkg/providers/instance/suite_test.go index 45d3504bafb6..a4f9682d38ca 100644 --- a/pkg/providers/instance/suite_test.go +++ b/pkg/providers/instance/suite_test.go @@ -115,6 +115,7 @@ var _ = Describe("InstanceProvider", func() { {CapacityType: corev1beta1.CapacityTypeSpot, InstanceType: "m5.xlarge", Zone: "test-zone-1a"}, {CapacityType: corev1beta1.CapacityTypeSpot, InstanceType: "m5.xlarge", Zone: "test-zone-1b"}, }) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) instanceTypes, err := cloudProvider.GetInstanceTypes(ctx, nodePool) Expect(err).ToNot(HaveOccurred()) diff --git a/pkg/providers/instancetype/suite_test.go b/pkg/providers/instancetype/suite_test.go index 6b0d284d725b..0386569498e0 100644 --- a/pkg/providers/instancetype/suite_test.go +++ b/pkg/providers/instancetype/suite_test.go @@ -155,6 +155,8 @@ var _ = Describe("InstanceTypeProvider", func() { }, }, }) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, windowsNodeClass)).To(BeNil()) }) It("should support individual instance type labels", func() { diff --git a/pkg/providers/launchtemplate/suite_test.go b/pkg/providers/launchtemplate/suite_test.go index 17449addc88f..c0608e179dc4 100644 --- a/pkg/providers/launchtemplate/suite_test.go +++ b/pkg/providers/launchtemplate/suite_test.go @@ -146,6 +146,7 @@ var _ = Describe("LaunchTemplate Provider", func() { }, }, }) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) }) It("should create unique launch templates for multiple identical nodeClasses", func() { nodeClass2 := test.EC2NodeClass() diff --git a/pkg/providers/subnet/subnet.go b/pkg/providers/subnet/subnet.go index 37af40718b25..44493776548b 100644 --- a/pkg/providers/subnet/subnet.go +++ b/pkg/providers/subnet/subnet.go @@ -41,6 +41,7 @@ type Provider interface { CheckAnyPublicIPAssociations(context.Context, *v1beta1.EC2NodeClass) (bool, error) ZonalSubnetsForLaunch(context.Context, *v1beta1.EC2NodeClass, []*cloudprovider.InstanceType, string) (map[string]*ec2.Subnet, error) UpdateInflightIPs(*ec2.CreateFleetInput, *ec2.CreateFleetOutput, []*cloudprovider.InstanceType, []*ec2.Subnet, string) + UpdateSubnets(context.Context, *v1beta1.EC2NodeClass) error } type DefaultProvider struct { @@ -64,26 +65,42 @@ func NewDefaultProvider(ec2api ec2iface.EC2API, cache *cache.Cache) *DefaultProv } func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1beta1.EC2NodeClass) ([]*ec2.Subnet, error) { - p.Lock() - defer p.Unlock() + p.RLock() + defer p.RUnlock() + filterSets := getFilterSets(nodeClass.Spec.SubnetSelectorTerms) if len(filterSets) == 0 { return []*ec2.Subnet{}, nil } - hash, err := hashstructure.Hash(filterSets, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true}) + hash, err := subnetHash(filterSets) if err != nil { - return nil, err + return []*ec2.Subnet{}, err } if subnets, ok := p.cache.Get(fmt.Sprint(hash)); ok { return subnets.([]*ec2.Subnet), nil } + return []*ec2.Subnet{}, nil +} + +func (p *DefaultProvider) UpdateSubnets(ctx context.Context, nodeClass *v1beta1.EC2NodeClass) error { + p.Lock() + defer p.Unlock() + + filterSets := getFilterSets(nodeClass.Spec.SubnetSelectorTerms) + if len(filterSets) == 0 { + return nil + } + hash, err := subnetHash(filterSets) + if err != nil { + return err + } // Ensure that all the subnets that are returned here are unique subnets := map[string]*ec2.Subnet{} for _, filters := range filterSets { output, err := p.ec2api.DescribeSubnetsWithContext(ctx, &ec2.DescribeSubnetsInput{Filters: filters}) if err != nil { - return nil, fmt.Errorf("describing subnets %s, %w", pretty.Concise(filters), err) + return fmt.Errorf("describing subnets %s, %w", pretty.Concise(filters), err) } for i := range output.Subnets { subnets[lo.FromPtr(output.Subnets[i].SubnetId)] = output.Subnets[i] @@ -91,6 +108,7 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1beta1.EC2NodeCl } } p.cache.SetDefault(fmt.Sprint(hash), lo.Values(subnets)) + p.cache.DeleteExpired() // delete expired after we have successfully updated the cache if p.cm.HasChanged(fmt.Sprintf("subnets/%s", nodeClass.Name), subnets) { logging.FromContext(ctx). With("subnets", lo.Map(lo.Values(subnets), func(s *ec2.Subnet, _ int) string { @@ -98,7 +116,8 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1beta1.EC2NodeCl })). Debugf("discovered subnets") } - return lo.Values(subnets), nil + + return nil } // CheckAnyPublicIPAssociations returns a bool indicating whether all referenced subnets assign public IPv4 addresses to EC2 instances created therein @@ -266,3 +285,11 @@ func getFilterSets(terms []v1beta1.SubnetSelectorTerm) (res [][]*ec2.Filter) { } return res } + +func subnetHash(filterSets [][]*ec2.Filter) (string, error) { + hash, err := hashstructure.Hash(filterSets, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true}) + if err != nil { + return "", err + } + return fmt.Sprintf("%d", hash), nil +} diff --git a/pkg/providers/subnet/suite_test.go b/pkg/providers/subnet/suite_test.go index 08b7354f7924..135c867caaba 100644 --- a/pkg/providers/subnet/suite_test.go +++ b/pkg/providers/subnet/suite_test.go @@ -99,6 +99,7 @@ var _ = Describe("SubnetProvider", func() { ID: "subnet-test1", }, } + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) subnets, err := awsEnv.SubnetProvider.List(ctx, nodeClass) Expect(err).To(BeNil()) ExpectConsistsOfSubnets([]*ec2.Subnet{ @@ -118,6 +119,7 @@ var _ = Describe("SubnetProvider", func() { ID: "subnet-test2", }, } + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) subnets, err := awsEnv.SubnetProvider.List(ctx, nodeClass) Expect(err).To(BeNil()) ExpectConsistsOfSubnets([]*ec2.Subnet{ @@ -144,6 +146,7 @@ var _ = Describe("SubnetProvider", func() { Tags: map[string]string{"foo": "bar"}, }, } + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) subnets, err := awsEnv.SubnetProvider.List(ctx, nodeClass) Expect(err).To(BeNil()) ExpectConsistsOfSubnets([]*ec2.Subnet{ @@ -165,6 +168,7 @@ var _ = Describe("SubnetProvider", func() { Tags: map[string]string{"Name": "test-subnet-1"}, }, } + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) subnets, err := awsEnv.SubnetProvider.List(ctx, nodeClass) Expect(err).To(BeNil()) ExpectConsistsOfSubnets([]*ec2.Subnet{ @@ -184,6 +188,7 @@ var _ = Describe("SubnetProvider", func() { Tags: map[string]string{"Name": "test-subnet-2"}, }, } + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) subnets, err := awsEnv.SubnetProvider.List(ctx, nodeClass) Expect(err).To(BeNil()) ExpectConsistsOfSubnets([]*ec2.Subnet{ @@ -206,6 +211,7 @@ var _ = Describe("SubnetProvider", func() { Tags: map[string]string{"foo": "bar"}, }, } + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) subnets, err := awsEnv.SubnetProvider.List(ctx, nodeClass) Expect(err).To(BeNil()) ExpectConsistsOfSubnets([]*ec2.Subnet{ @@ -225,6 +231,7 @@ var _ = Describe("SubnetProvider", func() { Tags: map[string]string{"foo": "bar"}, }, } + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) onlyPrivate, err := awsEnv.SubnetProvider.CheckAnyPublicIPAssociations(ctx, nodeClass) Expect(err).To(BeNil()) Expect(onlyPrivate).To(BeFalse()) @@ -235,6 +242,7 @@ var _ = Describe("SubnetProvider", func() { ID: "subnet-test2", }, } + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) onlyPrivate, err := awsEnv.SubnetProvider.CheckAnyPublicIPAssociations(ctx, nodeClass) Expect(err).To(BeNil()) Expect(onlyPrivate).To(BeTrue()) @@ -242,13 +250,22 @@ var _ = Describe("SubnetProvider", func() { }) Context("Provider Cache", func() { It("should resolve subnets from cache that are filtered by id", func() { - expectedSubnets := awsEnv.EC2API.DescribeSubnetsOutput.Clone().Subnets - for _, subnet := range expectedSubnets { + expectedSubnets, err := awsEnv.EC2API.DescribeSubnetsWithContext(ctx, &ec2.DescribeSubnetsInput{ + Filters: []*ec2.Filter{ + { + Name: lo.ToPtr("tag-key"), + Values: []*string{lo.ToPtr("*")}, + }, + }, + }) + Expect(err).To(BeNil()) + for _, subnet := range expectedSubnets.Subnets { nodeClass.Spec.SubnetSelectorTerms = []v1beta1.SubnetSelectorTerm{ { ID: *subnet.SubnetId, }, } + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) // Call list to request from aws and store in the cache _, err := awsEnv.SubnetProvider.List(ctx, nodeClass) Expect(err).To(BeNil()) @@ -257,12 +274,20 @@ var _ = Describe("SubnetProvider", func() { for _, cachedObject := range awsEnv.SubnetCache.Items() { cachedSubnet := cachedObject.Object.([]*ec2.Subnet) Expect(cachedSubnet).To(HaveLen(1)) - lo.Contains(expectedSubnets, cachedSubnet[0]) + lo.Contains(expectedSubnets.Subnets, cachedSubnet[0]) } }) It("should resolve subnets from cache that are filtered by tags", func() { - expectedSubnets := awsEnv.EC2API.DescribeSubnetsOutput.Clone().Subnets - tagSet := lo.Map(expectedSubnets, func(subnet *ec2.Subnet, _ int) map[string]string { + expectedSubnets, err := awsEnv.EC2API.DescribeSubnetsWithContext(ctx, &ec2.DescribeSubnetsInput{ + Filters: []*ec2.Filter{ + { + Name: lo.ToPtr("tag-key"), + Values: []*string{lo.ToPtr("*")}, + }, + }, + }) + Expect(err).To(BeNil()) + tagSet := lo.Map(expectedSubnets.Subnets, func(subnet *ec2.Subnet, _ int) map[string]string { tag, _ := lo.Find(subnet.Tags, func(tag *ec2.Tag) bool { return lo.FromPtr(tag.Key) == "Name" }) @@ -274,6 +299,7 @@ var _ = Describe("SubnetProvider", func() { Tags: tag, }, } + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) // Call list to request from aws and store in the cache _, err := awsEnv.SubnetProvider.List(ctx, nodeClass) Expect(err).To(BeNil()) @@ -282,7 +308,7 @@ var _ = Describe("SubnetProvider", func() { for _, cachedObject := range awsEnv.SubnetCache.Items() { cachedSubnet := cachedObject.Object.([]*ec2.Subnet) Expect(cachedSubnet).To(HaveLen(1)) - lo.Contains(expectedSubnets, cachedSubnet[0]) + lo.Contains(expectedSubnets.Subnets, cachedSubnet[0]) } }) })