Skip to content

Commit

Permalink
chore: Convert all AWS providers to interfaces (#6001)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-innis authored Apr 8, 2024
1 parent 8a4458e commit af43cf8
Show file tree
Hide file tree
Showing 30 changed files with 257 additions and 194 deletions.
2 changes: 1 addition & 1 deletion hack/code/prices_gen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func main() {
// record prices for each region we are interested in
for _, region := range getAWSRegions(opts.partition) {
log.Println("fetching for", region)
pricingProvider := pricing.NewProvider(ctx, pricing.NewAPI(sess, region), ec2, region)
pricingProvider := pricing.NewDefaultProvider(ctx, pricing.NewAPI(sess, region), ec2, region)
controller := controllerspricing.NewController(pricingProvider)
_, err := controller.Reconcile(ctx, reconcile.Request{NamespacedName: types.NamespacedName{}})
if err != nil {
Expand Down
19 changes: 10 additions & 9 deletions pkg/cloudprovider/cloudprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,18 @@ import (
var _ cloudprovider.CloudProvider = (*CloudProvider)(nil)

type CloudProvider struct {
instanceTypeProvider *instancetype.Provider
instanceProvider *instance.Provider
kubeClient client.Client
amiProvider *amifamily.Provider
securityGroupProvider *securitygroup.Provider
subnetProvider *subnet.Provider
recorder events.Recorder
kubeClient client.Client
recorder events.Recorder

instanceTypeProvider instancetype.Provider
instanceProvider instance.Provider
amiProvider amifamily.Provider
securityGroupProvider securitygroup.Provider
subnetProvider subnet.Provider
}

func New(instanceTypeProvider *instancetype.Provider, instanceProvider *instance.Provider, recorder events.Recorder,
kubeClient client.Client, amiProvider *amifamily.Provider, securityGroupProvider *securitygroup.Provider, subnetProvider *subnet.Provider) *CloudProvider {
func New(instanceTypeProvider instancetype.Provider, instanceProvider instance.Provider, recorder events.Recorder,
kubeClient client.Client, amiProvider amifamily.Provider, securityGroupProvider securitygroup.Provider, subnetProvider subnet.Provider) *CloudProvider {
return &CloudProvider{
instanceTypeProvider: instanceTypeProvider,
instanceProvider: instanceProvider,
Expand Down
6 changes: 3 additions & 3 deletions pkg/controllers/controllers.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ import (
)

func NewControllers(ctx context.Context, sess *session.Session, clk clock.Clock, kubeClient client.Client, recorder events.Recorder,
unavailableOfferings *cache.UnavailableOfferings, cloudProvider cloudprovider.CloudProvider, subnetProvider *subnet.Provider,
securityGroupProvider *securitygroup.Provider, instanceProfileProvider *instanceprofile.Provider, instanceProvider *instance.Provider,
pricingProvider *pricing.Provider, amiProvider *amifamily.Provider, launchTemplateProvider *launchtemplate.Provider) []controller.Controller {
unavailableOfferings *cache.UnavailableOfferings, cloudProvider cloudprovider.CloudProvider, subnetProvider subnet.Provider,
securityGroupProvider securitygroup.Provider, instanceProfileProvider instanceprofile.Provider, instanceProvider instance.Provider,
pricingProvider pricing.Provider, amiProvider amifamily.Provider, launchTemplateProvider launchtemplate.Provider) []controller.Controller {

controllers := []controller.Controller{
nodeclass.NewController(kubeClient, recorder, subnetProvider, securityGroupProvider, amiProvider, instanceProfileProvider, launchTemplateProvider),
Expand Down
4 changes: 2 additions & 2 deletions pkg/controllers/interruption/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ type Controller struct {
kubeClient client.Client
clk clock.Clock
recorder events.Recorder
sqsProvider *sqs.Provider
sqsProvider sqs.Provider
unavailableOfferingsCache *cache.UnavailableOfferings
parser *EventParser
cm *pretty.ChangeMonitor
}

func NewController(kubeClient client.Client, clk clock.Clock, recorder events.Recorder,
sqsProvider *sqs.Provider, unavailableOfferingsCache *cache.UnavailableOfferings) *Controller {
sqsProvider sqs.Provider, unavailableOfferingsCache *cache.UnavailableOfferings) *Controller {

return &Controller{
kubeClient: kubeClient,
Expand Down
5 changes: 3 additions & 2 deletions pkg/controllers/interruption/interruption_benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
servicesqs "github.com/aws/aws-sdk-go/service/sqs"
"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
"github.com/samber/lo"
"go.uber.org/multierr"
"go.uber.org/zap"
Expand Down Expand Up @@ -161,8 +162,8 @@ func benchmarkNotificationController(b *testing.B, messageCount int) {

type providerSet struct {
kubeClient client.Client
sqsAPI *servicesqs.SQS
sqsProvider *sqs.Provider
sqsAPI sqsiface.SQSAPI
sqsProvider sqs.Provider
}

func newProviders(ctx context.Context, kubeClient client.Client) providerSet {
Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/interruption/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ const (
var ctx context.Context
var env *coretest.Environment
var sqsapi *fake.SQSAPI
var sqsProvider *sqs.Provider
var sqsProvider *sqs.DefaultProvider
var unavailableOfferingsCache *awscache.UnavailableOfferings
var fakeClock *clock.FakeClock
var controller *interruption.Controller
Expand Down
4 changes: 2 additions & 2 deletions pkg/controllers/nodeclaim/tagging/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ import (

type Controller struct {
kubeClient client.Client
instanceProvider *instance.Provider
instanceProvider instance.Provider
}

func NewController(kubeClient client.Client, instanceProvider *instance.Provider) corecontroller.Controller {
func NewController(kubeClient client.Client, instanceProvider instance.Provider) corecontroller.Controller {
return corecontroller.Typed[*corev1beta1.NodeClaim](kubeClient, &Controller{
kubeClient: kubeClient,
instanceProvider: instanceProvider,
Expand Down
26 changes: 14 additions & 12 deletions pkg/controllers/nodeclass/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,23 @@ import (
var _ corecontroller.FinalizingTypedController[*v1beta1.EC2NodeClass] = (*Controller)(nil)

type Controller struct {
kubeClient client.Client
recorder events.Recorder
subnetProvider *subnet.Provider
securityGroupProvider *securitygroup.Provider
amiProvider *amifamily.Provider
instanceProfileProvider *instanceprofile.Provider
launchTemplateProvider *launchtemplate.Provider
kubeClient client.Client
recorder events.Recorder

subnetProvider subnet.Provider
securityGroupProvider securitygroup.Provider
amiProvider amifamily.Provider
instanceProfileProvider instanceprofile.Provider
launchTemplateProvider launchtemplate.Provider
}

func NewController(kubeClient client.Client, recorder events.Recorder, subnetProvider *subnet.Provider, securityGroupProvider *securitygroup.Provider,
amiProvider *amifamily.Provider, instanceProfileProvider *instanceprofile.Provider, launchTemplateProvider *launchtemplate.Provider) corecontroller.Controller {
func NewController(kubeClient client.Client, recorder events.Recorder, subnetProvider subnet.Provider, securityGroupProvider securitygroup.Provider,
amiProvider amifamily.Provider, instanceProfileProvider instanceprofile.Provider, launchTemplateProvider launchtemplate.Provider) corecontroller.Controller {

return corecontroller.Typed[*v1beta1.EC2NodeClass](kubeClient, &Controller{
kubeClient: kubeClient,
recorder: recorder,
kubeClient: kubeClient,
recorder: recorder,

subnetProvider: subnetProvider,
securityGroupProvider: securityGroupProvider,
amiProvider: amiProvider,
Expand Down Expand Up @@ -137,7 +139,7 @@ func (c *Controller) Finalize(ctx context.Context, nodeClass *v1beta1.EC2NodeCla
return reconcile.Result{}, fmt.Errorf("deleting instance profile, %w", err)
}
}
if err := c.launchTemplateProvider.DeleteLaunchTemplates(ctx, nodeClass); err != nil {
if err := c.launchTemplateProvider.DeleteAll(ctx, nodeClass); err != nil {
return reconcile.Result{}, fmt.Errorf("deleting launch templates, %w", err)
}
controllerutil.RemoveFinalizer(nodeClass, v1beta1.TerminationFinalizer)
Expand Down
4 changes: 2 additions & 2 deletions pkg/controllers/pricing/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ import (
)

type Controller struct {
pricingProvider *pricing.Provider
pricingProvider pricing.Provider
}

func NewController(pricingProvider *pricing.Provider) *Controller {
func NewController(pricingProvider pricing.Provider) *Controller {
return &Controller{
pricingProvider: pricingProvider,
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/controllers/pricing/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ var controller *controllerspricing.Controller
func TestAWS(t *testing.T) {
ctx = TestContextWithLogger(t)
RegisterFailHandler(Fail)
RunSpecs(t, "Provider/AWS")
RunSpecs(t, "Pricing")
}

var _ = BeforeSuite(func() {
Expand Down Expand Up @@ -84,7 +84,7 @@ var _ = Describe("Pricing", func() {
"should return correct static data for all partitions",
func(staticPricing map[string]map[string]float64) {
for region, prices := range staticPricing {
provider := pricing.NewProvider(ctx, awsEnv.PricingAPI, awsEnv.EC2API, region)
provider := pricing.NewDefaultProvider(ctx, awsEnv.PricingAPI, awsEnv.EC2API, region)
for instance, price := range prices {
val, ok := provider.OnDemandPrice(instance)
Expect(ok).To(BeTrue())
Expand Down Expand Up @@ -298,7 +298,7 @@ var _ = Describe("Pricing", func() {
Expect(price).To(BeNumerically("==", 1.10))
})
It("should update on-demand pricing with response from the pricing API when in the CN partition", func() {
tmpPricingProvider := pricing.NewProvider(ctx, awsEnv.PricingAPI, awsEnv.EC2API, "cn-anywhere-1")
tmpPricingProvider := pricing.NewDefaultProvider(ctx, awsEnv.PricingAPI, awsEnv.EC2API, "cn-anywhere-1")
tmpController := controllerspricing.NewController(tmpPricingProvider)

now := time.Now()
Expand Down
2 changes: 1 addition & 1 deletion pkg/fake/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func MakeInstances() []*ec2.InstanceTypeInfo {
ctx := options.ToContext(context.Background(), &options.Options{IsolatedVPC: true})
// Use keys from the static pricing data so that we guarantee pricing for the data
// Create uniform instance data so all of them schedule for a given pod
for _, it := range pricing.NewProvider(ctx, nil, nil, "us-east-1").InstanceTypes() {
for _, it := range pricing.NewDefaultProvider(ctx, nil, nil, "us-east-1").InstanceTypes() {
instanceTypes = append(instanceTypes, &ec2.InstanceTypeInfo{
InstanceType: aws.String(it),
ProcessorInfo: &ec2.ProcessorInfo{
Expand Down
34 changes: 17 additions & 17 deletions pkg/operator/operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,16 @@ type Operator struct {
Session *session.Session
UnavailableOfferingsCache *awscache.UnavailableOfferings
EC2API ec2iface.EC2API
SubnetProvider *subnet.Provider
SecurityGroupProvider *securitygroup.Provider
InstanceProfileProvider *instanceprofile.Provider
AMIProvider *amifamily.Provider
SubnetProvider subnet.Provider
SecurityGroupProvider securitygroup.Provider
InstanceProfileProvider instanceprofile.Provider
AMIProvider amifamily.Provider
AMIResolver *amifamily.Resolver
LaunchTemplateProvider *launchtemplate.Provider
PricingProvider *pricing.Provider
VersionProvider *version.Provider
InstanceTypesProvider *instancetype.Provider
InstanceProvider *instance.Provider
LaunchTemplateProvider launchtemplate.Provider
PricingProvider pricing.Provider
VersionProvider version.Provider
InstanceTypesProvider instancetype.Provider
InstanceProvider instance.Provider
}

func NewOperator(ctx context.Context, operator *operator.Operator) (context.Context, *Operator) {
Expand Down Expand Up @@ -132,18 +132,18 @@ func NewOperator(ctx context.Context, operator *operator.Operator) (context.Cont
}

unavailableOfferingsCache := awscache.NewUnavailableOfferings()
subnetProvider := subnet.NewProvider(ec2api, cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval))
securityGroupProvider := securitygroup.NewProvider(ec2api, cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval))
subnetProvider := subnet.NewDefaultProvider(ec2api, cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval))
securityGroupProvider := securitygroup.NewDefaultProvider(ec2api, cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval))
instanceProfileProvider := instanceprofile.NewProvider(*sess.Config.Region, iam.New(sess), cache.New(awscache.InstanceProfileTTL, awscache.DefaultCleanupInterval))
pricingProvider := pricing.NewProvider(
pricingProvider := pricing.NewDefaultProvider(
ctx,
pricing.NewAPI(sess, *sess.Config.Region),
ec2api,
*sess.Config.Region,
)
versionProvider := version.NewProvider(operator.KubernetesInterface, cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval))
amiProvider := amifamily.NewProvider(versionProvider, ssm.New(sess), ec2api, cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval))
amiResolver := amifamily.New(amiProvider)
versionProvider := version.NewDefaultProvider(operator.KubernetesInterface, cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval))
amiProvider := amifamily.NewDefaultProvider(versionProvider, ssm.New(sess), ec2api, cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval))
amiResolver := amifamily.NewResolver(amiProvider)
launchTemplateProvider := launchtemplate.NewProvider(
ctx,
cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval),
Expand All @@ -158,15 +158,15 @@ func NewOperator(ctx context.Context, operator *operator.Operator) (context.Cont
kubeDNSIP,
clusterEndpoint,
)
instanceTypeProvider := instancetype.NewProvider(
instanceTypeProvider := instancetype.NewDefaultProvider(
*sess.Config.Region,
cache.New(awscache.InstanceTypesAndZonesTTL, awscache.DefaultCleanupInterval),
ec2api,
subnetProvider,
unavailableOfferingsCache,
pricingProvider,
)
instanceProvider := instance.NewProvider(
instanceProvider := instance.NewDefaultProvider(
ctx,
aws.StringValue(sess.Config.Region),
ec2api,
Expand Down
22 changes: 13 additions & 9 deletions pkg/providers/amifamily/ami.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,16 @@ import (
"sigs.k8s.io/karpenter/pkg/utils/pretty"
)

type Provider struct {
type Provider interface {
Get(ctx context.Context, nodeClass *v1beta1.EC2NodeClass, options *Options) (AMIs, error)
}

type DefaultProvider struct {
cache *cache.Cache
ssm ssmiface.SSMAPI
ec2api ec2iface.EC2API
cm *pretty.ChangeMonitor
versionProvider *version.Provider
versionProvider version.Provider
}

type AMI struct {
Expand Down Expand Up @@ -101,8 +105,8 @@ func (a AMIs) MapToInstanceTypes(instanceTypes []*cloudprovider.InstanceType) ma
return amiIDs
}

func NewProvider(versionProvider *version.Provider, ssm ssmiface.SSMAPI, ec2api ec2iface.EC2API, cache *cache.Cache) *Provider {
return &Provider{
func NewDefaultProvider(versionProvider version.Provider, ssm ssmiface.SSMAPI, ec2api ec2iface.EC2API, cache *cache.Cache) *DefaultProvider {
return &DefaultProvider{
cache: cache,
ssm: ssm,
ec2api: ec2api,
Expand All @@ -112,7 +116,7 @@ func NewProvider(versionProvider *version.Provider, ssm ssmiface.SSMAPI, ec2api
}

// Get Returning a list of AMIs with its associated requirements
func (p *Provider) Get(ctx context.Context, nodeClass *v1beta1.EC2NodeClass, options *Options) (AMIs, error) {
func (p *DefaultProvider) Get(ctx context.Context, nodeClass *v1beta1.EC2NodeClass, options *Options) (AMIs, error) {
var err error
var amis AMIs
if len(nodeClass.Spec.AMISelectorTerms) == 0 {
Expand All @@ -133,7 +137,7 @@ func (p *Provider) Get(ctx context.Context, nodeClass *v1beta1.EC2NodeClass, opt
return amis, nil
}

func (p *Provider) getDefaultAMIs(ctx context.Context, nodeClass *v1beta1.EC2NodeClass, options *Options) (res AMIs, err error) {
func (p *DefaultProvider) getDefaultAMIs(ctx context.Context, nodeClass *v1beta1.EC2NodeClass, options *Options) (res AMIs, err error) {
if images, ok := p.cache.Get(lo.FromPtr(nodeClass.Spec.AMIFamily)); ok {
return images.(AMIs), nil
}
Expand Down Expand Up @@ -171,7 +175,7 @@ func (p *Provider) getDefaultAMIs(ctx context.Context, nodeClass *v1beta1.EC2Nod
return res, nil
}

func (p *Provider) resolveSSMParameter(ctx context.Context, ssmQuery string) (string, error) {
func (p *DefaultProvider) resolveSSMParameter(ctx context.Context, ssmQuery string) (string, error) {
output, err := p.ssm.GetParameterWithContext(ctx, &ssm.GetParameterInput{Name: aws.String(ssmQuery)})
if err != nil {
return "", fmt.Errorf("getting ssm parameter %q, %w", ssmQuery, err)
Expand All @@ -180,7 +184,7 @@ func (p *Provider) resolveSSMParameter(ctx context.Context, ssmQuery string) (st
return ami, nil
}

func (p *Provider) getAMIs(ctx context.Context, terms []v1beta1.AMISelectorTerm) (AMIs, error) {
func (p *DefaultProvider) getAMIs(ctx context.Context, terms []v1beta1.AMISelectorTerm) (AMIs, error) {
filterAndOwnerSets := GetFilterAndOwnerSets(terms)
hash, err := hashstructure.Hash(filterAndOwnerSets, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true})
if err != nil {
Expand Down Expand Up @@ -279,7 +283,7 @@ func GetFilterAndOwnerSets(terms []v1beta1.AMISelectorTerm) (res []FiltersAndOwn
return res
}

func (p *Provider) getRequirementsFromImage(ec2Image *ec2.Image) scheduling.Requirements {
func (p *DefaultProvider) getRequirementsFromImage(ec2Image *ec2.Image) scheduling.Requirements {
requirements := scheduling.NewRequirements()
// Always add the architecture of an image as a requirement, irrespective of what's specified in EC2 tags.
architecture := *ec2Image.Architecture
Expand Down
6 changes: 3 additions & 3 deletions pkg/providers/amifamily/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ var DefaultEBS = v1beta1.BlockDevice{

// Resolver is able to fill-in dynamic launch template parameters
type Resolver struct {
amiProvider *Provider
amiProvider Provider
}

// Options define the static launch template parameters
Expand Down Expand Up @@ -111,8 +111,8 @@ func (d DefaultFamily) FeatureFlags() FeatureFlags {
}
}

// New constructs a new launch template Resolver
func New(amiProvider *Provider) *Resolver {
// NewResolver constructs a new launch template Resolver
func NewResolver(amiProvider Provider) *Resolver {
return &Resolver{
amiProvider: amiProvider,
}
Expand Down
Loading

0 comments on commit af43cf8

Please sign in to comment.