From bf816f5f1734c55422c0d9c393e192cb5190b9f6 Mon Sep 17 00:00:00 2001 From: Jonathan Innis Date: Fri, 10 May 2024 14:51:18 -0500 Subject: [PATCH] fix: Ensure shallow copy of data when returning back cached data (#6167) --- go.mod | 6 +- go.sum | 12 +-- .../karpenter.k8s.aws_ec2nodeclasses.yaml | 2 +- pkg/providers/amifamily/ami.go | 8 +- pkg/providers/amifamily/suite_test.go | 52 ++++++++++- pkg/providers/instancetype/instancetype.go | 4 +- pkg/providers/instancetype/suite_test.go | 37 +++++++- pkg/providers/securitygroup/securitygroup.go | 4 +- pkg/providers/securitygroup/suite_test.go | 68 ++++++++++++++ pkg/providers/subnet/subnet.go | 4 +- pkg/providers/subnet/suite_test.go | 89 +++++++++++++++++++ 11 files changed, 266 insertions(+), 20 deletions(-) diff --git a/go.mod b/go.mod index 13892951405b..ceed5fe1bde7 100644 --- a/go.mod +++ b/go.mod @@ -89,10 +89,10 @@ require ( go.opencensus.io v0.24.0 // indirect go.uber.org/automaxprocs v1.5.3 // indirect golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect - golang.org/x/net v0.19.0 // indirect + golang.org/x/net v0.23.0 // indirect golang.org/x/oauth2 v0.13.0 // indirect - golang.org/x/sys v0.16.0 // indirect - golang.org/x/term v0.15.0 // indirect + golang.org/x/sys v0.18.0 // indirect + golang.org/x/term v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/tools v0.16.1 // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect diff --git a/go.sum b/go.sum index e08ffd88a75e..05aa672dbb80 100644 --- a/go.sum +++ b/go.sum @@ -461,8 +461,8 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= -golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= -golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= +golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= +golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -531,14 +531,14 @@ golang.org/x/sys v0.0.0-20220708085239-5a0f0661e09d/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= -golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= -golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= -golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= +golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= +golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/pkg/apis/crds/karpenter.k8s.aws_ec2nodeclasses.yaml b/pkg/apis/crds/karpenter.k8s.aws_ec2nodeclasses.yaml index af4acba6004c..79d39053c493 100644 --- a/pkg/apis/crds/karpenter.k8s.aws_ec2nodeclasses.yaml +++ b/pkg/apis/crds/karpenter.k8s.aws_ec2nodeclasses.yaml @@ -3,7 +3,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.14.0 + controller-gen.kubebuilder.io/version: v0.15.0 name: ec2nodeclasses.karpenter.k8s.aws spec: group: karpenter.k8s.aws diff --git a/pkg/providers/amifamily/ami.go b/pkg/providers/amifamily/ami.go index 83bb807a6a6d..e163d114cc1a 100644 --- a/pkg/providers/amifamily/ami.go +++ b/pkg/providers/amifamily/ami.go @@ -135,7 +135,9 @@ func (p *Provider) Get(ctx context.Context, nodeClass *v1beta1.EC2NodeClass, opt func (p *Provider) getDefaultAMIs(ctx context.Context, nodeClass *v1beta1.EC2NodeClass, options *Options) (res AMIs, err error) { if images, ok := p.cache.Get(lo.FromPtr(nodeClass.Spec.AMIFamily)); ok { - return images.(AMIs), nil + // Ensure what's returned from this function is a deep-copy of AMIs so alterations + // to the data don't affect the original + return append(AMIs{}, images.(AMIs)...), nil } amiFamily := GetAMIFamily(nodeClass.Spec.AMIFamily, options) kubernetesVersion, err := p.versionProvider.Get(ctx) @@ -187,7 +189,9 @@ func (p *Provider) getAMIs(ctx context.Context, terms []v1beta1.AMISelectorTerm) return nil, err } if images, ok := p.cache.Get(fmt.Sprintf("%d", hash)); ok { - return images.(AMIs), nil + // Ensure what's returned from this function is a deep-copy of AMIs so alterations + // to the data don't affect the original + return append(AMIs{}, images.(AMIs)...), nil } images := map[uint64]AMI{} for _, filtersAndOwners := range filterAndOwnerSets { diff --git a/pkg/providers/amifamily/suite_test.go b/pkg/providers/amifamily/suite_test.go index c7f5cb2da1a2..421abae641c5 100644 --- a/pkg/providers/amifamily/suite_test.go +++ b/pkg/providers/amifamily/suite_test.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "sort" + "sync" "testing" "time" @@ -74,7 +75,7 @@ var _ = BeforeEach(func() { { Name: aws.String(amd64AMI), ImageId: aws.String("amd64-ami-id"), - CreationDate: aws.String(time.Now().Format(time.RFC3339)), + CreationDate: aws.String(time.Time{}.Format(time.RFC3339)), Architecture: aws.String("x86_64"), Tags: []*ec2.Tag{ {Key: aws.String("Name"), Value: aws.String(amd64AMI)}, @@ -84,7 +85,7 @@ var _ = BeforeEach(func() { { Name: aws.String(arm64AMI), ImageId: aws.String("arm64-ami-id"), - CreationDate: aws.String(time.Now().Add(time.Minute).Format(time.RFC3339)), + CreationDate: aws.String(time.Time{}.Add(time.Minute).Format(time.RFC3339)), Architecture: aws.String("arm64"), Tags: []*ec2.Tag{ {Key: aws.String("Name"), Value: aws.String(arm64AMI)}, @@ -94,7 +95,7 @@ var _ = BeforeEach(func() { { Name: aws.String(amd64NvidiaAMI), ImageId: aws.String("amd64-nvidia-ami-id"), - CreationDate: aws.String(time.Now().Add(2 * time.Minute).Format(time.RFC3339)), + CreationDate: aws.String(time.Time{}.Add(2 * time.Minute).Format(time.RFC3339)), Architecture: aws.String("x86_64"), Tags: []*ec2.Tag{ {Key: aws.String("Name"), Value: aws.String(amd64NvidiaAMI)}, @@ -104,7 +105,7 @@ var _ = BeforeEach(func() { { Name: aws.String(arm64NvidiaAMI), ImageId: aws.String("arm64-nvidia-ami-id"), - CreationDate: aws.String(time.Now().Add(2 * time.Minute).Format(time.RFC3339)), + CreationDate: aws.String(time.Time{}.Add(2 * time.Minute).Format(time.RFC3339)), Architecture: aws.String("arm64"), Tags: []*ec2.Tag{ {Key: aws.String("Name"), Value: aws.String(arm64NvidiaAMI)}, @@ -186,6 +187,49 @@ var _ = Describe("AMIProvider", func() { Expect(err).ToNot(HaveOccurred()) Expect(amis).To(HaveLen(0)) }) + It("should not cause data races when calling Get() simultaneously", func() { + nodeClass.Spec.AMISelectorTerms = []v1beta1.AMISelectorTerm{ + { + ID: "amd64-ami-id", + }, + { + ID: "arm64-ami-id", + }, + } + wg := sync.WaitGroup{} + for i := 0; i < 10000; i++ { + wg.Add(1) + go func() { + defer wg.Done() + defer GinkgoRecover() + images, err := awsEnv.AMIProvider.Get(ctx, nodeClass, &amifamily.Options{}) + Expect(err).ToNot(HaveOccurred()) + + Expect(images).To(HaveLen(2)) + // Sort everything in parallel and ensure that we don't get data races + images.Sort() + Expect(images).To(BeEquivalentTo([]amifamily.AMI{ + { + Name: arm64AMI, + AmiID: "arm64-ami-id", + CreationDate: time.Time{}.Add(time.Minute).Format(time.RFC3339), + Requirements: scheduling.NewLabelRequirements(map[string]string{ + v1.LabelArchStable: corev1beta1.ArchitectureArm64, + }), + }, + { + Name: amd64AMI, + AmiID: "amd64-ami-id", + CreationDate: time.Time{}.Format(time.RFC3339), + Requirements: scheduling.NewLabelRequirements(map[string]string{ + v1.LabelArchStable: corev1beta1.ArchitectureAmd64, + }), + }, + })) + }() + } + wg.Wait() + }) Context("SSM Alias Missing", func() { It("should succeed to partially resolve AMIs if all SSM aliases don't exist (Al2)", func() { nodeClass.Spec.AMIFamily = &v1beta1.AMIFamilyAL2 diff --git a/pkg/providers/instancetype/instancetype.go b/pkg/providers/instancetype/instancetype.go index a4f007aa43e2..e822654846d7 100644 --- a/pkg/providers/instancetype/instancetype.go +++ b/pkg/providers/instancetype/instancetype.go @@ -140,7 +140,9 @@ func (p *Provider) List(ctx context.Context, kc *corev1beta1.KubeletConfiguratio systemReservedHash, ) if item, ok := p.cache.Get(key); ok { - return item.([]*cloudprovider.InstanceType), nil + // Ensure what's returned from this function is a shallow-copy of the slice (not a deep-copy of the data itself) + // so that modifications to the ordering of the data don't affect the original + return append([]*cloudprovider.InstanceType{}, item.([]*cloudprovider.InstanceType)...), nil } result := lo.Map(instanceTypes, func(i *ec2.InstanceTypeInfo, _ int) *cloudprovider.InstanceType { return NewInstanceType(ctx, i, kc, p.region, nodeClass, p.createOfferings(ctx, i, instanceTypeOfferings[aws.StringValue(i.InstanceType)], zones, subnetZones)) diff --git a/pkg/providers/instancetype/suite_test.go b/pkg/providers/instancetype/suite_test.go index 49a01ca4df3d..f7134570104f 100644 --- a/pkg/providers/instancetype/suite_test.go +++ b/pkg/providers/instancetype/suite_test.go @@ -21,6 +21,7 @@ import ( "net" "sort" "strings" + "sync" "testing" "time" @@ -736,7 +737,6 @@ var _ = Describe("InstanceTypes", func() { ExpectScheduled(ctx, env.Client, pod) }) - Context("Overhead", func() { var info *ec2.InstanceTypeInfo BeforeEach(func() { @@ -1606,6 +1606,41 @@ var _ = Describe("InstanceTypes", func() { }) }) }) + It("should not cause data races when calling List() simultaneously", func() { + mu := sync.RWMutex{} + var instanceTypeOrder []string + wg := sync.WaitGroup{} + for i := 0; i < 10000; i++ { + wg.Add(1) + go func() { + defer wg.Done() + defer GinkgoRecover() + instanceTypes, err := awsEnv.InstanceTypesProvider.List(ctx, &corev1beta1.KubeletConfiguration{}, nodeClass) + Expect(err).ToNot(HaveOccurred()) + + // Sort everything in parallel and ensure that we don't get data races + sort.Slice(instanceTypes, func(i, j int) bool { + return instanceTypes[i].Name < instanceTypes[j].Name + }) + // Get the ordering of the instance types based on name + tempInstanceTypeOrder := lo.Map(instanceTypes, func(i *corecloudprovider.InstanceType, _ int) string { + return i.Name + }) + // Expect that all the elements in the instance type list are unique + Expect(lo.Uniq(tempInstanceTypeOrder)).To(HaveLen(len(tempInstanceTypeOrder))) + + // We have to lock since we are doing simultaneous access to this value + mu.Lock() + if len(instanceTypeOrder) == 0 { + instanceTypeOrder = tempInstanceTypeOrder + } else { + Expect(tempInstanceTypeOrder).To(BeEquivalentTo(instanceTypeOrder)) + } + mu.Unlock() + }() + } + wg.Wait() + }) }) // generateSpotPricing creates a spot price history output for use in a mock that has all spot offerings discounted by 50% diff --git a/pkg/providers/securitygroup/securitygroup.go b/pkg/providers/securitygroup/securitygroup.go index c6efc9acca70..dd1f5ebf1429 100644 --- a/pkg/providers/securitygroup/securitygroup.go +++ b/pkg/providers/securitygroup/securitygroup.go @@ -77,7 +77,9 @@ func (p *Provider) getSecurityGroups(ctx context.Context, filterSets [][]*ec2.Fi return nil, err } if sg, ok := p.cache.Get(fmt.Sprint(hash)); ok { - return sg.([]*ec2.SecurityGroup), nil + // Ensure what's returned from this function is a shallow-copy of the slice (not a deep-copy of the data itself) + // so that modifications to the ordering of the data don't affect the original + return append([]*ec2.SecurityGroup{}, sg.([]*ec2.SecurityGroup)...), nil } securityGroups := map[string]*ec2.SecurityGroup{} for _, filters := range filterSets { diff --git a/pkg/providers/securitygroup/suite_test.go b/pkg/providers/securitygroup/suite_test.go index 7756f747222d..30433be3744e 100644 --- a/pkg/providers/securitygroup/suite_test.go +++ b/pkg/providers/securitygroup/suite_test.go @@ -16,6 +16,8 @@ package securitygroup_test import ( "context" + "sort" + "sync" "testing" "github.com/aws/aws-sdk-go/aws" @@ -265,6 +267,72 @@ var _ = Describe("SecurityGroupProvider", func() { }, }, securityGroups) }) + It("should not cause data races when calling List() simultaneously", func() { + wg := sync.WaitGroup{} + for i := 0; i < 10000; i++ { + wg.Add(1) + go func() { + defer wg.Done() + defer GinkgoRecover() + securityGroups, err := awsEnv.SecurityGroupProvider.List(ctx, nodeClass) + Expect(err).ToNot(HaveOccurred()) + + Expect(securityGroups).To(HaveLen(3)) + // Sort everything in parallel and ensure that we don't get data races + sort.Slice(securityGroups, func(i, j int) bool { + return *securityGroups[i].GroupId < *securityGroups[j].GroupId + }) + Expect(securityGroups).To(BeEquivalentTo([]*ec2.SecurityGroup{ + { + GroupId: lo.ToPtr("sg-test1"), + GroupName: lo.ToPtr("securityGroup-test1"), + Tags: []*ec2.Tag{ + { + Key: lo.ToPtr("Name"), + Value: lo.ToPtr("test-security-group-1"), + }, + { + Key: lo.ToPtr("foo"), + Value: lo.ToPtr("bar"), + }, + }, + }, + { + GroupId: lo.ToPtr("sg-test2"), + GroupName: lo.ToPtr("securityGroup-test2"), + Tags: []*ec2.Tag{ + { + Key: lo.ToPtr("Name"), + Value: lo.ToPtr("test-security-group-2"), + }, + { + Key: lo.ToPtr("foo"), + Value: lo.ToPtr("bar"), + }, + }, + }, + { + GroupId: lo.ToPtr("sg-test3"), + GroupName: lo.ToPtr("securityGroup-test3"), + Tags: []*ec2.Tag{ + { + Key: lo.ToPtr("Name"), + Value: lo.ToPtr("test-security-group-3"), + }, + { + Key: lo.ToPtr("TestTag"), + }, + { + Key: lo.ToPtr("foo"), + Value: lo.ToPtr("bar"), + }, + }, + }, + })) + }() + } + wg.Wait() + }) }) func ExpectConsistsOfSecurityGroups(expected, actual []*ec2.SecurityGroup) { diff --git a/pkg/providers/subnet/subnet.go b/pkg/providers/subnet/subnet.go index b343bfc5af3e..564cbe7788f9 100644 --- a/pkg/providers/subnet/subnet.go +++ b/pkg/providers/subnet/subnet.go @@ -67,7 +67,9 @@ func (p *Provider) List(ctx context.Context, nodeClass *v1beta1.EC2NodeClass) ([ return nil, err } if subnets, ok := p.cache.Get(fmt.Sprint(hash)); ok { - return subnets.([]*ec2.Subnet), nil + // Ensure what's returned from this function is a shallow-copy of the slice (not a deep-copy of the data itself) + // so that modifications to the ordering of the data don't affect the original + return append([]*ec2.Subnet{}, subnets.([]*ec2.Subnet)...), nil } // Ensure that all the subnets that are returned here are unique diff --git a/pkg/providers/subnet/suite_test.go b/pkg/providers/subnet/suite_test.go index 31daa9058f31..6840a22a3ffd 100644 --- a/pkg/providers/subnet/suite_test.go +++ b/pkg/providers/subnet/suite_test.go @@ -16,6 +16,8 @@ package subnet_test import ( "context" + "sort" + "sync" "testing" "github.com/aws/aws-sdk-go/aws" @@ -240,6 +242,93 @@ var _ = Describe("SubnetProvider", func() { Expect(onlyPrivate).To(BeTrue()) }) }) + It("should not cause data races when calling List() simultaneously", func() { + wg := sync.WaitGroup{} + for i := 0; i < 10000; i++ { + wg.Add(1) + go func() { + defer wg.Done() + defer GinkgoRecover() + subnets, err := awsEnv.SubnetProvider.List(ctx, nodeClass) + Expect(err).ToNot(HaveOccurred()) + + Expect(subnets).To(HaveLen(4)) + // Sort everything in parallel and ensure that we don't get data races + sort.Slice(subnets, func(i, j int) bool { + if int(*subnets[i].AvailableIpAddressCount) != int(*subnets[j].AvailableIpAddressCount) { + return int(*subnets[i].AvailableIpAddressCount) > int(*subnets[j].AvailableIpAddressCount) + } + return *subnets[i].SubnetId < *subnets[j].SubnetId + }) + Expect(subnets).To(BeEquivalentTo([]*ec2.Subnet{ + { + AvailabilityZone: lo.ToPtr("test-zone-1a"), + AvailableIpAddressCount: lo.ToPtr[int64](100), + SubnetId: lo.ToPtr("subnet-test1"), + MapPublicIpOnLaunch: lo.ToPtr(false), + Tags: []*ec2.Tag{ + { + Key: lo.ToPtr("Name"), + Value: lo.ToPtr("test-subnet-1"), + }, + { + Key: lo.ToPtr("foo"), + Value: lo.ToPtr("bar"), + }, + }, + }, + { + AvailabilityZone: lo.ToPtr("test-zone-1b"), + AvailableIpAddressCount: lo.ToPtr[int64](100), + MapPublicIpOnLaunch: lo.ToPtr(true), + SubnetId: lo.ToPtr("subnet-test2"), + + Tags: []*ec2.Tag{ + { + Key: lo.ToPtr("Name"), + Value: lo.ToPtr("test-subnet-2"), + }, + { + Key: lo.ToPtr("foo"), + Value: lo.ToPtr("bar"), + }, + }, + }, + { + AvailabilityZone: lo.ToPtr("test-zone-1c"), + AvailableIpAddressCount: lo.ToPtr[int64](100), + SubnetId: lo.ToPtr("subnet-test3"), + Tags: []*ec2.Tag{ + { + Key: lo.ToPtr("Name"), + Value: lo.ToPtr("test-subnet-3"), + }, + { + Key: lo.ToPtr("TestTag"), + }, + { + Key: lo.ToPtr("foo"), + Value: lo.ToPtr("bar"), + }, + }, + }, + { + AvailabilityZone: lo.ToPtr("test-zone-1a-local"), + AvailableIpAddressCount: lo.ToPtr[int64](100), + SubnetId: lo.ToPtr("subnet-test4"), + MapPublicIpOnLaunch: lo.ToPtr(true), + Tags: []*ec2.Tag{ + { + Key: lo.ToPtr("Name"), + Value: lo.ToPtr("test-subnet-4"), + }, + }, + }, + })) + }() + } + wg.Wait() + }) }) func ExpectConsistsOfSubnets(expected, actual []*ec2.Subnet) {