Skip to content

Commit

Permalink
Allow passing queueURL through interruption queue
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-innis committed Feb 1, 2024
1 parent c9ef004 commit 4de0c50
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 9 deletions.
24 changes: 17 additions & 7 deletions pkg/providers/sqs/sqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"encoding/json"
"fmt"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sqs"
Expand All @@ -31,17 +32,26 @@ type Provider struct {
url string
}

func NewProvider(ctx context.Context, client sqsiface.SQSAPI, queueName string) (*Provider, error) {
ret, err := client.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{
QueueName: aws.String(queueName),
})
if err != nil {
return nil, fmt.Errorf("fetching queue url, %w", err)
func NewProvider(ctx context.Context, client sqsiface.SQSAPI, interruptionQueue string) (*Provider, error) {
var queueURL, queueName string
if !strings.HasPrefix(interruptionQueue, "https://") {
ret, err := client.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{
QueueName: aws.String(interruptionQueue),
})
if err != nil {
return nil, fmt.Errorf("fetching queue url, %w", err)
}
queueName = interruptionQueue
queueURL = aws.StringValue(ret.QueueUrl)
} else {
ss := strings.Split(interruptionQueue, "/")
queueName = ss[len(ss)-1]
queueURL = interruptionQueue
}
return &Provider{
client: client,
name: queueName,
url: aws.StringValue(ret.QueueUrl),
url: queueURL,
}, nil
}

Expand Down
5 changes: 4 additions & 1 deletion test/pkg/environment/aws/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type Environment struct {
IAMAPI *iam.IAM
FISAPI *fis.FIS
EKSAPI *eks.EKS
SQSAPI *servicesqs.SQS
TimeStreamAPI timestreamwriteiface.TimestreamWriteAPI

SQSProvider *sqs.Provider
Expand All @@ -94,6 +95,7 @@ func NewEnvironment(t *testing.T) *Environment {
Region: *session.Config.Region,
Environment: env,

SQSAPI: servicesqs.New(session),
STSAPI: sts.New(session),
EC2API: ec2.New(session),
SSMAPI: ssm.New(session),
Expand All @@ -107,7 +109,8 @@ func NewEnvironment(t *testing.T) *Environment {
}
// Initialize the provider only if the INTERRUPTION_QUEUE environment variable is defined
if v, ok := os.LookupEnv("INTERRUPTION_QUEUE"); ok {
awsEnv.SQSProvider = lo.Must(sqs.NewProvider(env.Context, servicesqs.New(session), v))
awsEnv.SQSProvider = lo.Must(sqs.NewProvider(env.Context, awsEnv.SQSAPI, v))
awsEnv.InterruptionQueue = v
}
return awsEnv
}
Expand Down
10 changes: 9 additions & 1 deletion test/pkg/environment/aws/expectations.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ import (
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/fis"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/mitchellh/hashstructure/v2"
"github.com/samber/lo"
"go.uber.org/multierr"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"

coretest "sigs.k8s.io/karpenter/pkg/test"

"github.com/aws/karpenter-provider-aws/pkg/apis/v1beta1"
Expand Down Expand Up @@ -417,6 +417,14 @@ func (env *Environment) ExpectInstanceProfileDeleted(instanceProfileName, roleNa
Expect(awserrors.IgnoreNotFound(err)).ToNot(HaveOccurred())
}

func (env *Environment) ExpectQueueURL(name string) string {
ret, err := env.SQSAPI.GetQueueUrlWithContext(env, &sqs.GetQueueUrlInput{
QueueName: aws.String(name),
})
Expect(err).ToNot(HaveOccurred())
return aws.StringValue(ret.QueueUrl)
}

func ignoreAlreadyContainsRole(err error) error {
if err != nil {
if strings.Contains(err.Error(), "Cannot exceed quota for InstanceSessionsPerInstanceProfile") {
Expand Down
33 changes: 33 additions & 0 deletions test/suites/interruption/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,39 @@ var _ = Describe("Interruption", func() {
instanceID, err := utils.ParseInstanceID(node.Spec.ProviderID)
Expect(err).ToNot(HaveOccurred())

By("Creating a scheduled change health event in the SQS message queue")
env.ExpectMessagesCreated(scheduledChangeMessage(env.Region, "000000000000", instanceID))
env.EventuallyExpectNotFoundAssertion(node).WithTimeout(time.Minute).Should(Succeed()) // shorten the timeout since we should react faster
env.EventuallyExpectHealthyPodCount(selector, 1)
})
It("should terminate the node when receiving a scheduled change health event when using the interruption queue url", func() {
env.ExpectSettingsOverridden(v1.EnvVar{
Name: "INTERRUPTION_QUEUE",
Value: env.ExpectQueueURL(env.InterruptionQueue),
})

By("Creating a single healthy node with a healthy deployment")
numPods := 1
dep := coretest.Deployment(coretest.DeploymentOptions{
Replicas: int32(numPods),
PodOptions: coretest.PodOptions{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{"app": "my-app"},
},
TerminationGracePeriodSeconds: ptr.Int64(0),
},
})
selector := labels.SelectorFromSet(dep.Spec.Selector.MatchLabels)

env.ExpectCreated(nodeClass, nodePool, dep)

env.EventuallyExpectHealthyPodCount(selector, numPods)
env.ExpectCreatedNodeCount("==", 1)

node := env.Monitor.CreatedNodes()[0]
instanceID, err := utils.ParseInstanceID(node.Spec.ProviderID)
Expect(err).ToNot(HaveOccurred())

By("Creating a scheduled change health event in the SQS message queue")
env.ExpectMessagesCreated(scheduledChangeMessage(env.Region, "000000000000", instanceID))
env.EventuallyExpectNotFoundAssertion(node).WithTimeout(time.Minute).Should(Succeed()) // shorten the timeout since we should react faster
Expand Down

0 comments on commit 4de0c50

Please sign in to comment.