From 4ba17dd33146c84ee5f4eb5ca680e71311f7c8a5 Mon Sep 17 00:00:00 2001 From: Amanuel Engeda Date: Wed, 17 Apr 2024 17:59:36 -0700 Subject: [PATCH] Create Subnet controller --- pkg/cache/cache.go | 2 + pkg/cloudprovider/suite_test.go | 18 ++++ pkg/controllers/controllers.go | 2 + .../nodeclass/status/subnet_test.go | 11 ++ .../nodeclass/status/suite_test.go | 1 + .../providers/subnet/controller.go | 60 +++++++++++ .../providers/subnet/suite_test.go | 102 ++++++++++++++++++ pkg/operator/operator.go | 2 +- pkg/providers/instance/suite_test.go | 1 + pkg/providers/instancetype/suite_test.go | 2 + pkg/providers/launchtemplate/suite_test.go | 3 + pkg/providers/subnet/subnet.go | 39 +++++-- pkg/providers/subnet/suite_test.go | 38 +++++-- 13 files changed, 268 insertions(+), 13 deletions(-) create mode 100644 pkg/controllers/providers/subnet/controller.go create mode 100644 pkg/controllers/providers/subnet/suite_test.go diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index 7bb586693f6d..5deb0d2b5bf5 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -31,6 +31,8 @@ const ( InstanceTypesAndZonesTTL = 5 * time.Minute // InstanceProfileTTL is the time before we refresh checking instance profile existence at IAM InstanceProfileTTL = 15 * time.Minute + // InstanceTypesAndZonesTTL is the time before we remove subnets that have been refreshed + SubnetTTL = 10 * time.Minute ) const ( diff --git a/pkg/cloudprovider/suite_test.go b/pkg/cloudprovider/suite_test.go index 556767b6d109..40bfe5721a0e 100644 --- a/pkg/cloudprovider/suite_test.go +++ b/pkg/cloudprovider/suite_test.go @@ -138,6 +138,7 @@ var _ = Describe("CloudProvider", func() { }, }, }) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) }) It("should return an ICE error when there are no instance types to launch", func() { // Specify no instance types and expect to receive a capacity error @@ -658,6 +659,7 @@ var _ = Describe("CloudProvider", func() { Expect(err).To(HaveOccurred()) }) It("should not return drifted if the NodeClaim is valid", func() { + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) isDrifted, err := cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).ToNot(HaveOccurred()) Expect(isDrifted).To(BeEmpty()) @@ -696,11 +698,13 @@ var _ = Describe("CloudProvider", func() { }, }, }) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) isDrifted, err := cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).ToNot(HaveOccurred()) Expect(isDrifted).To(Equal(cloudprovider.SecurityGroupDrift)) }) It("should not return drifted if the security groups match", func() { + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) isDrifted, err := cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).ToNot(HaveOccurred()) Expect(isDrifted).To(BeEmpty()) @@ -776,6 +780,7 @@ var _ = Describe("CloudProvider", func() { DescribeTable("should return drifted if a statically drifted EC2NodeClass.Spec field is updated", func(changes v1beta1.EC2NodeClass) { ExpectApplied(ctx, env.Client, nodePool, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) isDrifted, err := cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).NotTo(HaveOccurred()) Expect(isDrifted).To(BeEmpty()) @@ -816,6 +821,7 @@ var _ = Describe("CloudProvider", func() { nodeClass.Annotations = lo.Assign(nodeClass.Annotations, map[string]string{v1beta1.AnnotationEC2NodeClassHash: nodeClass.Hash()}) ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) isDrifted, err := cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).NotTo(HaveOccurred()) Expect(isDrifted).To(Equal(cloudprovider.NodeClassDrift)) @@ -823,6 +829,7 @@ var _ = Describe("CloudProvider", func() { DescribeTable("should not return drifted if dynamic fields are updated", func(changes v1beta1.EC2NodeClass) { ExpectApplied(ctx, env.Client, nodePool, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) isDrifted, err := cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).NotTo(HaveOccurred()) Expect(isDrifted).To(BeEmpty()) @@ -831,6 +838,7 @@ var _ = Describe("CloudProvider", func() { nodeClass.Annotations = lo.Assign(nodeClass.Annotations, map[string]string{v1beta1.AnnotationEC2NodeClassHash: nodeClass.Hash()}) ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) isDrifted, err = cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).NotTo(HaveOccurred()) Expect(isDrifted).To(BeEmpty()) @@ -847,6 +855,7 @@ var _ = Describe("CloudProvider", func() { "Test Key": "Test Value", } ExpectApplied(ctx, env.Client, nodePool, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) isDrifted, err := cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).NotTo(HaveOccurred()) Expect(isDrifted).To(BeEmpty()) @@ -861,6 +870,7 @@ var _ = Describe("CloudProvider", func() { v1beta1.AnnotationEC2NodeClassHashVersion: "test-hash-version-2", } ExpectApplied(ctx, env.Client, nodePool, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) isDrifted, err := cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).NotTo(HaveOccurred()) Expect(isDrifted).To(BeEmpty()) @@ -878,6 +888,7 @@ var _ = Describe("CloudProvider", func() { "Test Key": "Test Value", } ExpectApplied(ctx, env.Client, nodePool, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) isDrifted, err := cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).NotTo(HaveOccurred()) Expect(isDrifted).To(BeEmpty()) @@ -895,6 +906,7 @@ var _ = Describe("CloudProvider", func() { "Test Key": "Test Value", } ExpectApplied(ctx, env.Client, nodePool, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) isDrifted, err := cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).NotTo(HaveOccurred()) Expect(isDrifted).To(BeEmpty()) @@ -937,6 +949,7 @@ var _ = Describe("CloudProvider", func() { Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-2")}}}, }}) ExpectApplied(ctx, env.Client, nodePool, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) pod := coretest.UnschedulablePod(coretest.PodOptions{NodeSelector: map[string]string{v1.LabelTopologyZone: "test-zone-1a"}}) ExpectProvisioned(ctx, env.Client, cluster, cloudProvider, prov, pod) ExpectScheduled(ctx, env.Client, pod) @@ -952,6 +965,7 @@ var _ = Describe("CloudProvider", func() { }}) nodePool.Spec.Template.Spec.Kubelet = &corev1beta1.KubeletConfiguration{MaxPods: aws.Int32(1)} ExpectApplied(ctx, env.Client, nodePool, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) pod1 := coretest.UnschedulablePod(coretest.PodOptions{NodeSelector: map[string]string{v1.LabelTopologyZone: "test-zone-1a"}}) pod2 := coretest.UnschedulablePod(coretest.PodOptions{NodeSelector: map[string]string{v1.LabelTopologyZone: "test-zone-1a"}}) ExpectProvisioned(ctx, env.Client, cluster, cloudProvider, prov, pod1, pod2) @@ -971,6 +985,7 @@ var _ = Describe("CloudProvider", func() { {SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailableIpAddressCount: aws.Int64(10), Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}}, }}) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) pod1 := coretest.UnschedulablePod(coretest.PodOptions{NodeSelector: map[string]string{v1.LabelTopologyZone: "test-zone-1a"}}) ExpectApplied(ctx, env.Client, nodePool, nodeClass, pod1) awsEnv.EC2API.CreateFleetBehavior.Error.Set(fmt.Errorf("CreateFleet synthetic error")) @@ -986,6 +1001,7 @@ var _ = Describe("CloudProvider", func() { }}) nodeClass.Spec.SubnetSelectorTerms = []v1beta1.SubnetSelectorTerm{{Tags: map[string]string{"Name": "test-subnet-1"}}} ExpectApplied(ctx, env.Client, nodePool, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) podSubnet1 := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, cloudProvider, prov, podSubnet1) ExpectScheduled(ctx, env.Client, podSubnet1) @@ -1018,6 +1034,7 @@ var _ = Describe("CloudProvider", func() { }, }) ExpectApplied(ctx, env.Client, nodePool2, nodeClass2) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass2)).To(BeNil()) podSubnet2 := coretest.UnschedulablePod(coretest.PodOptions{NodeSelector: map[string]string{corev1beta1.NodePoolLabelKey: nodePool2.Name}}) ExpectProvisioned(ctx, env.Client, cluster, cloudProvider, prov, podSubnet2) ExpectScheduled(ctx, env.Client, podSubnet2) @@ -1059,6 +1076,7 @@ var _ = Describe("CloudProvider", func() { }, }) ExpectApplied(ctx, env.Client, nodePool, nodePool2, nodeClass, misconfiguredNodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, cloudProvider, prov, pod) ExpectScheduled(ctx, env.Client, pod) diff --git a/pkg/controllers/controllers.go b/pkg/controllers/controllers.go index e4c8c9388829..1a24d95e79b2 100644 --- a/pkg/controllers/controllers.go +++ b/pkg/controllers/controllers.go @@ -23,6 +23,7 @@ import ( nodeclassstatus "github.com/aws/karpenter-provider-aws/pkg/controllers/nodeclass/status" nodeclasstermination "github.com/aws/karpenter-provider-aws/pkg/controllers/nodeclass/termination" controllerspricing "github.com/aws/karpenter-provider-aws/pkg/controllers/providers/pricing" + controllerssubnets "github.com/aws/karpenter-provider-aws/pkg/controllers/providers/subnet" "github.com/aws/karpenter-provider-aws/pkg/providers/launchtemplate" "github.com/aws/aws-sdk-go/aws/session" @@ -60,6 +61,7 @@ func NewControllers(ctx context.Context, sess *session.Session, clk clock.Clock, nodeclaimgarbagecollection.NewController(kubeClient, cloudProvider), nodeclaimtagging.NewController(kubeClient, instanceProvider), controllerspricing.NewController(pricingProvider), + controllerssubnets.NewController(kubeClient, subnetProvider), } if options.FromContext(ctx).InterruptionQueue != "" { sqsapi := servicesqs.New(sess) diff --git a/pkg/controllers/nodeclass/status/subnet_test.go b/pkg/controllers/nodeclass/status/subnet_test.go index ba37f1d3b11a..050b3b59367f 100644 --- a/pkg/controllers/nodeclass/status/subnet_test.go +++ b/pkg/controllers/nodeclass/status/subnet_test.go @@ -53,6 +53,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { }) It("Should update EC2NodeClass status for Subnets", func() { ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) nodeClass = ExpectExists(ctx, env.Client, nodeClass) Expect(nodeClass.Status.Subnets).To(Equal([]v1beta1.Subnet{ @@ -81,6 +82,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { {SubnetId: aws.String("subnet-test3"), AvailabilityZone: aws.String("test-zone-1c"), AvailableIpAddressCount: aws.Int64(50)}, }}) ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) nodeClass = ExpectExists(ctx, env.Client, nodeClass) Expect(nodeClass.Status.Subnets).To(Equal([]v1beta1.Subnet{ @@ -108,6 +110,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { }, } ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) nodeClass = ExpectExists(ctx, env.Client, nodeClass) Expect(nodeClass.Status.Subnets).To(Equal([]v1beta1.Subnet{ @@ -128,6 +131,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { }, } ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) nodeClass = ExpectExists(ctx, env.Client, nodeClass) Expect(nodeClass.Status.Subnets).To(Equal([]v1beta1.Subnet{ @@ -139,6 +143,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { }) It("Should update Subnet status when the Subnet selector gets updated by tags", func() { ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) nodeClass = ExpectExists(ctx, env.Client, nodeClass) Expect(nodeClass.Status.Subnets).To(Equal([]v1beta1.Subnet{ @@ -173,6 +178,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { }, } ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) nodeClass = ExpectExists(ctx, env.Client, nodeClass) Expect(nodeClass.Status.Subnets).To(Equal([]v1beta1.Subnet{ @@ -188,6 +194,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { }) It("Should update Subnet status when the Subnet selector gets updated by ids", func() { ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) nodeClass = ExpectExists(ctx, env.Client, nodeClass) Expect(nodeClass.Status.Subnets).To(Equal([]v1beta1.Subnet{ @@ -215,6 +222,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { }, } ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) nodeClass = ExpectExists(ctx, env.Client, nodeClass) Expect(nodeClass.Status.Subnets).To(Equal([]v1beta1.Subnet{ @@ -231,12 +239,14 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { }, } ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) ExpectReconcileFailed(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) nodeClass = ExpectExists(ctx, env.Client, nodeClass) Expect(nodeClass.Status.Subnets).To(BeNil()) }) It("Should not resolve a invalid selectors for an updated subnet selector", func() { ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) nodeClass = ExpectExists(ctx, env.Client, nodeClass) Expect(nodeClass.Status.Subnets).To(Equal([]v1beta1.Subnet{ @@ -264,6 +274,7 @@ var _ = Describe("NodeClass Subnet Status Controller", func() { }, } ExpectApplied(ctx, env.Client, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) ExpectReconcileFailed(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) nodeClass = ExpectExists(ctx, env.Client, nodeClass) Expect(nodeClass.Status.Subnets).To(BeNil()) diff --git a/pkg/controllers/nodeclass/status/suite_test.go b/pkg/controllers/nodeclass/status/suite_test.go index 545c9ba6cfd5..a7f97fd9a88b 100644 --- a/pkg/controllers/nodeclass/status/suite_test.go +++ b/pkg/controllers/nodeclass/status/suite_test.go @@ -73,6 +73,7 @@ var _ = BeforeEach(func() { ctx = coreoptions.ToContext(ctx, coretest.Options()) nodeClass = test.EC2NodeClass() awsEnv.Reset() + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) }) var _ = AfterEach(func() { diff --git a/pkg/controllers/providers/subnet/controller.go b/pkg/controllers/providers/subnet/controller.go new file mode 100644 index 000000000000..765941e1ce83 --- /dev/null +++ b/pkg/controllers/providers/subnet/controller.go @@ -0,0 +1,60 @@ +/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package subnet + +import ( + "context" + "time" + + controllerruntime "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/manager" + "sigs.k8s.io/controller-runtime/pkg/reconcile" + corecontroller "sigs.k8s.io/karpenter/pkg/operator/controller" + + "sigs.k8s.io/controller-runtime/pkg/controller" + + "github.com/aws/karpenter-provider-aws/pkg/apis/v1beta1" + "github.com/aws/karpenter-provider-aws/pkg/providers/subnet" +) + +var _ corecontroller.TypedController[*v1beta1.EC2NodeClass] = (*Controller)(nil) + +type Controller struct { + subnetProvider subnet.Provider +} + +func NewController(kubeClient client.Client, subnetProvider subnet.Provider) corecontroller.Controller { + return corecontroller.Typed[*v1beta1.EC2NodeClass](kubeClient, &Controller{ + subnetProvider: subnetProvider, + }) +} + +func (c *Controller) Reconcile(ctx context.Context, nodeClass *v1beta1.EC2NodeClass) (reconcile.Result, error) { + return reconcile.Result{RequeueAfter: 5 * time.Minute}, c.subnetProvider.UpdateSubnets(ctx, nodeClass) +} + +func (c *Controller) Name() string { + return "subnet" +} + +func (c *Controller) Builder(_ context.Context, m manager.Manager) corecontroller.Builder { + return corecontroller.Adapt(controllerruntime. + NewControllerManagedBy(m). + For(&v1beta1.EC2NodeClass{}). + WithOptions(controller.Options{ + MaxConcurrentReconciles: 10, + })) +} diff --git a/pkg/controllers/providers/subnet/suite_test.go b/pkg/controllers/providers/subnet/suite_test.go new file mode 100644 index 000000000000..3b2f10bd7c9a --- /dev/null +++ b/pkg/controllers/providers/subnet/suite_test.go @@ -0,0 +1,102 @@ +/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package subnet_test + +import ( + "context" + "reflect" + "sort" + "testing" + + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/samber/lo" + "sigs.k8s.io/controller-runtime/pkg/client" + coreoptions "sigs.k8s.io/karpenter/pkg/operator/options" + "sigs.k8s.io/karpenter/pkg/operator/scheme" + coretest "sigs.k8s.io/karpenter/pkg/test" + + corecontroller "sigs.k8s.io/karpenter/pkg/operator/controller" + + "github.com/aws/karpenter-provider-aws/pkg/apis" + "github.com/aws/karpenter-provider-aws/pkg/apis/v1beta1" + controllerssubnets "github.com/aws/karpenter-provider-aws/pkg/controllers/providers/subnet" + "github.com/aws/karpenter-provider-aws/pkg/operator/options" + "github.com/aws/karpenter-provider-aws/pkg/test" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + . "knative.dev/pkg/logging/testing" + . "sigs.k8s.io/karpenter/pkg/test/expectations" +) + +var ctx context.Context +var stop context.CancelFunc +var env *coretest.Environment +var awsEnv *test.Environment +var nodeClass *v1beta1.EC2NodeClass +var controller corecontroller.Controller + +func TestAWS(t *testing.T) { + ctx = TestContextWithLogger(t) + RegisterFailHandler(Fail) + RunSpecs(t, "Subnet") +} + +var _ = BeforeSuite(func() { + env = coretest.NewEnvironment(scheme.Scheme, coretest.WithCRDs(apis.CRDs...), coretest.WithFieldIndexers(test.EC2NodeClassFieldIndexer(ctx))) + ctx = coreoptions.ToContext(ctx, coretest.Options()) + ctx = options.ToContext(ctx, test.Options()) + ctx, stop = context.WithCancel(ctx) + awsEnv = test.NewEnvironment(ctx, env) + controller = controllerssubnets.NewController(env.Client, awsEnv.SubnetProvider) + nodeClass = test.EC2NodeClass() +}) + +var _ = AfterEach(func() { + ExpectCleanedUp(ctx, env.Client) +}) + +var _ = Describe("Subnet", func() { + It("should update subnet date with response from the DescribeSubnets API", func() { + awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{}) + ExpectApplied(ctx, env.Client, nodeClass) + ExpectReconcileSucceeded(ctx, controller, client.ObjectKeyFromObject(nodeClass)) + subnets, err := awsEnv.SubnetProvider.List(ctx, nodeClass) + Expect(err).To(BeNil()) + Expect(subnets).To(HaveLen(0)) + + awsEnv.Reset() + expectedSubnets, err := awsEnv.EC2API.DescribeSubnetsWithContext(ctx, &ec2.DescribeSubnetsInput{ + Filters: []*ec2.Filter{ + { + Name: lo.ToPtr("tag-key"), + Values: []*string{lo.ToPtr("*")}, + }, + }, + }) + Expect(err).To(BeNil()) + + ExpectReconcileSucceeded(ctx, controller, client.ObjectKeyFromObject(nodeClass)) + subnets, err = awsEnv.SubnetProvider.List(ctx, nodeClass) + sort.Slice(subnets, func(i, j int) bool { + return lo.FromPtr(subnets[i].AvailabilityZone) < lo.FromPtr(subnets[j].AvailabilityZone) + }) + sort.Slice(expectedSubnets.Subnets, func(i, j int) bool { + return lo.FromPtr(expectedSubnets.Subnets[i].AvailabilityZone) < lo.FromPtr(expectedSubnets.Subnets[j].AvailabilityZone) + }) + Expect(err).To(BeNil()) + Expect(reflect.DeepEqual(subnets, expectedSubnets.Subnets)).To(BeTrue()) + }) +}) diff --git a/pkg/operator/operator.go b/pkg/operator/operator.go index 39be55788322..069fd469212e 100644 --- a/pkg/operator/operator.go +++ b/pkg/operator/operator.go @@ -132,7 +132,7 @@ func NewOperator(ctx context.Context, operator *operator.Operator) (context.Cont } unavailableOfferingsCache := awscache.NewUnavailableOfferings() - subnetProvider := subnet.NewDefaultProvider(ec2api, cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval)) + subnetProvider := subnet.NewDefaultProvider(ec2api, cache.New(awscache.SubnetTTL, 0)) securityGroupProvider := securitygroup.NewDefaultProvider(ec2api, cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval)) instanceProfileProvider := instanceprofile.NewDefaultProvider(*sess.Config.Region, iam.New(sess), cache.New(awscache.InstanceProfileTTL, awscache.DefaultCleanupInterval)) pricingProvider := pricing.NewDefaultProvider( diff --git a/pkg/providers/instance/suite_test.go b/pkg/providers/instance/suite_test.go index 45d3504bafb6..a4f9682d38ca 100644 --- a/pkg/providers/instance/suite_test.go +++ b/pkg/providers/instance/suite_test.go @@ -115,6 +115,7 @@ var _ = Describe("InstanceProvider", func() { {CapacityType: corev1beta1.CapacityTypeSpot, InstanceType: "m5.xlarge", Zone: "test-zone-1a"}, {CapacityType: corev1beta1.CapacityTypeSpot, InstanceType: "m5.xlarge", Zone: "test-zone-1b"}, }) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) instanceTypes, err := cloudProvider.GetInstanceTypes(ctx, nodePool) Expect(err).ToNot(HaveOccurred()) diff --git a/pkg/providers/instancetype/suite_test.go b/pkg/providers/instancetype/suite_test.go index 6b0d284d725b..0386569498e0 100644 --- a/pkg/providers/instancetype/suite_test.go +++ b/pkg/providers/instancetype/suite_test.go @@ -155,6 +155,8 @@ var _ = Describe("InstanceTypeProvider", func() { }, }, }) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, windowsNodeClass)).To(BeNil()) }) It("should support individual instance type labels", func() { diff --git a/pkg/providers/launchtemplate/suite_test.go b/pkg/providers/launchtemplate/suite_test.go index 17449addc88f..9d89f093e387 100644 --- a/pkg/providers/launchtemplate/suite_test.go +++ b/pkg/providers/launchtemplate/suite_test.go @@ -146,6 +146,7 @@ var _ = Describe("LaunchTemplate Provider", func() { }, }, }) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) }) It("should create unique launch templates for multiple identical nodeClasses", func() { nodeClass2 := test.EC2NodeClass() @@ -1949,6 +1950,7 @@ var _ = Describe("LaunchTemplate Provider", func() { {Tags: map[string]string{"Name": "test-subnet-3"}}, } ExpectApplied(ctx, env.Client, nodePool, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, cloudProvider, prov, pod) ExpectScheduled(ctx, env.Client, pod) @@ -1960,6 +1962,7 @@ var _ = Describe("LaunchTemplate Provider", func() { {Tags: map[string]string{"Name": "test-subnet-2"}}, } ExpectApplied(ctx, env.Client, nodePool, nodeClass) + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, cloudProvider, prov, pod) ExpectScheduled(ctx, env.Client, pod) diff --git a/pkg/providers/subnet/subnet.go b/pkg/providers/subnet/subnet.go index 37af40718b25..44493776548b 100644 --- a/pkg/providers/subnet/subnet.go +++ b/pkg/providers/subnet/subnet.go @@ -41,6 +41,7 @@ type Provider interface { CheckAnyPublicIPAssociations(context.Context, *v1beta1.EC2NodeClass) (bool, error) ZonalSubnetsForLaunch(context.Context, *v1beta1.EC2NodeClass, []*cloudprovider.InstanceType, string) (map[string]*ec2.Subnet, error) UpdateInflightIPs(*ec2.CreateFleetInput, *ec2.CreateFleetOutput, []*cloudprovider.InstanceType, []*ec2.Subnet, string) + UpdateSubnets(context.Context, *v1beta1.EC2NodeClass) error } type DefaultProvider struct { @@ -64,26 +65,42 @@ func NewDefaultProvider(ec2api ec2iface.EC2API, cache *cache.Cache) *DefaultProv } func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1beta1.EC2NodeClass) ([]*ec2.Subnet, error) { - p.Lock() - defer p.Unlock() + p.RLock() + defer p.RUnlock() + filterSets := getFilterSets(nodeClass.Spec.SubnetSelectorTerms) if len(filterSets) == 0 { return []*ec2.Subnet{}, nil } - hash, err := hashstructure.Hash(filterSets, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true}) + hash, err := subnetHash(filterSets) if err != nil { - return nil, err + return []*ec2.Subnet{}, err } if subnets, ok := p.cache.Get(fmt.Sprint(hash)); ok { return subnets.([]*ec2.Subnet), nil } + return []*ec2.Subnet{}, nil +} + +func (p *DefaultProvider) UpdateSubnets(ctx context.Context, nodeClass *v1beta1.EC2NodeClass) error { + p.Lock() + defer p.Unlock() + + filterSets := getFilterSets(nodeClass.Spec.SubnetSelectorTerms) + if len(filterSets) == 0 { + return nil + } + hash, err := subnetHash(filterSets) + if err != nil { + return err + } // Ensure that all the subnets that are returned here are unique subnets := map[string]*ec2.Subnet{} for _, filters := range filterSets { output, err := p.ec2api.DescribeSubnetsWithContext(ctx, &ec2.DescribeSubnetsInput{Filters: filters}) if err != nil { - return nil, fmt.Errorf("describing subnets %s, %w", pretty.Concise(filters), err) + return fmt.Errorf("describing subnets %s, %w", pretty.Concise(filters), err) } for i := range output.Subnets { subnets[lo.FromPtr(output.Subnets[i].SubnetId)] = output.Subnets[i] @@ -91,6 +108,7 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1beta1.EC2NodeCl } } p.cache.SetDefault(fmt.Sprint(hash), lo.Values(subnets)) + p.cache.DeleteExpired() // delete expired after we have successfully updated the cache if p.cm.HasChanged(fmt.Sprintf("subnets/%s", nodeClass.Name), subnets) { logging.FromContext(ctx). With("subnets", lo.Map(lo.Values(subnets), func(s *ec2.Subnet, _ int) string { @@ -98,7 +116,8 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1beta1.EC2NodeCl })). Debugf("discovered subnets") } - return lo.Values(subnets), nil + + return nil } // CheckAnyPublicIPAssociations returns a bool indicating whether all referenced subnets assign public IPv4 addresses to EC2 instances created therein @@ -266,3 +285,11 @@ func getFilterSets(terms []v1beta1.SubnetSelectorTerm) (res [][]*ec2.Filter) { } return res } + +func subnetHash(filterSets [][]*ec2.Filter) (string, error) { + hash, err := hashstructure.Hash(filterSets, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true}) + if err != nil { + return "", err + } + return fmt.Sprintf("%d", hash), nil +} diff --git a/pkg/providers/subnet/suite_test.go b/pkg/providers/subnet/suite_test.go index 08b7354f7924..135c867caaba 100644 --- a/pkg/providers/subnet/suite_test.go +++ b/pkg/providers/subnet/suite_test.go @@ -99,6 +99,7 @@ var _ = Describe("SubnetProvider", func() { ID: "subnet-test1", }, } + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) subnets, err := awsEnv.SubnetProvider.List(ctx, nodeClass) Expect(err).To(BeNil()) ExpectConsistsOfSubnets([]*ec2.Subnet{ @@ -118,6 +119,7 @@ var _ = Describe("SubnetProvider", func() { ID: "subnet-test2", }, } + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) subnets, err := awsEnv.SubnetProvider.List(ctx, nodeClass) Expect(err).To(BeNil()) ExpectConsistsOfSubnets([]*ec2.Subnet{ @@ -144,6 +146,7 @@ var _ = Describe("SubnetProvider", func() { Tags: map[string]string{"foo": "bar"}, }, } + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) subnets, err := awsEnv.SubnetProvider.List(ctx, nodeClass) Expect(err).To(BeNil()) ExpectConsistsOfSubnets([]*ec2.Subnet{ @@ -165,6 +168,7 @@ var _ = Describe("SubnetProvider", func() { Tags: map[string]string{"Name": "test-subnet-1"}, }, } + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) subnets, err := awsEnv.SubnetProvider.List(ctx, nodeClass) Expect(err).To(BeNil()) ExpectConsistsOfSubnets([]*ec2.Subnet{ @@ -184,6 +188,7 @@ var _ = Describe("SubnetProvider", func() { Tags: map[string]string{"Name": "test-subnet-2"}, }, } + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) subnets, err := awsEnv.SubnetProvider.List(ctx, nodeClass) Expect(err).To(BeNil()) ExpectConsistsOfSubnets([]*ec2.Subnet{ @@ -206,6 +211,7 @@ var _ = Describe("SubnetProvider", func() { Tags: map[string]string{"foo": "bar"}, }, } + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) subnets, err := awsEnv.SubnetProvider.List(ctx, nodeClass) Expect(err).To(BeNil()) ExpectConsistsOfSubnets([]*ec2.Subnet{ @@ -225,6 +231,7 @@ var _ = Describe("SubnetProvider", func() { Tags: map[string]string{"foo": "bar"}, }, } + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) onlyPrivate, err := awsEnv.SubnetProvider.CheckAnyPublicIPAssociations(ctx, nodeClass) Expect(err).To(BeNil()) Expect(onlyPrivate).To(BeFalse()) @@ -235,6 +242,7 @@ var _ = Describe("SubnetProvider", func() { ID: "subnet-test2", }, } + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) onlyPrivate, err := awsEnv.SubnetProvider.CheckAnyPublicIPAssociations(ctx, nodeClass) Expect(err).To(BeNil()) Expect(onlyPrivate).To(BeTrue()) @@ -242,13 +250,22 @@ var _ = Describe("SubnetProvider", func() { }) Context("Provider Cache", func() { It("should resolve subnets from cache that are filtered by id", func() { - expectedSubnets := awsEnv.EC2API.DescribeSubnetsOutput.Clone().Subnets - for _, subnet := range expectedSubnets { + expectedSubnets, err := awsEnv.EC2API.DescribeSubnetsWithContext(ctx, &ec2.DescribeSubnetsInput{ + Filters: []*ec2.Filter{ + { + Name: lo.ToPtr("tag-key"), + Values: []*string{lo.ToPtr("*")}, + }, + }, + }) + Expect(err).To(BeNil()) + for _, subnet := range expectedSubnets.Subnets { nodeClass.Spec.SubnetSelectorTerms = []v1beta1.SubnetSelectorTerm{ { ID: *subnet.SubnetId, }, } + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) // Call list to request from aws and store in the cache _, err := awsEnv.SubnetProvider.List(ctx, nodeClass) Expect(err).To(BeNil()) @@ -257,12 +274,20 @@ var _ = Describe("SubnetProvider", func() { for _, cachedObject := range awsEnv.SubnetCache.Items() { cachedSubnet := cachedObject.Object.([]*ec2.Subnet) Expect(cachedSubnet).To(HaveLen(1)) - lo.Contains(expectedSubnets, cachedSubnet[0]) + lo.Contains(expectedSubnets.Subnets, cachedSubnet[0]) } }) It("should resolve subnets from cache that are filtered by tags", func() { - expectedSubnets := awsEnv.EC2API.DescribeSubnetsOutput.Clone().Subnets - tagSet := lo.Map(expectedSubnets, func(subnet *ec2.Subnet, _ int) map[string]string { + expectedSubnets, err := awsEnv.EC2API.DescribeSubnetsWithContext(ctx, &ec2.DescribeSubnetsInput{ + Filters: []*ec2.Filter{ + { + Name: lo.ToPtr("tag-key"), + Values: []*string{lo.ToPtr("*")}, + }, + }, + }) + Expect(err).To(BeNil()) + tagSet := lo.Map(expectedSubnets.Subnets, func(subnet *ec2.Subnet, _ int) map[string]string { tag, _ := lo.Find(subnet.Tags, func(tag *ec2.Tag) bool { return lo.FromPtr(tag.Key) == "Name" }) @@ -274,6 +299,7 @@ var _ = Describe("SubnetProvider", func() { Tags: tag, }, } + Expect(awsEnv.SubnetProvider.UpdateSubnets(ctx, nodeClass)).To(BeNil()) // Call list to request from aws and store in the cache _, err := awsEnv.SubnetProvider.List(ctx, nodeClass) Expect(err).To(BeNil()) @@ -282,7 +308,7 @@ var _ = Describe("SubnetProvider", func() { for _, cachedObject := range awsEnv.SubnetCache.Items() { cachedSubnet := cachedObject.Object.([]*ec2.Subnet) Expect(cachedSubnet).To(HaveLen(1)) - lo.Contains(expectedSubnets, cachedSubnet[0]) + lo.Contains(expectedSubnets.Subnets, cachedSubnet[0]) } }) })