diff --git a/test/suites/integration/networkinterfaces_test.go b/test/suites/integration/networkinterfaces_test.go index a39e461dc8c6..0db9014b265b 100644 --- a/test/suites/integration/networkinterfaces_test.go +++ b/test/suites/integration/networkinterfaces_test.go @@ -34,7 +34,8 @@ var _ = Describe("NetworkInterfaces", func() { subnets := env.GetSubnets(map[string]string{"karpenter.sh/discovery": settings.FromContext(env.Context).ClusterName}) Expect(len(subnets)).ToNot(Equal(0)) allSubnets := lo.Flatten(lo.Values(subnets)) - ExpectAssociatePublicIPAddressToBe(false, allSubnets...) + subnetCleanUpFn := ExpectAssociatePublicIPAddressToBe(false, allSubnets...) + defer subnetCleanUpFn() provider := awstest.AWSNodeTemplate(v1alpha1.AWSNodeTemplateSpec{ AWS: v1alpha1.AWS{ @@ -58,7 +59,8 @@ var _ = Describe("NetworkInterfaces", func() { subnets := env.GetSubnets(map[string]string{"karpenter.sh/discovery": settings.FromContext(env.Context).ClusterName}) Expect(len(subnets)).ToNot(Equal(0)) allSubnets := lo.Flatten(lo.Values(subnets)) - ExpectAssociatePublicIPAddressToBe(true, allSubnets...) + subnetCleanUpFn := ExpectAssociatePublicIPAddressToBe(true, allSubnets...) + defer subnetCleanUpFn() provider := awstest.AWSNodeTemplate(v1alpha1.AWSNodeTemplateSpec{ AWS: v1alpha1.AWS{ @@ -82,7 +84,8 @@ var _ = Describe("NetworkInterfaces", func() { subnets := env.GetSubnets(map[string]string{"karpenter.sh/discovery": settings.FromContext(env.Context).ClusterName}) Expect(len(subnets)).ToNot(Equal(0)) allSubnets := lo.Flatten(lo.Values(subnets)) - ExpectAssociatePublicIPAddressToBe(false, allSubnets...) + subnetCleanUpFn := ExpectAssociatePublicIPAddressToBe(false, allSubnets...) + defer subnetCleanUpFn() desc := "a test network interface" provider := awstest.AWSNodeTemplate(v1alpha1.AWSNodeTemplateSpec{ @@ -160,14 +163,43 @@ var _ = Describe("NetworkInterfaces", func() { }) }) -func ExpectAssociatePublicIPAddressToBe(enabled bool, subnetIDs ...string) { - for subnetID := range subnetIDs { +func ExpectAssociatePublicIPAddressToBe(enabled bool, subnetIDs ...string) func() { + var subnetIDsToModify []string + // check if the subnet matches the desired "enabled" state + for _, subnetID := range subnetIDs { + subnetsOut, err := env.EC2API.DescribeSubnets(&ec2.DescribeSubnetsInput{ + SubnetIds: []*string{&subnetID}, + }) + Expect(err).To(BeNil()) + Expect(len(subnetsOut.Subnets)).To(Equal(1)) + subnet := subnetsOut.Subnets[0] + if *subnet.MapPublicIpOnLaunch == enabled { + continue + } + subnetIDsToModify = append(subnetIDsToModify, subnetID) + } + + // modify the subnets that do not match the desired "enabled" state + for _, subnetID := range subnetIDsToModify { _, err := env.EC2API.ModifySubnetAttribute(&ec2.ModifySubnetAttributeInput{ MapPublicIpOnLaunch: &ec2.AttributeBooleanValue{ Value: aws.Bool(enabled), }, - SubnetId: aws.String(subnetIDs[subnetID]), + SubnetId: aws.String(subnetID), }) Expect(err).To(BeNil()) } + + // return a clean-up func to revert the modified subnets to the previous state + return func() { + for _, subnetID := range subnetIDsToModify { + _, err := env.EC2API.ModifySubnetAttribute(&ec2.ModifySubnetAttributeInput{ + MapPublicIpOnLaunch: &ec2.AttributeBooleanValue{ + Value: aws.Bool(!enabled), + }, + SubnetId: aws.String(subnetID), + }) + Expect(err).To(BeNil()) + } + } }