From 4de0c5054087cb5b81f0318aa440f77589fe1edf Mon Sep 17 00:00:00 2001 From: Jonathan Innis Date: Thu, 1 Feb 2024 13:20:42 -0800 Subject: [PATCH] Allow passing queueURL through interruption queue --- pkg/providers/sqs/sqs.go | 24 ++++++++++++----- test/pkg/environment/aws/environment.go | 5 +++- test/pkg/environment/aws/expectations.go | 10 ++++++- test/suites/interruption/suite_test.go | 33 ++++++++++++++++++++++++ 4 files changed, 63 insertions(+), 9 deletions(-) diff --git a/pkg/providers/sqs/sqs.go b/pkg/providers/sqs/sqs.go index 0bf8df4e9b06..71547a545e87 100644 --- a/pkg/providers/sqs/sqs.go +++ b/pkg/providers/sqs/sqs.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" "fmt" + "strings" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/sqs" @@ -31,17 +32,26 @@ type Provider struct { url string } -func NewProvider(ctx context.Context, client sqsiface.SQSAPI, queueName string) (*Provider, error) { - ret, err := client.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{ - QueueName: aws.String(queueName), - }) - if err != nil { - return nil, fmt.Errorf("fetching queue url, %w", err) +func NewProvider(ctx context.Context, client sqsiface.SQSAPI, interruptionQueue string) (*Provider, error) { + var queueURL, queueName string + if !strings.HasPrefix(interruptionQueue, "https://") { + ret, err := client.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{ + QueueName: aws.String(interruptionQueue), + }) + if err != nil { + return nil, fmt.Errorf("fetching queue url, %w", err) + } + queueName = interruptionQueue + queueURL = aws.StringValue(ret.QueueUrl) + } else { + ss := strings.Split(interruptionQueue, "/") + queueName = ss[len(ss)-1] + queueURL = interruptionQueue } return &Provider{ client: client, name: queueName, - url: aws.StringValue(ret.QueueUrl), + url: queueURL, }, nil } diff --git a/test/pkg/environment/aws/environment.go b/test/pkg/environment/aws/environment.go index c243c7c07205..99b34b135daf 100644 --- a/test/pkg/environment/aws/environment.go +++ b/test/pkg/environment/aws/environment.go @@ -69,6 +69,7 @@ type Environment struct { IAMAPI *iam.IAM FISAPI *fis.FIS EKSAPI *eks.EKS + SQSAPI *servicesqs.SQS TimeStreamAPI timestreamwriteiface.TimestreamWriteAPI SQSProvider *sqs.Provider @@ -94,6 +95,7 @@ func NewEnvironment(t *testing.T) *Environment { Region: *session.Config.Region, Environment: env, + SQSAPI: servicesqs.New(session), STSAPI: sts.New(session), EC2API: ec2.New(session), SSMAPI: ssm.New(session), @@ -107,7 +109,8 @@ func NewEnvironment(t *testing.T) *Environment { } // Initialize the provider only if the INTERRUPTION_QUEUE environment variable is defined if v, ok := os.LookupEnv("INTERRUPTION_QUEUE"); ok { - awsEnv.SQSProvider = lo.Must(sqs.NewProvider(env.Context, servicesqs.New(session), v)) + awsEnv.SQSProvider = lo.Must(sqs.NewProvider(env.Context, awsEnv.SQSAPI, v)) + awsEnv.InterruptionQueue = v } return awsEnv } diff --git a/test/pkg/environment/aws/expectations.go b/test/pkg/environment/aws/expectations.go index 6703a651c17d..fd86666a8d8e 100644 --- a/test/pkg/environment/aws/expectations.go +++ b/test/pkg/environment/aws/expectations.go @@ -26,6 +26,7 @@ import ( "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/fis" "github.com/aws/aws-sdk-go/service/iam" + "github.com/aws/aws-sdk-go/service/sqs" "github.com/aws/aws-sdk-go/service/ssm" "github.com/aws/aws-sdk-go/service/sts" "github.com/mitchellh/hashstructure/v2" @@ -33,7 +34,6 @@ import ( "go.uber.org/multierr" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" - coretest "sigs.k8s.io/karpenter/pkg/test" "github.com/aws/karpenter-provider-aws/pkg/apis/v1beta1" @@ -417,6 +417,14 @@ func (env *Environment) ExpectInstanceProfileDeleted(instanceProfileName, roleNa Expect(awserrors.IgnoreNotFound(err)).ToNot(HaveOccurred()) } +func (env *Environment) ExpectQueueURL(name string) string { + ret, err := env.SQSAPI.GetQueueUrlWithContext(env, &sqs.GetQueueUrlInput{ + QueueName: aws.String(name), + }) + Expect(err).ToNot(HaveOccurred()) + return aws.StringValue(ret.QueueUrl) +} + func ignoreAlreadyContainsRole(err error) error { if err != nil { if strings.Contains(err.Error(), "Cannot exceed quota for InstanceSessionsPerInstanceProfile") { diff --git a/test/suites/interruption/suite_test.go b/test/suites/interruption/suite_test.go index bd338cd39899..01d31605cd11 100644 --- a/test/suites/interruption/suite_test.go +++ b/test/suites/interruption/suite_test.go @@ -182,6 +182,39 @@ var _ = Describe("Interruption", func() { instanceID, err := utils.ParseInstanceID(node.Spec.ProviderID) Expect(err).ToNot(HaveOccurred()) + By("Creating a scheduled change health event in the SQS message queue") + env.ExpectMessagesCreated(scheduledChangeMessage(env.Region, "000000000000", instanceID)) + env.EventuallyExpectNotFoundAssertion(node).WithTimeout(time.Minute).Should(Succeed()) // shorten the timeout since we should react faster + env.EventuallyExpectHealthyPodCount(selector, 1) + }) + It("should terminate the node when receiving a scheduled change health event when using the interruption queue url", func() { + env.ExpectSettingsOverridden(v1.EnvVar{ + Name: "INTERRUPTION_QUEUE", + Value: env.ExpectQueueURL(env.InterruptionQueue), + }) + + By("Creating a single healthy node with a healthy deployment") + numPods := 1 + dep := coretest.Deployment(coretest.DeploymentOptions{ + Replicas: int32(numPods), + PodOptions: coretest.PodOptions{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{"app": "my-app"}, + }, + TerminationGracePeriodSeconds: ptr.Int64(0), + }, + }) + selector := labels.SelectorFromSet(dep.Spec.Selector.MatchLabels) + + env.ExpectCreated(nodeClass, nodePool, dep) + + env.EventuallyExpectHealthyPodCount(selector, numPods) + env.ExpectCreatedNodeCount("==", 1) + + node := env.Monitor.CreatedNodes()[0] + instanceID, err := utils.ParseInstanceID(node.Spec.ProviderID) + Expect(err).ToNot(HaveOccurred()) + By("Creating a scheduled change health event in the SQS message queue") env.ExpectMessagesCreated(scheduledChangeMessage(env.Region, "000000000000", instanceID)) env.EventuallyExpectNotFoundAssertion(node).WithTimeout(time.Minute).Should(Succeed()) // shorten the timeout since we should react faster