diff --git a/pkg/providers/amifamily/al2.go b/pkg/providers/amifamily/al2.go index 035fd4459c5a..ebd385250dc7 100644 --- a/pkg/providers/amifamily/al2.go +++ b/pkg/providers/amifamily/al2.go @@ -69,29 +69,14 @@ func (a AL2) DescribeImageQuery(ctx context.Context, ssmProvider ssm.Provider, k return DescribeImageQuery{}, fmt.Errorf(`failed to discover any AMIs for alias "al2@%s"`, amiVersion) } - // Only inject requirements if we were able to discover accelerated AMIs. If we weren't, we should use the standard - // AMI for all nodes. - hasAcceleratedAMIs := lo.ContainsBy(lo.Values(ids), func(variants []Variant) bool { - for _, v := range variants { - if v != VariantStandard { - return true - } - } - return false - }) - requirements := map[string][]scheduling.Requirements{} - if hasAcceleratedAMIs { - requirements = lo.MapValues(ids, func(variants []Variant, _ string) []scheduling.Requirements { - return lo.Map(variants, func(v Variant, _ int) scheduling.Requirements { return v.Requirements() }) - }) - } - return DescribeImageQuery{ Filters: []*ec2.Filter{{ Name: lo.ToPtr("image-id"), Values: lo.ToSlicePtr(lo.Keys(ids)), }}, - KnownRequirements: requirements, + KnownRequirements: lo.MapValues(ids, func(variants []Variant, _ string) []scheduling.Requirements { + return lo.Map(variants, func(v Variant, _ int) scheduling.Requirements { return v.Requirements() }) + }), }, nil } diff --git a/pkg/providers/amifamily/al2023.go b/pkg/providers/amifamily/al2023.go index 591c0d98512b..4626b21198f6 100644 --- a/pkg/providers/amifamily/al2023.go +++ b/pkg/providers/amifamily/al2023.go @@ -37,15 +37,11 @@ type AL2023 struct { func (a AL2023) DescribeImageQuery(ctx context.Context, ssmProvider ssm.Provider, k8sVersion string, amiVersion string) (DescribeImageQuery, error) { ids := map[string]Variant{} - for _, arch := range []string{ - "x86_64", - "arm64", + for arch, variants := range map[string][]Variant{ + "x86_64": []Variant{VariantStandard, VariantNvidia, VariantNeuron}, + "arm64": []Variant{VariantStandard}, } { - for _, variant := range []Variant{ - VariantStandard, - VariantNvidia, - VariantNeuron, - } { + for _, variant := range variants { path := a.resolvePath(arch, string(variant), k8sVersion, amiVersion) imageID, err := ssmProvider.Get(ctx, path) if err != nil { @@ -59,24 +55,14 @@ func (a AL2023) DescribeImageQuery(ctx context.Context, ssmProvider ssm.Provider return DescribeImageQuery{}, fmt.Errorf(`failed to discover AMIs for alias "al2023@%s"`, amiVersion) } - // Only inject requirements if we were able to discover accelerated AMIs. If we weren't, we should use the standard - // AMI for all nodes. - hasAcceleratedAMIs := lo.ContainsBy(lo.Values(ids), func(v Variant) bool { - return v != VariantStandard - }) - requirements := map[string][]scheduling.Requirements{} - if hasAcceleratedAMIs { - requirements = lo.MapValues(ids, func(v Variant, _ string) []scheduling.Requirements { - return []scheduling.Requirements{v.Requirements()} - }) - } - return DescribeImageQuery{ Filters: []*ec2.Filter{{ Name: lo.ToPtr("image-id"), Values: lo.ToSlicePtr(lo.Keys(ids)), }}, - KnownRequirements: requirements, + KnownRequirements: lo.MapValues(ids, func(v Variant, _ string) []scheduling.Requirements { + return []scheduling.Requirements{v.Requirements()} + }), }, nil } diff --git a/pkg/providers/amifamily/bottlerocket.go b/pkg/providers/amifamily/bottlerocket.go index bf5873ef4607..536f1e8d4972 100644 --- a/pkg/providers/amifamily/bottlerocket.go +++ b/pkg/providers/amifamily/bottlerocket.go @@ -57,29 +57,14 @@ func (b Bottlerocket) DescribeImageQuery(ctx context.Context, ssmProvider ssm.Pr return DescribeImageQuery{}, fmt.Errorf(`failed to discover any AMIs for alias "bottlerocket@%s"`, amiVersion) } - // Only inject requirements if we were able to discover accelerated AMIs. If we weren't, we should use the standard - // AMI for all nodes. - hasAcceleratedAMIs := lo.ContainsBy(lo.Values(ids), func(variants []Variant) bool { - for _, v := range variants { - if v != VariantStandard { - return true - } - } - return false - }) - requirements := map[string][]scheduling.Requirements{} - if hasAcceleratedAMIs { - requirements = lo.MapValues(ids, func(variants []Variant, _ string) []scheduling.Requirements { - return lo.Map(variants, func(v Variant, _ int) scheduling.Requirements { return v.Requirements() }) - }) - } - return DescribeImageQuery{ Filters: []*ec2.Filter{{ Name: lo.ToPtr("image-id"), Values: lo.ToSlicePtr(lo.Keys(ids)), }}, - KnownRequirements: requirements, + KnownRequirements: lo.MapValues(ids, func(variants []Variant, _ string) []scheduling.Requirements { + return lo.Map(variants, func(v Variant, _ int) scheduling.Requirements { return v.Requirements() }) + }), }, nil }