Skip to content

Commit

Permalink
fix: Subnet and Security Groups to use in-memory values for evaluatin…
Browse files Browse the repository at this point in the history
…g Drift (aws#5518)
  • Loading branch information
engedaam authored Jan 24, 2024
1 parent 0df4e3b commit 3bce2be
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 46 deletions.
39 changes: 26 additions & 13 deletions pkg/cloudprovider/drift.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ import (
corev1beta1 "sigs.k8s.io/karpenter/pkg/apis/v1beta1"
"sigs.k8s.io/karpenter/pkg/cloudprovider"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"

"github.com/aws/karpenter-provider-aws/pkg/apis/v1beta1"
"github.com/aws/karpenter-provider-aws/pkg/providers/amifamily"
"github.com/aws/karpenter-provider-aws/pkg/providers/instance"
Expand All @@ -51,11 +54,11 @@ func (c *CloudProvider) isNodeClassDrifted(ctx context.Context, nodeClaim *corev
if err != nil {
return "", fmt.Errorf("calculating ami drift, %w", err)
}
securitygroupDrifted, err := c.areSecurityGroupsDrifted(instance, nodeClass)
securitygroupDrifted, err := c.areSecurityGroupsDrifted(ctx, instance, nodeClass)
if err != nil {
return "", fmt.Errorf("calculating securitygroup drift, %w", err)
}
subnetDrifted, err := c.isSubnetDrifted(instance, nodeClass)
subnetDrifted, err := c.isSubnetDrifted(ctx, instance, nodeClass)
if err != nil {
return "", fmt.Errorf("calculating subnet drift, %w", err)
}
Expand Down Expand Up @@ -91,28 +94,38 @@ func (c *CloudProvider) isAMIDrifted(ctx context.Context, nodeClaim *corev1beta1
return "", nil
}

// Checks if the security groups are drifted, by comparing the EC2NodeClass.Status.Subnets
// Checks if the security groups are drifted, by comparing the subnet returned from the subnetProvider
// to the ec2 instance subnets
func (c *CloudProvider) isSubnetDrifted(instance *instance.Instance, nodeClass *v1beta1.EC2NodeClass) (cloudprovider.DriftReason, error) {
// If the node template status does not have subnets, wait for the subnets to be populated before continuing
if len(nodeClass.Status.Subnets) == 0 {
return "", fmt.Errorf("no subnets exist in status")
func (c *CloudProvider) isSubnetDrifted(ctx context.Context, instance *instance.Instance, nodeClass *v1beta1.EC2NodeClass) (cloudprovider.DriftReason, error) {
subnets, err := c.subnetProvider.List(ctx, nodeClass)
if err != nil {
return "", err
}
_, found := lo.Find(nodeClass.Status.Subnets, func(subnet v1beta1.Subnet) bool {
return subnet.ID == instance.SubnetID
// subnets need to be found to check for drift
if len(subnets) == 0 {
return "", fmt.Errorf("no subnets are discovered")
}

_, found := lo.Find(subnets, func(subnet *ec2.Subnet) bool {
return aws.StringValue(subnet.SubnetId) == instance.SubnetID
})

if !found {
return SubnetDrift, nil
}
return "", nil
}

// Checks if the security groups are drifted, by comparing the EC2NodeClass.Status.SecurityGroups
// Checks if the security groups are drifted, by comparing the security groups returned from the SecurityGroupProvider
// to the ec2 instance security groups
func (c *CloudProvider) areSecurityGroupsDrifted(ec2Instance *instance.Instance, nodeClass *v1beta1.EC2NodeClass) (cloudprovider.DriftReason, error) {
securityGroupIds := sets.New(lo.Map(nodeClass.Status.SecurityGroups, func(sg v1beta1.SecurityGroup, _ int) string { return sg.ID })...)
func (c *CloudProvider) areSecurityGroupsDrifted(ctx context.Context, ec2Instance *instance.Instance, nodeClass *v1beta1.EC2NodeClass) (cloudprovider.DriftReason, error) {
securitygroup, err := c.securityGroupProvider.List(ctx, nodeClass)
if err != nil {
return "", err
}
securityGroupIds := sets.New(lo.Map(securitygroup, func(sg *ec2.SecurityGroup, _ int) string { return aws.StringValue(sg.GroupId) })...)
if len(securityGroupIds) == 0 {
return "", fmt.Errorf("no security groups exist in status")
return "", fmt.Errorf("no security groups are discovered")
}

if !securityGroupIds.Equal(sets.New(ec2Instance.SecurityGroupIDs...)) {
Expand Down
100 changes: 67 additions & 33 deletions pkg/cloudprovider/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,31 +210,65 @@ var _ = Describe("CloudProvider", func() {
ImageId: aws.String(armAMIID),
Architecture: aws.String("arm64"),
CreationDate: aws.String("2022-08-15T12:00:00Z"),
Tags: []*ec2.Tag{
{
Key: aws.String("ami-key-1"),
Value: aws.String("ami-value-1"),
},
},
},
{
Name: aws.String(coretest.RandomName()),
ImageId: aws.String(amdAMIID),
Architecture: aws.String("x86_64"),
CreationDate: aws.String("2022-08-15T12:00:00Z"),
Tags: []*ec2.Tag{
{
Key: aws.String("ami-key-2"),
Value: aws.String("ami-value-2"),
},
},
},
},
})
nodeClass.Status.SecurityGroups = []v1beta1.SecurityGroup{
{
ID: validSecurityGroup,
Name: "test-securitygroup",
},
}
nodeClass.Status.Subnets = []v1beta1.Subnet{
{
ID: validSubnet1,
Zone: "zone-1",
awsEnv.EC2API.DescribeSecurityGroupsOutput.Set(&ec2.DescribeSecurityGroupsOutput{
SecurityGroups: []*ec2.SecurityGroup{
{
GroupId: aws.String(validSecurityGroup),
GroupName: aws.String("test-securitygroup"),
Tags: []*ec2.Tag{
{
Key: aws.String("sg-key"),
Value: aws.String("sg-value"),
},
},
},
},
{
ID: validSubnet2,
Zone: "zone-2",
})
awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{
Subnets: []*ec2.Subnet{
{
SubnetId: aws.String(validSubnet1),
AvailabilityZone: aws.String("zone-1"),
Tags: []*ec2.Tag{
{
Key: aws.String("sn-key-1"),
Value: aws.String("sn-value-1"),
},
},
},
{
SubnetId: aws.String(validSubnet2),
AvailabilityZone: aws.String("zone-2"),
Tags: []*ec2.Tag{
{
Key: aws.String("sn-key-2"),
Value: aws.String("sn-value-2"),
},
},
},
},
}
})
ExpectApplied(ctx, env.Client, nodePool, nodeClass)
instanceTypes, err := cloudProvider.GetInstanceTypes(ctx, nodePool)
Expect(err).ToNot(HaveOccurred())
Expand Down Expand Up @@ -305,8 +339,8 @@ var _ = Describe("CloudProvider", func() {
Expect(isDrifted).To(Equal(cloudprovider.SubnetDrift))
})
It("should return an error if subnets are empty", func() {
nodeClass.Status.Subnets = []v1beta1.Subnet{}
ExpectApplied(ctx, env.Client, nodeClass)
awsEnv.SubnetCache.Flush()
awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{}})
_, err := cloudProvider.IsDrifted(ctx, nodeClaim)
Expect(err).To(HaveOccurred())
})
Expand All @@ -316,39 +350,39 @@ var _ = Describe("CloudProvider", func() {
Expect(isDrifted).To(BeEmpty())
})
It("should return an error if the security groups are empty", func() {
nodeClass.Status.SecurityGroups = []v1beta1.SecurityGroup{}
ExpectApplied(ctx, env.Client, nodeClass)
awsEnv.EC2API.DescribeSecurityGroupsOutput.Set(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{}})
// Instance is a reference to what we return in the GetInstances call
instance.SecurityGroups = []*ec2.GroupIdentifier{{GroupId: aws.String(fake.SecurityGroupID())}}
_, err := cloudProvider.IsDrifted(ctx, nodeClaim)
Expect(err).To(HaveOccurred())
})
It("should return drifted if the instance security groups doesn't match the status", func() {
It("should return drifted if the instance security groups doesn't match the discovered values", func() {
// Instance is a reference to what we return in the GetInstances call
instance.SecurityGroups = []*ec2.GroupIdentifier{{GroupId: aws.String(fake.SecurityGroupID())}}
isDrifted, err := cloudProvider.IsDrifted(ctx, nodeClaim)
Expect(err).ToNot(HaveOccurred())
Expect(isDrifted).To(Equal(cloudprovider.SecurityGroupDrift))
})
It("should return drifted if there are more instance security groups present than in the status", func() {
It("should return drifted if there are more instance security groups present than in the discovered values", func() {
// Instance is a reference to what we return in the GetInstances call
instance.SecurityGroups = []*ec2.GroupIdentifier{{GroupId: aws.String(fake.SecurityGroupID())}, {GroupId: aws.String(validSecurityGroup)}}
isDrifted, err := cloudProvider.IsDrifted(ctx, nodeClaim)
Expect(err).ToNot(HaveOccurred())
Expect(isDrifted).To(Equal(cloudprovider.SecurityGroupDrift))
})
It("should return drifted if more security groups are present than instance security groups", func() {
nodeClass.Status.SecurityGroups = []v1beta1.SecurityGroup{
{
ID: validSecurityGroup,
Name: "test-securitygroup",
},
{
ID: fake.SecurityGroupID(),
Name: "test-securitygroup",
It("should return drifted if more security groups are present than instance security groups then discovered from nodeclass", func() {
awsEnv.EC2API.DescribeSecurityGroupsOutput.Set(&ec2.DescribeSecurityGroupsOutput{
SecurityGroups: []*ec2.SecurityGroup{
{
GroupId: aws.String(validSecurityGroup),
GroupName: aws.String("test-securitygroup"),
},
{
GroupId: aws.String(fake.SecurityGroupID()),
GroupName: aws.String("test-securitygroup"),
},
},
}
ExpectApplied(ctx, env.Client, nodeClass)
})
isDrifted, err := cloudProvider.IsDrifted(ctx, nodeClaim)
Expect(err).ToNot(HaveOccurred())
Expect(isDrifted).To(Equal(cloudprovider.SecurityGroupDrift))
Expand Down Expand Up @@ -422,8 +456,8 @@ var _ = Describe("CloudProvider", func() {
Expect(err).NotTo(HaveOccurred())
Expect(isDrifted).To(BeEmpty())
},
Entry("AMI Drift", v1beta1.EC2NodeClass{Spec: v1beta1.EC2NodeClassSpec{AMISelectorTerms: []v1beta1.AMISelectorTerm{{Tags: map[string]string{"*": "*"}}}}}),
Entry("Subnet Drift", v1beta1.EC2NodeClass{Spec: v1beta1.EC2NodeClassSpec{SubnetSelectorTerms: []v1beta1.SubnetSelectorTerm{{ID: "subnet-test1"}}}}),
Entry("AMI Drift", v1beta1.EC2NodeClass{Spec: v1beta1.EC2NodeClassSpec{AMISelectorTerms: []v1beta1.AMISelectorTerm{{Tags: map[string]string{"ami-key-1": "ami-value-1"}}}}}),
Entry("Subnet Drift", v1beta1.EC2NodeClass{Spec: v1beta1.EC2NodeClassSpec{SubnetSelectorTerms: []v1beta1.SubnetSelectorTerm{{Tags: map[string]string{"sn-key-1": "sn-value-1"}}}}}),
Entry("SecurityGroup Drift", v1beta1.EC2NodeClass{Spec: v1beta1.EC2NodeClassSpec{SecurityGroupSelectorTerms: []v1beta1.SecurityGroupSelectorTerm{{Tags: map[string]string{"sg-key": "sg-value"}}}}}),
)
It("should not return drifted if karpenter.k8s.aws/nodeclass-hash annotation is not present on the NodeClaim", func() {
Expand Down

0 comments on commit 3bce2be

Please sign in to comment.