diff --git a/pkg/providers/pricing/pricing.go b/pkg/providers/pricing/pricing.go index 7edc92edba20..53c9f99991e2 100644 --- a/pkg/providers/pricing/pricing.go +++ b/pkg/providers/pricing/pricing.go @@ -54,8 +54,10 @@ type Provider struct { region string cm *pretty.ChangeMonitor - mu sync.RWMutex - onDemandPrices map[string]float64 + muOnDemand sync.RWMutex + onDemandPrices map[string]float64 + + muSpot sync.RWMutex spotPrices map[string]zonal spotPricingUpdated bool } @@ -109,16 +111,18 @@ func NewProvider(_ context.Context, pricing pricingiface.PricingAPI, ec2Api ec2i // InstanceTypes returns the list of all instance types for which either a spot or on-demand price is known. func (p *Provider) InstanceTypes() []string { - p.mu.RLock() - defer p.mu.RUnlock() + p.muOnDemand.RLock() + p.muSpot.RLock() + defer p.muOnDemand.RUnlock() + defer p.muSpot.RUnlock() return lo.Union(lo.Keys(p.onDemandPrices), lo.Keys(p.spotPrices)) } // OnDemandPrice returns the last known on-demand price for a given instance type, returning an error if there is no // known on-demand pricing for the instance type. func (p *Provider) OnDemandPrice(instanceType string) (float64, bool) { - p.mu.RLock() - defer p.mu.RUnlock() + p.muOnDemand.RLock() + defer p.muOnDemand.RUnlock() price, ok := p.onDemandPrices[instanceType] if !ok { return 0.0, false @@ -129,8 +133,8 @@ func (p *Provider) OnDemandPrice(instanceType string) (float64, bool) { // SpotPrice returns the last known spot price for a given instance type and zone, returning an error // if there is no known spot pricing for that instance type or zone func (p *Provider) SpotPrice(instanceType string, zone string) (float64, bool) { - p.mu.RLock() - defer p.mu.RUnlock() + p.muSpot.RLock() + defer p.muSpot.RUnlock() if val, ok := p.spotPrices[instanceType]; ok { if !p.spotPricingUpdated { return val.defaultPrice, true @@ -158,6 +162,9 @@ func (p *Provider) UpdateOnDemandPricing(ctx context.Context) error { return nil } + p.muOnDemand.Lock() + defer p.muOnDemand.Unlock() + wg.Add(1) go func() { defer wg.Done() @@ -193,8 +200,6 @@ func (p *Provider) UpdateOnDemandPricing(ctx context.Context) error { wg.Wait() - p.mu.Lock() - defer p.mu.Unlock() err := multierr.Append(onDemandErr, onDemandMetalErr) if err != nil { return fmt.Errorf("retreiving on-demand pricing data, %w", err) @@ -245,18 +250,51 @@ func (p *Provider) fetchOnDemandPricing(ctx context.Context, additionalFilters . Value: aws.String("OnDemand"), }}, additionalFilters...) - if err := p.pricing.GetProductsPagesWithContext(ctx, &pricing.GetProductsInput{ - Filters: filters, - ServiceCode: aws.String("AmazonEC2")}, p.onDemandPage(prices)); err != nil { + + err := p.pricing.GetProductsPagesWithContext( + ctx, + &pricing.GetProductsInput{ + Filters: filters, + ServiceCode: aws.String("AmazonEC2"), + }, + p.onDemandPage(ctx, prices), + ) + if err != nil { return nil, err } + return prices, nil } +func (p *Provider) spotPage(ctx context.Context, prices map[string]map[string]float64) func(output *ec2.DescribeSpotPriceHistoryOutput, b bool) bool { + return func(output *ec2.DescribeSpotPriceHistoryOutput, b bool) bool { + for _, sph := range output.SpotPriceHistory { + spotPriceStr := aws.StringValue(sph.SpotPrice) + spotPrice, err := strconv.ParseFloat(spotPriceStr, 64) + // these errors shouldn't occur, but if pricing API does have an error, we ignore the record + if err != nil { + logging.FromContext(ctx).Debugf("unable to parse price record %#v", sph) + continue + } + if sph.Timestamp == nil { + continue + } + instanceType := aws.StringValue(sph.InstanceType) + az := aws.StringValue(sph.AvailabilityZone) + _, ok := prices[instanceType] + if !ok { + prices[instanceType] = map[string]float64{} + } + prices[instanceType][az] = spotPrice + } + return true + } +} + // turning off cyclo here, it measures as a 12 due to all of the type checks of the pricing data which returns a deeply // nested map[string]interface{} // nolint: gocyclo -func (p *Provider) onDemandPage(prices map[string]float64) func(output *pricing.GetProductsOutput, b bool) bool { +func (p *Provider) onDemandPage(ctx context.Context, prices map[string]float64) func(output *pricing.GetProductsOutput, b bool) bool { // this isn't the full pricing struct, just the portions we care about type priceItem struct { Product struct { @@ -282,12 +320,12 @@ func (p *Provider) onDemandPage(prices map[string]float64) func(output *pricing. var buf bytes.Buffer enc := json.NewEncoder(&buf) if err := enc.Encode(outer); err != nil { - logging.FromContext(context.Background()).Errorf("encoding %s", err) + logging.FromContext(ctx).Errorf("encoding %s", err) } dec := json.NewDecoder(&buf) var pItem priceItem if err := dec.Decode(&pItem); err != nil { - logging.FromContext(context.Background()).Errorf("decoding %s", err) + logging.FromContext(ctx).Errorf("decoding %s", err) } if pItem.Product.Attributes.InstanceType == "" { continue @@ -308,37 +346,22 @@ func (p *Provider) onDemandPage(prices map[string]float64) func(output *pricing. // nolint: gocyclo func (p *Provider) UpdateSpotPricing(ctx context.Context) error { - totalOfferings := 0 - prices := map[string]map[string]float64{} - err := p.ec2.DescribeSpotPriceHistoryPagesWithContext(ctx, &ec2.DescribeSpotPriceHistoryInput{ - ProductDescriptions: []*string{aws.String("Linux/UNIX"), aws.String("Linux/UNIX (Amazon VPC)")}, - // get the latest spot price for each instance type - StartTime: aws.Time(time.Now()), - }, func(output *ec2.DescribeSpotPriceHistoryOutput, b bool) bool { - for _, sph := range output.SpotPriceHistory { - spotPriceStr := aws.StringValue(sph.SpotPrice) - spotPrice, err := strconv.ParseFloat(spotPriceStr, 64) - // these errors shouldn't occur, but if pricing API does have an error, we ignore the record - if err != nil { - logging.FromContext(ctx).Debugf("unable to parse price record %#v", sph) - continue - } - if sph.Timestamp == nil { - continue - } - instanceType := aws.StringValue(sph.InstanceType) - az := aws.StringValue(sph.AvailabilityZone) - _, ok := prices[instanceType] - if !ok { - prices[instanceType] = map[string]float64{} - } - prices[instanceType][az] = spotPrice - } - return true - }) - p.mu.Lock() - defer p.mu.Unlock() + + p.muSpot.Lock() + defer p.muSpot.Unlock() + err := p.ec2.DescribeSpotPriceHistoryPagesWithContext( + ctx, + &ec2.DescribeSpotPriceHistoryInput{ + ProductDescriptions: []*string{ + aws.String("Linux/UNIX"), + aws.String("Linux/UNIX (Amazon VPC)"), + }, + // get the latest spot price for each instance type + StartTime: aws.Time(time.Now()), + }, + p.spotPage(ctx, prices), + ) if err != nil { return fmt.Errorf("retrieving spot pricing data, %w", err) @@ -346,6 +369,8 @@ func (p *Provider) UpdateSpotPricing(ctx context.Context) error { if len(prices) == 0 { return fmt.Errorf("no spot pricing found") } + + totalOfferings := 0 for it, zoneData := range prices { if _, ok := p.spotPrices[it]; !ok { p.spotPrices[it] = newZonalPricing(0) @@ -367,9 +392,11 @@ func (p *Provider) UpdateSpotPricing(ctx context.Context) error { func (p *Provider) LivenessProbe(_ *http.Request) error { // ensure we don't deadlock and nolint for the empty critical section - p.mu.Lock() + p.muOnDemand.Lock() + p.muSpot.Lock() //nolint: staticcheck - p.mu.Unlock() + p.muOnDemand.Unlock() + p.muSpot.Unlock() return nil }