diff --git a/pkg/cloudprovider/drift.go b/pkg/cloudprovider/drift.go index 69dc5aad1783..7c743ec6638a 100644 --- a/pkg/cloudprovider/drift.go +++ b/pkg/cloudprovider/drift.go @@ -54,7 +54,7 @@ func (c *CloudProvider) isNodeClassDrifted(ctx context.Context, nodeClaim *corev if err != nil { return "", fmt.Errorf("calculating ami drift, %w", err) } - securitygroupDrifted, err := c.areSecurityGroupsDrifted(ctx, instance, nodeClass) + securitygroupDrifted, err := c.areSecurityGroupsDrifted(instance, nodeClass) if err != nil { return "", fmt.Errorf("calculating securitygroup drift, %w", err) } @@ -118,12 +118,8 @@ func (c *CloudProvider) isSubnetDrifted(ctx context.Context, instance *instance. // Checks if the security groups are drifted, by comparing the security groups returned from the SecurityGroupProvider // to the ec2 instance security groups -func (c *CloudProvider) areSecurityGroupsDrifted(ctx context.Context, ec2Instance *instance.Instance, nodeClass *v1beta1.EC2NodeClass) (cloudprovider.DriftReason, error) { - securitygroup, err := c.securityGroupProvider.List(ctx, nodeClass) - if err != nil { - return "", err - } - securityGroupIds := sets.New(lo.Map(securitygroup, func(sg *ec2.SecurityGroup, _ int) string { return aws.StringValue(sg.GroupId) })...) +func (c *CloudProvider) areSecurityGroupsDrifted(ec2Instance *instance.Instance, nodeClass *v1beta1.EC2NodeClass) (cloudprovider.DriftReason, error) { + securityGroupIds := sets.New(lo.Map(nodeClass.Status.SecurityGroups, func(sg v1beta1.SecurityGroup, _ int) string { return sg.ID })...) if len(securityGroupIds) == 0 { return "", fmt.Errorf("no security groups are discovered") } diff --git a/pkg/cloudprovider/suite_test.go b/pkg/cloudprovider/suite_test.go index b9d255ed3808..7298573b30f5 100644 --- a/pkg/cloudprovider/suite_test.go +++ b/pkg/cloudprovider/suite_test.go @@ -39,15 +39,19 @@ import ( "github.com/aws/karpenter-provider-aws/pkg/apis" "github.com/aws/karpenter-provider-aws/pkg/apis/v1beta1" "github.com/aws/karpenter-provider-aws/pkg/cloudprovider" + "github.com/aws/karpenter-provider-aws/pkg/controllers/nodeclass/status" "github.com/aws/karpenter-provider-aws/pkg/fake" "github.com/aws/karpenter-provider-aws/pkg/operator/options" "github.com/aws/karpenter-provider-aws/pkg/test" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" corev1beta1 "sigs.k8s.io/karpenter/pkg/apis/v1beta1" corecloudproivder "sigs.k8s.io/karpenter/pkg/cloudprovider" "sigs.k8s.io/karpenter/pkg/controllers/provisioning" "sigs.k8s.io/karpenter/pkg/controllers/state" "sigs.k8s.io/karpenter/pkg/events" + "sigs.k8s.io/karpenter/pkg/operator/controller" coreoptions "sigs.k8s.io/karpenter/pkg/operator/options" "sigs.k8s.io/karpenter/pkg/operator/scheme" coretest "sigs.k8s.io/karpenter/pkg/test" @@ -67,6 +71,7 @@ var cloudProvider *cloudprovider.CloudProvider var cluster *state.Cluster var fakeClock *clock.FakeClock var recorder events.Recorder +var statusController controller.Controller func TestAWS(t *testing.T) { ctx = TestContextWithLogger(t) @@ -86,6 +91,14 @@ var _ = BeforeSuite(func() { env.Client, awsEnv.AMIProvider, awsEnv.SecurityGroupProvider, awsEnv.SubnetProvider) cluster = state.NewCluster(fakeClock, env.Client, cloudProvider) prov = provisioning.NewProvisioner(env.Client, recorder, cloudProvider, cluster) + statusController = status.NewController( + env.Client, + awsEnv.SubnetProvider, + awsEnv.SecurityGroupProvider, + awsEnv.AMIProvider, + awsEnv.InstanceProfileProvider, + awsEnv.LaunchTemplateProvider, + ) }) var _ = AfterSuite(func() { @@ -565,8 +578,10 @@ var _ = Describe("CloudProvider", func() { awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{ Subnets: []*ec2.Subnet{ { - SubnetId: aws.String(validSubnet1), - AvailabilityZone: aws.String("zone-1"), + SubnetId: aws.String(validSubnet1), + AvailabilityZone: aws.String("zone-1"), + AvailableIpAddressCount: aws.Int64(100), + MapPublicIpOnLaunch: aws.Bool(false), Tags: []*ec2.Tag{ { Key: aws.String("sn-key-1"), @@ -575,8 +590,10 @@ var _ = Describe("CloudProvider", func() { }, }, { - SubnetId: aws.String(validSubnet2), - AvailabilityZone: aws.String("zone-2"), + SubnetId: aws.String(validSubnet2), + AvailabilityZone: aws.String("zone-2"), + AvailableIpAddressCount: aws.Int64(100), + MapPublicIpOnLaunch: aws.Bool(false), Tags: []*ec2.Tag{ { Key: aws.String("sn-key-2"), @@ -587,6 +604,7 @@ var _ = Describe("CloudProvider", func() { }, }) ExpectApplied(ctx, env.Client, nodePool, nodeClass) + ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) instanceTypes, err := cloudProvider.GetInstanceTypes(ctx, nodePool) Expect(err).ToNot(HaveOccurred()) selectedInstanceType = instanceTypes[0] @@ -621,6 +639,8 @@ var _ = Describe("CloudProvider", func() { nodeClaim.Labels = lo.Assign(nodeClaim.Labels, map[string]string{v1.LabelInstanceTypeStable: selectedInstanceType.Name}) }) It("should not fail if NodeClass does not exist", func() { + controllerutil.RemoveFinalizer(nodeClass, v1beta1.TerminationFinalizer) + ExpectApplied(ctx, env.Client, nodeClass) ExpectDeleted(ctx, env.Client, nodeClass) drifted, err := cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).ToNot(HaveOccurred()) @@ -674,6 +694,8 @@ var _ = Describe("CloudProvider", func() { awsEnv.EC2API.DescribeSecurityGroupsOutput.Set(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{}}) // Instance is a reference to what we return in the GetInstances call instance.SecurityGroups = []*ec2.GroupIdentifier{{GroupId: aws.String(fake.SecurityGroupID())}} + awsEnv.SecurityGroupCache.Flush() + ExpectReconcileFailed(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) _, err := cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).To(HaveOccurred()) }) @@ -704,6 +726,8 @@ var _ = Describe("CloudProvider", func() { }, }, }) + awsEnv.SecurityGroupCache.Flush() + ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) isDrifted, err := cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).ToNot(HaveOccurred()) Expect(isDrifted).To(Equal(cloudprovider.SecurityGroupDrift)) @@ -734,6 +758,7 @@ var _ = Describe("CloudProvider", func() { It("should return drifted if the AMI no longer matches the existing NodeClaims instance type", func() { nodeClass.Spec.AMISelectorTerms = []v1beta1.AMISelectorTerm{{ID: amdAMIID}} ExpectApplied(ctx, env.Client, nodeClass) + ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) isDrifted, err := cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).ToNot(HaveOccurred()) Expect(isDrifted).To(Equal(cloudprovider.AMIDrift)) @@ -784,6 +809,7 @@ var _ = Describe("CloudProvider", func() { DescribeTable("should return drifted if a statically drifted EC2NodeClass.Spec field is updated", func(changes v1beta1.EC2NodeClass) { ExpectApplied(ctx, env.Client, nodePool, nodeClass) + ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) isDrifted, err := cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).NotTo(HaveOccurred()) Expect(isDrifted).To(BeEmpty()) @@ -792,6 +818,7 @@ var _ = Describe("CloudProvider", func() { nodeClass.Annotations = lo.Assign(nodeClass.Annotations, map[string]string{v1beta1.AnnotationEC2NodeClassHash: nodeClass.Hash()}) ExpectApplied(ctx, env.Client, nodeClass) + ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) isDrifted, err = cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).NotTo(HaveOccurred()) Expect(isDrifted).To(Equal(cloudprovider.NodeClassDrift)) @@ -831,6 +858,7 @@ var _ = Describe("CloudProvider", func() { DescribeTable("should not return drifted if dynamic fields are updated", func(changes v1beta1.EC2NodeClass) { ExpectApplied(ctx, env.Client, nodePool, nodeClass) + ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) isDrifted, err := cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).NotTo(HaveOccurred()) Expect(isDrifted).To(BeEmpty()) @@ -839,6 +867,7 @@ var _ = Describe("CloudProvider", func() { nodeClass.Annotations = lo.Assign(nodeClass.Annotations, map[string]string{v1beta1.AnnotationEC2NodeClassHash: nodeClass.Hash()}) ExpectApplied(ctx, env.Client, nodeClass) + ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) isDrifted, err = cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).NotTo(HaveOccurred()) Expect(isDrifted).To(BeEmpty()) @@ -855,6 +884,7 @@ var _ = Describe("CloudProvider", func() { "Test Key": "Test Value", } ExpectApplied(ctx, env.Client, nodePool, nodeClass) + ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) isDrifted, err := cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).NotTo(HaveOccurred()) Expect(isDrifted).To(BeEmpty()) @@ -869,6 +899,7 @@ var _ = Describe("CloudProvider", func() { v1beta1.AnnotationEC2NodeClassHashVersion: "test-hash-version-2", } ExpectApplied(ctx, env.Client, nodePool, nodeClass) + ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) isDrifted, err := cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).NotTo(HaveOccurred()) Expect(isDrifted).To(BeEmpty()) @@ -886,6 +917,7 @@ var _ = Describe("CloudProvider", func() { "Test Key": "Test Value", } ExpectApplied(ctx, env.Client, nodePool, nodeClass) + ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) isDrifted, err := cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).NotTo(HaveOccurred()) Expect(isDrifted).To(BeEmpty()) @@ -903,6 +935,7 @@ var _ = Describe("CloudProvider", func() { "Test Key": "Test Value", } ExpectApplied(ctx, env.Client, nodePool, nodeClass) + ExpectReconcileSucceeded(ctx, statusController, client.ObjectKeyFromObject(nodeClass)) isDrifted, err := cloudProvider.IsDrifted(ctx, nodeClaim) Expect(err).NotTo(HaveOccurred()) Expect(isDrifted).To(BeEmpty()) diff --git a/pkg/providers/launchtemplate/launchtemplate.go b/pkg/providers/launchtemplate/launchtemplate.go index f7d692c6f0de..7a30609887c4 100644 --- a/pkg/providers/launchtemplate/launchtemplate.go +++ b/pkg/providers/launchtemplate/launchtemplate.go @@ -177,14 +177,12 @@ func (p *DefaultProvider) createAMIOptions(ctx context.Context, nodeClass *v1bet ClusterCIDR: p.ClusterCIDR.Load(), InstanceProfile: instanceProfile, InstanceStorePolicy: nodeClass.Spec.InstanceStorePolicy, - SecurityGroups: lo.Map(securityGroups, func(s *ec2.SecurityGroup, _ int) v1beta1.SecurityGroup { - return v1beta1.SecurityGroup{ID: aws.StringValue(s.GroupId), Name: aws.StringValue(s.GroupName)} - }), - Tags: tags, - Labels: labels, - CABundle: p.CABundle, - KubeDNSIP: p.KubeDNSIP, - NodeClassName: nodeClass.Name, + SecurityGroups: nodeClass.Status.SecurityGroups, + Tags: tags, + Labels: labels, + CABundle: p.CABundle, + KubeDNSIP: p.KubeDNSIP, + NodeClassName: nodeClass.Name, } if nodeClass.Spec.AssociatePublicIPAddress != nil { options.AssociatePublicIPAddress = nodeClass.Spec.AssociatePublicIPAddress