diff --git a/pkg/cloudprovider/cloudprovider.go b/pkg/cloudprovider/cloudprovider.go index 82c7f31d1790..e75af9c5fa5f 100644 --- a/pkg/cloudprovider/cloudprovider.go +++ b/pkg/cloudprovider/cloudprovider.go @@ -189,22 +189,22 @@ func (c *CloudProvider) Delete(ctx context.Context, machine *v1alpha5.Machine) e return c.instanceProvider.Delete(ctx, id) } -func (c *CloudProvider) IsMachineDrifted(ctx context.Context, machine *v1alpha5.Machine) (bool, error) { +func (c *CloudProvider) IsMachineDrifted(ctx context.Context, machine *v1alpha5.Machine) (cloudprovider.DriftReason, error) { // Not needed when GetInstanceTypes removes provisioner dependency provisioner := &v1alpha5.Provisioner{} if err := c.kubeClient.Get(ctx, types.NamespacedName{Name: machine.Labels[v1alpha5.ProvisionerNameLabelKey]}, provisioner); err != nil { - return false, client.IgnoreNotFound(fmt.Errorf("getting provisioner, %w", err)) + return "", client.IgnoreNotFound(fmt.Errorf("getting provisioner, %w", err)) } if provisioner.Spec.ProviderRef == nil { - return false, nil + return "", nil } nodeTemplate, err := c.resolveNodeTemplate(ctx, nil, provisioner.Spec.ProviderRef) if err != nil { - return false, client.IgnoreNotFound(fmt.Errorf("resolving node template, %w", err)) + return "", client.IgnoreNotFound(fmt.Errorf("resolving node template, %w", err)) } drifted, err := c.isNodeTemplateDrifted(ctx, machine, provisioner, nodeTemplate) if err != nil { - return false, err + return "", err } return drifted, nil } diff --git a/pkg/cloudprovider/drift.go b/pkg/cloudprovider/drift.go index 02afdf7a401f..5f081611933d 100644 --- a/pkg/cloudprovider/drift.go +++ b/pkg/cloudprovider/drift.go @@ -30,90 +30,111 @@ import ( "github.com/aws/karpenter/pkg/utils" ) -func (c *CloudProvider) isNodeTemplateDrifted(ctx context.Context, machine *v1alpha5.Machine, provisioner *v1alpha5.Provisioner, nodeTemplate *v1alpha1.AWSNodeTemplate) (bool, error) { +const ( + StaticDrift cloudprovider.DriftReason = "StaticDrift" + AMIDrift cloudprovider.DriftReason = "AMIDrift" + SubnetDrift cloudprovider.DriftReason = "SubnetDrift" + SecurityGroupDrift cloudprovider.DriftReason = "SecurityGroupDrift" +) + +func (c *CloudProvider) isNodeTemplateDrifted(ctx context.Context, machine *v1alpha5.Machine, provisioner *v1alpha5.Provisioner, nodeTemplate *v1alpha1.AWSNodeTemplate) (cloudprovider.DriftReason, error) { instance, err := c.getInstance(ctx, machine.Status.ProviderID) if err != nil { - return false, err + return "", err } amiDrifted, err := c.isAMIDrifted(ctx, machine, provisioner, instance, nodeTemplate) if err != nil { - return false, fmt.Errorf("calculating ami drift, %w", err) + return "", fmt.Errorf("calculating ami drift, %w", err) } securitygroupDrifted, err := c.areSecurityGroupsDrifted(instance, nodeTemplate) if err != nil { - return false, fmt.Errorf("calculating securitygroup drift, %w", err) + return "", fmt.Errorf("calculating securitygroup drift, %w", err) } subnetDrifted, err := c.isSubnetDrifted(instance, nodeTemplate) if err != nil { - return false, fmt.Errorf("calculating subnet drift, %w", err) + return "", fmt.Errorf("calculating subnet drift, %w", err) } - - return amiDrifted || securitygroupDrifted || subnetDrifted || c.areStaticFieldsDrifted(machine, nodeTemplate), nil + drifted := lo.FindOrElse([]cloudprovider.DriftReason{amiDrifted, securitygroupDrifted, subnetDrifted, c.areStaticFieldsDrifted(machine, nodeTemplate)}, "", func(i cloudprovider.DriftReason) bool { + return string(i) != "" + }) + return drifted, nil } func (c *CloudProvider) isAMIDrifted(ctx context.Context, machine *v1alpha5.Machine, provisioner *v1alpha5.Provisioner, - instance *instance.Instance, nodeTemplate *v1alpha1.AWSNodeTemplate) (bool, error) { + instance *instance.Instance, nodeTemplate *v1alpha1.AWSNodeTemplate) (cloudprovider.DriftReason, error) { instanceTypes, err := c.GetInstanceTypes(ctx, provisioner) if err != nil { - return false, fmt.Errorf("getting instanceTypes, %w", err) + return "", fmt.Errorf("getting instanceTypes, %w", err) } nodeInstanceType, found := lo.Find(instanceTypes, func(instType *cloudprovider.InstanceType) bool { return instType.Name == machine.Labels[v1.LabelInstanceTypeStable] }) if !found { - return false, fmt.Errorf(`finding node instance type "%s"`, machine.Labels[v1.LabelInstanceTypeStable]) + return "", fmt.Errorf(`finding node instance type "%s"`, machine.Labels[v1.LabelInstanceTypeStable]) } if nodeTemplate.Spec.LaunchTemplateName != nil { - return false, nil + return "", nil } amis, err := c.amiProvider.Get(ctx, nodeTemplate, &amifamily.Options{}) if err != nil { - return false, fmt.Errorf("getting amis, %w", err) + return "", fmt.Errorf("getting amis, %w", err) } if len(amis) == 0 { - return false, fmt.Errorf("no amis exist given constraints") + return "", fmt.Errorf("no amis exist given constraints") } mappedAMIs := amifamily.MapInstanceTypes(amis, []*cloudprovider.InstanceType{nodeInstanceType}) if len(mappedAMIs) == 0 { - return false, fmt.Errorf("no instance types satisfy requirements of amis %v,", amis) + return "", fmt.Errorf("no instance types satisfy requirements of amis %v,", amis) + } + if !lo.Contains(lo.Keys(mappedAMIs), instance.ImageID) { + return AMIDrift, nil } - return !lo.Contains(lo.Keys(mappedAMIs), instance.ImageID), nil + return "", nil } -func (c *CloudProvider) isSubnetDrifted(instance *instance.Instance, nodeTemplate *v1alpha1.AWSNodeTemplate) (bool, error) { +func (c *CloudProvider) isSubnetDrifted(instance *instance.Instance, nodeTemplate *v1alpha1.AWSNodeTemplate) (cloudprovider.DriftReason, error) { // If the node template status does not have subnets, wait for the subnets to be populated before continuing if nodeTemplate.Status.Subnets == nil { - return false, fmt.Errorf("AWSNodeTemplate has no subnets") + return "", fmt.Errorf("AWSNodeTemplate has no subnets") } _, found := lo.Find(nodeTemplate.Status.Subnets, func(subnet v1alpha1.Subnet) bool { return subnet.ID == instance.SubnetID }) - return !found, nil + if !found { + return SubnetDrift, nil + } + return "", nil } // Checks if the security groups are drifted, by comparing the AWSNodeTemplate.Status.SecurityGroups // to the ec2 instance security groups -func (c *CloudProvider) areSecurityGroupsDrifted(ec2Instance *instance.Instance, nodeTemplate *v1alpha1.AWSNodeTemplate) (bool, error) { +func (c *CloudProvider) areSecurityGroupsDrifted(ec2Instance *instance.Instance, nodeTemplate *v1alpha1.AWSNodeTemplate) (cloudprovider.DriftReason, error) { // nodeTemplate.Spec.SecurityGroupSelector can be nil if the user is using a launchTemplateName to define SecurityGroups // Karpenter will not drift on changes to securitygroup in the launchTemplateName if nodeTemplate.Spec.LaunchTemplateName != nil { - return false, nil + return "", nil } securityGroupIds := sets.New(lo.Map(nodeTemplate.Status.SecurityGroups, func(sg v1alpha1.SecurityGroup, _ int) string { return sg.ID })...) if len(securityGroupIds) == 0 { - return false, fmt.Errorf("no security groups exist in the AWSNodeTemplate Status") + return "", fmt.Errorf("no security groups exist in the AWSNodeTemplate Status") + } + + if !securityGroupIds.Equal(sets.New(ec2Instance.SecurityGroupIDs...)) { + return SecurityGroupDrift, nil } - return !securityGroupIds.Equal(sets.New(ec2Instance.SecurityGroupIDs...)), nil + return "", nil } -func (c *CloudProvider) areStaticFieldsDrifted(machine *v1alpha5.Machine, nodeTemplate *v1alpha1.AWSNodeTemplate) bool { +func (c *CloudProvider) areStaticFieldsDrifted(machine *v1alpha5.Machine, nodeTemplate *v1alpha1.AWSNodeTemplate) cloudprovider.DriftReason { nodeTemplateHash, foundHashNodeTemplate := nodeTemplate.ObjectMeta.Annotations[v1alpha1.AnnotationNodeTemplateHash] machineHash, foundHashMachine := machine.ObjectMeta.Annotations[v1alpha1.AnnotationNodeTemplateHash] if !foundHashNodeTemplate || !foundHashMachine { - return false + return "" } - - return nodeTemplateHash != machineHash + if nodeTemplateHash != machineHash { + return StaticDrift + } + return "" } func (c *CloudProvider) getInstance(ctx context.Context, providerID string) (*instance.Instance, error) { diff --git a/pkg/cloudprovider/suite_test.go b/pkg/cloudprovider/suite_test.go index c806db39a257..1af1f9629879 100644 --- a/pkg/cloudprovider/suite_test.go +++ b/pkg/cloudprovider/suite_test.go @@ -59,6 +59,8 @@ import ( . "github.com/aws/karpenter-core/pkg/test/expectations" ) +const NoDrift corecloudproivder.DriftReason = "" + var ctx context.Context var stop context.CancelFunc var opts options.Options @@ -307,33 +309,33 @@ var _ = Describe("CloudProvider", func() { ExpectDeleted(ctx, env.Client, nodeTemplate) drifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).ToNot(HaveOccurred()) - Expect(drifted).To(BeFalse()) + Expect(drifted).To(Equal(NoDrift)) }) It("should return false if providerRef is not defined", func() { provisioner.Spec.ProviderRef = nil ExpectApplied(ctx, env.Client, provisioner) drifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).ToNot(HaveOccurred()) - Expect(drifted).To(BeFalse()) + Expect(drifted).To(Equal(NoDrift)) }) It("should not fail if provisioner does not exist", func() { ExpectDeleted(ctx, env.Client, provisioner) drifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).ToNot(HaveOccurred()) - Expect(drifted).To(BeFalse()) + Expect(drifted).To(Equal(NoDrift)) }) It("should return drifted if the AMI is not valid", func() { // Instance is a reference to what we return in the GetInstances call instance.ImageId = aws.String(fake.ImageID()) isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).ToNot(HaveOccurred()) - Expect(isDrifted).To(BeTrue()) + Expect(isDrifted).To(Equal(cloudprovider.AMIDrift)) }) It("should return drifted if the subnet is not valid", func() { instance.SubnetId = aws.String(fake.SubnetID()) isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).ToNot(HaveOccurred()) - Expect(isDrifted).To(BeTrue()) + Expect(isDrifted).To(Equal(cloudprovider.SubnetDrift)) }) It("should return an error if AWSNodeTemplate subnets are empty", func() { nodeTemplate.Status.Subnets = []v1alpha1.Subnet{} @@ -344,7 +346,7 @@ var _ = Describe("CloudProvider", func() { It("should not return drifted if the machine is valid", func() { isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).ToNot(HaveOccurred()) - Expect(isDrifted).To(BeFalse()) + Expect(isDrifted).To(Equal(NoDrift)) }) It("should return an error if the AWSNodeTemplate securitygroup are empty", func() { nodeTemplate.Status.SecurityGroups = []v1alpha1.SecurityGroup{} @@ -359,14 +361,14 @@ var _ = Describe("CloudProvider", func() { instance.SecurityGroups = []*ec2.GroupIdentifier{{GroupId: aws.String(fake.SecurityGroupID())}} isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).ToNot(HaveOccurred()) - Expect(isDrifted).To(BeTrue()) + Expect(isDrifted).To(Equal(cloudprovider.SecurityGroupDrift)) }) It("should return drifted if there are more instance securitygroups are present than AWSNodeTemplate Status", func() { // Instance is a reference to what we return in the GetInstances call instance.SecurityGroups = []*ec2.GroupIdentifier{{GroupId: aws.String(fake.SecurityGroupID())}, {GroupId: aws.String(validSecurityGroup)}} isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).ToNot(HaveOccurred()) - Expect(isDrifted).To(BeTrue()) + Expect(isDrifted).To(Equal(cloudprovider.SecurityGroupDrift)) }) It("should return drifted if more AWSNodeTemplate securitygroups are present than instance securitygroups", func() { nodeTemplate.Status.SecurityGroups = []v1alpha1.SecurityGroup{ @@ -382,7 +384,7 @@ var _ = Describe("CloudProvider", func() { ExpectApplied(ctx, env.Client, nodeTemplate) isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).ToNot(HaveOccurred()) - Expect(isDrifted).To(BeTrue()) + Expect(isDrifted).To(Equal(cloudprovider.SecurityGroupDrift)) }) It("should not return drifted if launchTemplateName is defined", func() { nodeTemplate.Spec.LaunchTemplateName = aws.String("validLaunchTemplateName") @@ -390,12 +392,12 @@ var _ = Describe("CloudProvider", func() { nodeTemplate.Status.SecurityGroups = nil isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).ToNot(HaveOccurred()) - Expect(isDrifted).To(BeFalse()) + Expect(isDrifted).To(Equal(NoDrift)) }) It("should not return drifted if the securitygroups match", func() { isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).ToNot(HaveOccurred()) - Expect(isDrifted).To(BeFalse()) + Expect(isDrifted).To(Equal(NoDrift)) }) It("should error if the machine doesn't have the instance-type label", func() { machine.Labels = map[string]string{ @@ -408,7 +410,7 @@ var _ = Describe("CloudProvider", func() { machine.Status = v1alpha5.MachineStatus{} isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).To(HaveOccurred()) - Expect(isDrifted).To(BeFalse()) + Expect(isDrifted).To(Equal(NoDrift)) }) It("should error drift if the underlying machine does not exist", func() { awsEnv.EC2API.DescribeInstancesBehavior.Output.Set(&ec2.DescribeInstancesOutput{ @@ -431,7 +433,7 @@ var _ = Describe("CloudProvider", func() { ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).NotTo(HaveOccurred()) - Expect(isDrifted).To(BeFalse()) + Expect(isDrifted).To(Equal(NoDrift)) updatedAWSNodeTemplate := test.AWSNodeTemplate(*nodeTemplate.Spec.DeepCopy(), awsnodetemplatespec) updatedAWSNodeTemplate.ObjectMeta = nodeTemplate.ObjectMeta @@ -441,7 +443,7 @@ var _ = Describe("CloudProvider", func() { ExpectApplied(ctx, env.Client, updatedAWSNodeTemplate) isDrifted, err = cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).NotTo(HaveOccurred()) - Expect(isDrifted).To(BeTrue()) + Expect(isDrifted).To(Equal(cloudprovider.StaticDrift)) }, Entry("InstanceProfile Drift", v1alpha1.AWSNodeTemplateSpec{AWS: v1alpha1.AWS{InstanceProfile: aws.String("profile-2")}}), Entry("UserData Drift", v1alpha1.AWSNodeTemplateSpec{UserData: aws.String("userdata-test-2")}), @@ -457,7 +459,7 @@ var _ = Describe("CloudProvider", func() { ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).NotTo(HaveOccurred()) - Expect(isDrifted).To(BeFalse()) + Expect(isDrifted).To(Equal(NoDrift)) updatedAWSNodeTemplate := test.AWSNodeTemplate(*nodeTemplate.Spec.DeepCopy(), awsnodetemplatespec) updatedAWSNodeTemplate.ObjectMeta = nodeTemplate.ObjectMeta @@ -467,7 +469,7 @@ var _ = Describe("CloudProvider", func() { ExpectApplied(ctx, env.Client, updatedAWSNodeTemplate) isDrifted, err = cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).NotTo(HaveOccurred()) - Expect(isDrifted).To(BeFalse()) + Expect(isDrifted).To(Equal(NoDrift)) }, Entry("AMISelector Drift", v1alpha1.AWSNodeTemplateSpec{AMISelector: map[string]string{"aws::ids": validAMI}}), Entry("SubnetSelector Drift", v1alpha1.AWSNodeTemplateSpec{AWS: v1alpha1.AWS{SubnetSelector: map[string]string{"aws-ids": "subnet-test1"}}}), @@ -481,7 +483,7 @@ var _ = Describe("CloudProvider", func() { ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).NotTo(HaveOccurred()) - Expect(isDrifted).To(BeFalse()) + Expect(isDrifted).To(Equal(NoDrift)) }) }) }) diff --git a/pkg/fake/cloudprovider.go b/pkg/fake/cloudprovider.go index 51b12e5b77da..e4204dad3c16 100644 --- a/pkg/fake/cloudprovider.go +++ b/pkg/fake/cloudprovider.go @@ -20,6 +20,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "github.com/aws/karpenter-core/pkg/apis/v1alpha5" + "github.com/aws/karpenter-core/pkg/cloudprovider" corecloudprovider "github.com/aws/karpenter-core/pkg/cloudprovider" "github.com/aws/karpenter-core/pkg/test" "github.com/aws/karpenter/pkg/apis/v1alpha1" @@ -57,14 +58,14 @@ func (c *CloudProvider) GetInstanceTypes(_ context.Context, _ *v1alpha5.Provisio }, nil } -func (c *CloudProvider) IsMachineDrifted(_ context.Context, machine *v1alpha5.Machine) (bool, error) { +func (c *CloudProvider) IsMachineDrifted(_ context.Context, machine *v1alpha5.Machine) (cloudprovider.DriftReason, error) { nodeAMI := machine.Labels[v1alpha1.LabelInstanceAMIID] for _, ami := range c.ValidAMIs { if nodeAMI == ami { - return false, nil + return "", nil } } - return true, nil + return "drifted", nil } func (c *CloudProvider) Get(context.Context, string) (*v1alpha5.Machine, error) {