diff --git a/go.mod b/go.mod index 28e42086d23a..0380b1117315 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/PuerkitoBio/goquery v1.8.1 github.com/avast/retry-go v3.0.0+incompatible github.com/aws/aws-sdk-go v1.44.294 - github.com/aws/karpenter-core v0.29.2-0.20230803235302-95bd9f61a18b + github.com/aws/karpenter-core v0.29.2-0.20230808175334-44f8af74b472 github.com/go-playground/validator/v10 v10.13.0 github.com/imdario/mergo v0.3.16 github.com/mitchellh/hashstructure/v2 v2.0.2 diff --git a/go.sum b/go.sum index 8eabb69f0b6a..c6cd264bc2f1 100644 --- a/go.sum +++ b/go.sum @@ -53,8 +53,8 @@ github.com/avast/retry-go v3.0.0+incompatible h1:4SOWQ7Qs+oroOTQOYnAHqelpCO0biHS github.com/avast/retry-go v3.0.0+incompatible/go.mod h1:XtSnn+n/sHqQIpZ10K1qAevBhOOCWBLXXy3hyiqqBrY= github.com/aws/aws-sdk-go v1.44.294 h1:3x7GaEth+pDU9HwFcAU0awZlEix5CEdyIZvV08SlHa8= github.com/aws/aws-sdk-go v1.44.294/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= -github.com/aws/karpenter-core v0.29.2-0.20230803235302-95bd9f61a18b h1:88YeEA65jQNCj4/AdH0qixJeuGuIOi5QxOzsRCXsJQA= -github.com/aws/karpenter-core v0.29.2-0.20230803235302-95bd9f61a18b/go.mod h1:+C8X0N378fQ/+YmopvRHflj2JFrVP8sPs9xL7v4A6eM= +github.com/aws/karpenter-core v0.29.2-0.20230808175334-44f8af74b472 h1:+6nHi+A0/D0VfAOcVR/bgOsETOjpwBXOucxGuJoaucA= +github.com/aws/karpenter-core v0.29.2-0.20230808175334-44f8af74b472/go.mod h1:+C8X0N378fQ/+YmopvRHflj2JFrVP8sPs9xL7v4A6eM= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= diff --git a/pkg/cloudprovider/cloudprovider.go b/pkg/cloudprovider/cloudprovider.go index 82c7f31d1790..ef27b5117fe1 100644 --- a/pkg/cloudprovider/cloudprovider.go +++ b/pkg/cloudprovider/cloudprovider.go @@ -25,6 +25,7 @@ import ( "k8s.io/apimachinery/pkg/api/resource" "k8s.io/apimachinery/pkg/runtime/schema" + "github.com/aws/karpenter-core/pkg/controllers/machine/disruption" "github.com/aws/karpenter-core/pkg/utils/functional" "github.com/aws/karpenter/pkg/apis" "github.com/aws/karpenter/pkg/apis/v1alpha1" @@ -189,22 +190,22 @@ func (c *CloudProvider) Delete(ctx context.Context, machine *v1alpha5.Machine) e return c.instanceProvider.Delete(ctx, id) } -func (c *CloudProvider) IsMachineDrifted(ctx context.Context, machine *v1alpha5.Machine) (bool, error) { +func (c *CloudProvider) IsMachineDrifted(ctx context.Context, machine *v1alpha5.Machine) (cloudprovider.DriftReason, error) { // Not needed when GetInstanceTypes removes provisioner dependency provisioner := &v1alpha5.Provisioner{} if err := c.kubeClient.Get(ctx, types.NamespacedName{Name: machine.Labels[v1alpha5.ProvisionerNameLabelKey]}, provisioner); err != nil { - return false, client.IgnoreNotFound(fmt.Errorf("getting provisioner, %w", err)) + return disruption.NotDrifted, client.IgnoreNotFound(fmt.Errorf("getting provisioner, %w", err)) } if provisioner.Spec.ProviderRef == nil { - return false, nil + return "", nil } nodeTemplate, err := c.resolveNodeTemplate(ctx, nil, provisioner.Spec.ProviderRef) if err != nil { - return false, client.IgnoreNotFound(fmt.Errorf("resolving node template, %w", err)) + return "", client.IgnoreNotFound(fmt.Errorf("resolving node template, %w", err)) } drifted, err := c.isNodeTemplateDrifted(ctx, machine, provisioner, nodeTemplate) if err != nil { - return false, err + return "", err } return drifted, nil } diff --git a/pkg/cloudprovider/drift.go b/pkg/cloudprovider/drift.go index 02afdf7a401f..01e3fd277b40 100644 --- a/pkg/cloudprovider/drift.go +++ b/pkg/cloudprovider/drift.go @@ -23,6 +23,7 @@ import ( "github.com/aws/karpenter-core/pkg/apis/v1alpha5" "github.com/aws/karpenter-core/pkg/cloudprovider" + "github.com/aws/karpenter-core/pkg/controllers/machine/disruption" "github.com/aws/karpenter-core/pkg/utils/sets" "github.com/aws/karpenter/pkg/apis/v1alpha1" "github.com/aws/karpenter/pkg/providers/amifamily" @@ -30,90 +31,110 @@ import ( "github.com/aws/karpenter/pkg/utils" ) -func (c *CloudProvider) isNodeTemplateDrifted(ctx context.Context, machine *v1alpha5.Machine, provisioner *v1alpha5.Provisioner, nodeTemplate *v1alpha1.AWSNodeTemplate) (bool, error) { +const ( + AMIDrift cloudprovider.DriftReason = "AMIDrift" + SubnetDrift cloudprovider.DriftReason = "SubnetDrift" + SecurityGroupDrift cloudprovider.DriftReason = "SecurityGroupDrift" +) + +func (c *CloudProvider) isNodeTemplateDrifted(ctx context.Context, machine *v1alpha5.Machine, provisioner *v1alpha5.Provisioner, nodeTemplate *v1alpha1.AWSNodeTemplate) (cloudprovider.DriftReason, error) { instance, err := c.getInstance(ctx, machine.Status.ProviderID) if err != nil { - return false, err + return disruption.NotDrifted, err } amiDrifted, err := c.isAMIDrifted(ctx, machine, provisioner, instance, nodeTemplate) if err != nil { - return false, fmt.Errorf("calculating ami drift, %w", err) + return disruption.NotDrifted, fmt.Errorf("calculating ami drift, %w", err) } securitygroupDrifted, err := c.areSecurityGroupsDrifted(instance, nodeTemplate) if err != nil { - return false, fmt.Errorf("calculating securitygroup drift, %w", err) + return disruption.NotDrifted, fmt.Errorf("calculating securitygroup drift, %w", err) } subnetDrifted, err := c.isSubnetDrifted(instance, nodeTemplate) if err != nil { - return false, fmt.Errorf("calculating subnet drift, %w", err) + return disruption.NotDrifted, fmt.Errorf("calculating subnet drift, %w", err) } - - return amiDrifted || securitygroupDrifted || subnetDrifted || c.areStaticFieldsDrifted(machine, nodeTemplate), nil + drifted := lo.FindOrElse([]cloudprovider.DriftReason{amiDrifted, securitygroupDrifted, subnetDrifted, c.areStaticFieldsDrifted(machine, nodeTemplate)}, disruption.NotDrifted, func(i cloudprovider.DriftReason) bool { + return i != disruption.NotDrifted + }) + return drifted, nil } func (c *CloudProvider) isAMIDrifted(ctx context.Context, machine *v1alpha5.Machine, provisioner *v1alpha5.Provisioner, - instance *instance.Instance, nodeTemplate *v1alpha1.AWSNodeTemplate) (bool, error) { + instance *instance.Instance, nodeTemplate *v1alpha1.AWSNodeTemplate) (cloudprovider.DriftReason, error) { instanceTypes, err := c.GetInstanceTypes(ctx, provisioner) if err != nil { - return false, fmt.Errorf("getting instanceTypes, %w", err) + return disruption.NotDrifted, fmt.Errorf("getting instanceTypes, %w", err) } nodeInstanceType, found := lo.Find(instanceTypes, func(instType *cloudprovider.InstanceType) bool { return instType.Name == machine.Labels[v1.LabelInstanceTypeStable] }) if !found { - return false, fmt.Errorf(`finding node instance type "%s"`, machine.Labels[v1.LabelInstanceTypeStable]) + return disruption.NotDrifted, fmt.Errorf(`finding node instance type "%s"`, machine.Labels[v1.LabelInstanceTypeStable]) } if nodeTemplate.Spec.LaunchTemplateName != nil { - return false, nil + return disruption.NotDrifted, nil } amis, err := c.amiProvider.Get(ctx, nodeTemplate, &amifamily.Options{}) if err != nil { - return false, fmt.Errorf("getting amis, %w", err) + return disruption.NotDrifted, fmt.Errorf("getting amis, %w", err) } if len(amis) == 0 { - return false, fmt.Errorf("no amis exist given constraints") + return disruption.NotDrifted, fmt.Errorf("no amis exist given constraints") } mappedAMIs := amifamily.MapInstanceTypes(amis, []*cloudprovider.InstanceType{nodeInstanceType}) if len(mappedAMIs) == 0 { - return false, fmt.Errorf("no instance types satisfy requirements of amis %v,", amis) + return disruption.NotDrifted, fmt.Errorf("no instance types satisfy requirements of amis %v,", amis) + } + if !lo.Contains(lo.Keys(mappedAMIs), instance.ImageID) { + return AMIDrift, nil } - return !lo.Contains(lo.Keys(mappedAMIs), instance.ImageID), nil + return disruption.NotDrifted, nil } -func (c *CloudProvider) isSubnetDrifted(instance *instance.Instance, nodeTemplate *v1alpha1.AWSNodeTemplate) (bool, error) { +func (c *CloudProvider) isSubnetDrifted(instance *instance.Instance, nodeTemplate *v1alpha1.AWSNodeTemplate) (cloudprovider.DriftReason, error) { // If the node template status does not have subnets, wait for the subnets to be populated before continuing if nodeTemplate.Status.Subnets == nil { - return false, fmt.Errorf("AWSNodeTemplate has no subnets") + return disruption.NotDrifted, fmt.Errorf("AWSNodeTemplate has no subnets") } _, found := lo.Find(nodeTemplate.Status.Subnets, func(subnet v1alpha1.Subnet) bool { return subnet.ID == instance.SubnetID }) - return !found, nil + if !found { + return SubnetDrift, nil + } + return disruption.NotDrifted, nil } // Checks if the security groups are drifted, by comparing the AWSNodeTemplate.Status.SecurityGroups // to the ec2 instance security groups -func (c *CloudProvider) areSecurityGroupsDrifted(ec2Instance *instance.Instance, nodeTemplate *v1alpha1.AWSNodeTemplate) (bool, error) { +func (c *CloudProvider) areSecurityGroupsDrifted(ec2Instance *instance.Instance, nodeTemplate *v1alpha1.AWSNodeTemplate) (cloudprovider.DriftReason, error) { // nodeTemplate.Spec.SecurityGroupSelector can be nil if the user is using a launchTemplateName to define SecurityGroups // Karpenter will not drift on changes to securitygroup in the launchTemplateName if nodeTemplate.Spec.LaunchTemplateName != nil { - return false, nil + return disruption.NotDrifted, nil } securityGroupIds := sets.New(lo.Map(nodeTemplate.Status.SecurityGroups, func(sg v1alpha1.SecurityGroup, _ int) string { return sg.ID })...) if len(securityGroupIds) == 0 { - return false, fmt.Errorf("no security groups exist in the AWSNodeTemplate Status") + return disruption.NotDrifted, fmt.Errorf("no security groups exist in the AWSNodeTemplate Status") + } + + if !securityGroupIds.Equal(sets.New(ec2Instance.SecurityGroupIDs...)) { + return SecurityGroupDrift, nil } - return !securityGroupIds.Equal(sets.New(ec2Instance.SecurityGroupIDs...)), nil + return disruption.NotDrifted, nil } -func (c *CloudProvider) areStaticFieldsDrifted(machine *v1alpha5.Machine, nodeTemplate *v1alpha1.AWSNodeTemplate) bool { +func (c *CloudProvider) areStaticFieldsDrifted(machine *v1alpha5.Machine, nodeTemplate *v1alpha1.AWSNodeTemplate) cloudprovider.DriftReason { nodeTemplateHash, foundHashNodeTemplate := nodeTemplate.ObjectMeta.Annotations[v1alpha1.AnnotationNodeTemplateHash] machineHash, foundHashMachine := machine.ObjectMeta.Annotations[v1alpha1.AnnotationNodeTemplateHash] if !foundHashNodeTemplate || !foundHashMachine { - return false + return disruption.NotDrifted } - - return nodeTemplateHash != machineHash + if nodeTemplateHash != machineHash { + return disruption.StaticDrift + } + return disruption.NotDrifted } func (c *CloudProvider) getInstance(ctx context.Context, providerID string) (*instance.Instance, error) { diff --git a/pkg/cloudprovider/suite_test.go b/pkg/cloudprovider/suite_test.go index c806db39a257..f01671402916 100644 --- a/pkg/cloudprovider/suite_test.go +++ b/pkg/cloudprovider/suite_test.go @@ -49,6 +49,7 @@ import ( coresettings "github.com/aws/karpenter-core/pkg/apis/settings" "github.com/aws/karpenter-core/pkg/apis/v1alpha5" corecloudproivder "github.com/aws/karpenter-core/pkg/cloudprovider" + "github.com/aws/karpenter-core/pkg/controllers/machine/disruption" "github.com/aws/karpenter-core/pkg/controllers/provisioning" "github.com/aws/karpenter-core/pkg/controllers/state" "github.com/aws/karpenter-core/pkg/events" @@ -307,33 +308,33 @@ var _ = Describe("CloudProvider", func() { ExpectDeleted(ctx, env.Client, nodeTemplate) drifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).ToNot(HaveOccurred()) - Expect(drifted).To(BeFalse()) + Expect(drifted).To(Equal(disruption.NotDrifted)) }) It("should return false if providerRef is not defined", func() { provisioner.Spec.ProviderRef = nil ExpectApplied(ctx, env.Client, provisioner) drifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).ToNot(HaveOccurred()) - Expect(drifted).To(BeFalse()) + Expect(drifted).To(Equal(disruption.NotDrifted)) }) It("should not fail if provisioner does not exist", func() { ExpectDeleted(ctx, env.Client, provisioner) drifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).ToNot(HaveOccurred()) - Expect(drifted).To(BeFalse()) + Expect(drifted).To(Equal(disruption.NotDrifted)) }) It("should return drifted if the AMI is not valid", func() { // Instance is a reference to what we return in the GetInstances call instance.ImageId = aws.String(fake.ImageID()) isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).ToNot(HaveOccurred()) - Expect(isDrifted).To(BeTrue()) + Expect(isDrifted).To(Equal(cloudprovider.AMIDrift)) }) It("should return drifted if the subnet is not valid", func() { instance.SubnetId = aws.String(fake.SubnetID()) isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).ToNot(HaveOccurred()) - Expect(isDrifted).To(BeTrue()) + Expect(isDrifted).To(Equal(cloudprovider.SubnetDrift)) }) It("should return an error if AWSNodeTemplate subnets are empty", func() { nodeTemplate.Status.Subnets = []v1alpha1.Subnet{} @@ -344,7 +345,7 @@ var _ = Describe("CloudProvider", func() { It("should not return drifted if the machine is valid", func() { isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).ToNot(HaveOccurred()) - Expect(isDrifted).To(BeFalse()) + Expect(isDrifted).To(Equal(disruption.NotDrifted)) }) It("should return an error if the AWSNodeTemplate securitygroup are empty", func() { nodeTemplate.Status.SecurityGroups = []v1alpha1.SecurityGroup{} @@ -359,14 +360,14 @@ var _ = Describe("CloudProvider", func() { instance.SecurityGroups = []*ec2.GroupIdentifier{{GroupId: aws.String(fake.SecurityGroupID())}} isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).ToNot(HaveOccurred()) - Expect(isDrifted).To(BeTrue()) + Expect(isDrifted).To(Equal(cloudprovider.SecurityGroupDrift)) }) It("should return drifted if there are more instance securitygroups are present than AWSNodeTemplate Status", func() { // Instance is a reference to what we return in the GetInstances call instance.SecurityGroups = []*ec2.GroupIdentifier{{GroupId: aws.String(fake.SecurityGroupID())}, {GroupId: aws.String(validSecurityGroup)}} isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).ToNot(HaveOccurred()) - Expect(isDrifted).To(BeTrue()) + Expect(isDrifted).To(Equal(cloudprovider.SecurityGroupDrift)) }) It("should return drifted if more AWSNodeTemplate securitygroups are present than instance securitygroups", func() { nodeTemplate.Status.SecurityGroups = []v1alpha1.SecurityGroup{ @@ -382,7 +383,7 @@ var _ = Describe("CloudProvider", func() { ExpectApplied(ctx, env.Client, nodeTemplate) isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).ToNot(HaveOccurred()) - Expect(isDrifted).To(BeTrue()) + Expect(isDrifted).To(Equal(cloudprovider.SecurityGroupDrift)) }) It("should not return drifted if launchTemplateName is defined", func() { nodeTemplate.Spec.LaunchTemplateName = aws.String("validLaunchTemplateName") @@ -390,12 +391,12 @@ var _ = Describe("CloudProvider", func() { nodeTemplate.Status.SecurityGroups = nil isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).ToNot(HaveOccurred()) - Expect(isDrifted).To(BeFalse()) + Expect(isDrifted).To(Equal(disruption.NotDrifted)) }) It("should not return drifted if the securitygroups match", func() { isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).ToNot(HaveOccurred()) - Expect(isDrifted).To(BeFalse()) + Expect(isDrifted).To(Equal(disruption.NotDrifted)) }) It("should error if the machine doesn't have the instance-type label", func() { machine.Labels = map[string]string{ @@ -408,7 +409,7 @@ var _ = Describe("CloudProvider", func() { machine.Status = v1alpha5.MachineStatus{} isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).To(HaveOccurred()) - Expect(isDrifted).To(BeFalse()) + Expect(isDrifted).To(Equal(disruption.NotDrifted)) }) It("should error drift if the underlying machine does not exist", func() { awsEnv.EC2API.DescribeInstancesBehavior.Output.Set(&ec2.DescribeInstancesOutput{ @@ -431,7 +432,7 @@ var _ = Describe("CloudProvider", func() { ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).NotTo(HaveOccurred()) - Expect(isDrifted).To(BeFalse()) + Expect(isDrifted).To(Equal(disruption.NotDrifted)) updatedAWSNodeTemplate := test.AWSNodeTemplate(*nodeTemplate.Spec.DeepCopy(), awsnodetemplatespec) updatedAWSNodeTemplate.ObjectMeta = nodeTemplate.ObjectMeta @@ -441,7 +442,7 @@ var _ = Describe("CloudProvider", func() { ExpectApplied(ctx, env.Client, updatedAWSNodeTemplate) isDrifted, err = cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).NotTo(HaveOccurred()) - Expect(isDrifted).To(BeTrue()) + Expect(isDrifted).To(Equal(disruption.StaticDrift)) }, Entry("InstanceProfile Drift", v1alpha1.AWSNodeTemplateSpec{AWS: v1alpha1.AWS{InstanceProfile: aws.String("profile-2")}}), Entry("UserData Drift", v1alpha1.AWSNodeTemplateSpec{UserData: aws.String("userdata-test-2")}), @@ -457,7 +458,7 @@ var _ = Describe("CloudProvider", func() { ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).NotTo(HaveOccurred()) - Expect(isDrifted).To(BeFalse()) + Expect(isDrifted).To(Equal(disruption.NotDrifted)) updatedAWSNodeTemplate := test.AWSNodeTemplate(*nodeTemplate.Spec.DeepCopy(), awsnodetemplatespec) updatedAWSNodeTemplate.ObjectMeta = nodeTemplate.ObjectMeta @@ -467,7 +468,7 @@ var _ = Describe("CloudProvider", func() { ExpectApplied(ctx, env.Client, updatedAWSNodeTemplate) isDrifted, err = cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).NotTo(HaveOccurred()) - Expect(isDrifted).To(BeFalse()) + Expect(isDrifted).To(Equal(disruption.NotDrifted)) }, Entry("AMISelector Drift", v1alpha1.AWSNodeTemplateSpec{AMISelector: map[string]string{"aws::ids": validAMI}}), Entry("SubnetSelector Drift", v1alpha1.AWSNodeTemplateSpec{AWS: v1alpha1.AWS{SubnetSelector: map[string]string{"aws-ids": "subnet-test1"}}}), @@ -481,7 +482,7 @@ var _ = Describe("CloudProvider", func() { ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) isDrifted, err := cloudProvider.IsMachineDrifted(ctx, machine) Expect(err).NotTo(HaveOccurred()) - Expect(isDrifted).To(BeFalse()) + Expect(isDrifted).To(Equal(disruption.NotDrifted)) }) }) }) diff --git a/pkg/fake/cloudprovider.go b/pkg/fake/cloudprovider.go index 51b12e5b77da..b03f43ff7307 100644 --- a/pkg/fake/cloudprovider.go +++ b/pkg/fake/cloudprovider.go @@ -21,6 +21,7 @@ import ( "github.com/aws/karpenter-core/pkg/apis/v1alpha5" corecloudprovider "github.com/aws/karpenter-core/pkg/cloudprovider" + "github.com/aws/karpenter-core/pkg/controllers/machine/disruption" "github.com/aws/karpenter-core/pkg/test" "github.com/aws/karpenter/pkg/apis/v1alpha1" ) @@ -57,14 +58,14 @@ func (c *CloudProvider) GetInstanceTypes(_ context.Context, _ *v1alpha5.Provisio }, nil } -func (c *CloudProvider) IsMachineDrifted(_ context.Context, machine *v1alpha5.Machine) (bool, error) { +func (c *CloudProvider) IsMachineDrifted(_ context.Context, machine *v1alpha5.Machine) (corecloudprovider.DriftReason, error) { nodeAMI := machine.Labels[v1alpha1.LabelInstanceAMIID] for _, ami := range c.ValidAMIs { if nodeAMI == ami { - return false, nil + return disruption.NotDrifted, nil } } - return true, nil + return "drifted", nil } func (c *CloudProvider) Get(context.Context, string) (*v1alpha5.Machine, error) {