Skip to content

Commit

Permalink
fix: Move mutex in pricing before AWS api calls, refactor spot and on…
Browse files Browse the repository at this point in the history
…Demand (#5751)
  • Loading branch information
tvonhacht-apple authored Mar 6, 2024
1 parent 68e736d commit e798c27
Showing 1 changed file with 75 additions and 48 deletions.
123 changes: 75 additions & 48 deletions pkg/providers/pricing/pricing.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ type Provider struct {
region string
cm *pretty.ChangeMonitor

mu sync.RWMutex
onDemandPrices map[string]float64
muOnDemand sync.RWMutex
onDemandPrices map[string]float64

muSpot sync.RWMutex
spotPrices map[string]zonal
spotPricingUpdated bool
}
Expand Down Expand Up @@ -109,16 +111,18 @@ func NewProvider(_ context.Context, pricing pricingiface.PricingAPI, ec2Api ec2i

// InstanceTypes returns the list of all instance types for which either a spot or on-demand price is known.
func (p *Provider) InstanceTypes() []string {
p.mu.RLock()
defer p.mu.RUnlock()
p.muOnDemand.RLock()
p.muSpot.RLock()
defer p.muOnDemand.RUnlock()
defer p.muSpot.RUnlock()
return lo.Union(lo.Keys(p.onDemandPrices), lo.Keys(p.spotPrices))
}

// OnDemandPrice returns the last known on-demand price for a given instance type, returning an error if there is no
// known on-demand pricing for the instance type.
func (p *Provider) OnDemandPrice(instanceType string) (float64, bool) {
p.mu.RLock()
defer p.mu.RUnlock()
p.muOnDemand.RLock()
defer p.muOnDemand.RUnlock()
price, ok := p.onDemandPrices[instanceType]
if !ok {
return 0.0, false
Expand All @@ -129,8 +133,8 @@ func (p *Provider) OnDemandPrice(instanceType string) (float64, bool) {
// SpotPrice returns the last known spot price for a given instance type and zone, returning an error
// if there is no known spot pricing for that instance type or zone
func (p *Provider) SpotPrice(instanceType string, zone string) (float64, bool) {
p.mu.RLock()
defer p.mu.RUnlock()
p.muSpot.RLock()
defer p.muSpot.RUnlock()
if val, ok := p.spotPrices[instanceType]; ok {
if !p.spotPricingUpdated {
return val.defaultPrice, true
Expand Down Expand Up @@ -158,6 +162,9 @@ func (p *Provider) UpdateOnDemandPricing(ctx context.Context) error {
return nil
}

p.muOnDemand.Lock()
defer p.muOnDemand.Unlock()

wg.Add(1)
go func() {
defer wg.Done()
Expand Down Expand Up @@ -193,8 +200,6 @@ func (p *Provider) UpdateOnDemandPricing(ctx context.Context) error {

wg.Wait()

p.mu.Lock()
defer p.mu.Unlock()
err := multierr.Append(onDemandErr, onDemandMetalErr)
if err != nil {
return fmt.Errorf("retreiving on-demand pricing data, %w", err)
Expand Down Expand Up @@ -245,18 +250,51 @@ func (p *Provider) fetchOnDemandPricing(ctx context.Context, additionalFilters .
Value: aws.String("OnDemand"),
}},
additionalFilters...)
if err := p.pricing.GetProductsPagesWithContext(ctx, &pricing.GetProductsInput{
Filters: filters,
ServiceCode: aws.String("AmazonEC2")}, p.onDemandPage(prices)); err != nil {

err := p.pricing.GetProductsPagesWithContext(
ctx,
&pricing.GetProductsInput{
Filters: filters,
ServiceCode: aws.String("AmazonEC2"),
},
p.onDemandPage(ctx, prices),
)
if err != nil {
return nil, err
}

return prices, nil
}

func (p *Provider) spotPage(ctx context.Context, prices map[string]map[string]float64) func(output *ec2.DescribeSpotPriceHistoryOutput, b bool) bool {
return func(output *ec2.DescribeSpotPriceHistoryOutput, b bool) bool {
for _, sph := range output.SpotPriceHistory {
spotPriceStr := aws.StringValue(sph.SpotPrice)
spotPrice, err := strconv.ParseFloat(spotPriceStr, 64)
// these errors shouldn't occur, but if pricing API does have an error, we ignore the record
if err != nil {
logging.FromContext(ctx).Debugf("unable to parse price record %#v", sph)
continue
}
if sph.Timestamp == nil {
continue
}
instanceType := aws.StringValue(sph.InstanceType)
az := aws.StringValue(sph.AvailabilityZone)
_, ok := prices[instanceType]
if !ok {
prices[instanceType] = map[string]float64{}
}
prices[instanceType][az] = spotPrice
}
return true
}
}

// turning off cyclo here, it measures as a 12 due to all of the type checks of the pricing data which returns a deeply
// nested map[string]interface{}
// nolint: gocyclo
func (p *Provider) onDemandPage(prices map[string]float64) func(output *pricing.GetProductsOutput, b bool) bool {
func (p *Provider) onDemandPage(ctx context.Context, prices map[string]float64) func(output *pricing.GetProductsOutput, b bool) bool {
// this isn't the full pricing struct, just the portions we care about
type priceItem struct {
Product struct {
Expand All @@ -282,12 +320,12 @@ func (p *Provider) onDemandPage(prices map[string]float64) func(output *pricing.
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
if err := enc.Encode(outer); err != nil {
logging.FromContext(context.Background()).Errorf("encoding %s", err)
logging.FromContext(ctx).Errorf("encoding %s", err)
}
dec := json.NewDecoder(&buf)
var pItem priceItem
if err := dec.Decode(&pItem); err != nil {
logging.FromContext(context.Background()).Errorf("decoding %s", err)
logging.FromContext(ctx).Errorf("decoding %s", err)
}
if pItem.Product.Attributes.InstanceType == "" {
continue
Expand All @@ -308,44 +346,31 @@ func (p *Provider) onDemandPage(prices map[string]float64) func(output *pricing.

// nolint: gocyclo
func (p *Provider) UpdateSpotPricing(ctx context.Context) error {
totalOfferings := 0

prices := map[string]map[string]float64{}
err := p.ec2.DescribeSpotPriceHistoryPagesWithContext(ctx, &ec2.DescribeSpotPriceHistoryInput{
ProductDescriptions: []*string{aws.String("Linux/UNIX"), aws.String("Linux/UNIX (Amazon VPC)")},
// get the latest spot price for each instance type
StartTime: aws.Time(time.Now()),
}, func(output *ec2.DescribeSpotPriceHistoryOutput, b bool) bool {
for _, sph := range output.SpotPriceHistory {
spotPriceStr := aws.StringValue(sph.SpotPrice)
spotPrice, err := strconv.ParseFloat(spotPriceStr, 64)
// these errors shouldn't occur, but if pricing API does have an error, we ignore the record
if err != nil {
logging.FromContext(ctx).Debugf("unable to parse price record %#v", sph)
continue
}
if sph.Timestamp == nil {
continue
}
instanceType := aws.StringValue(sph.InstanceType)
az := aws.StringValue(sph.AvailabilityZone)
_, ok := prices[instanceType]
if !ok {
prices[instanceType] = map[string]float64{}
}
prices[instanceType][az] = spotPrice
}
return true
})
p.mu.Lock()
defer p.mu.Unlock()

p.muSpot.Lock()
defer p.muSpot.Unlock()
err := p.ec2.DescribeSpotPriceHistoryPagesWithContext(
ctx,
&ec2.DescribeSpotPriceHistoryInput{
ProductDescriptions: []*string{
aws.String("Linux/UNIX"),
aws.String("Linux/UNIX (Amazon VPC)"),
},
// get the latest spot price for each instance type
StartTime: aws.Time(time.Now()),
},
p.spotPage(ctx, prices),
)

if err != nil {
return fmt.Errorf("retrieving spot pricing data, %w", err)
}
if len(prices) == 0 {
return fmt.Errorf("no spot pricing found")
}

totalOfferings := 0
for it, zoneData := range prices {
if _, ok := p.spotPrices[it]; !ok {
p.spotPrices[it] = newZonalPricing(0)
Expand All @@ -367,9 +392,11 @@ func (p *Provider) UpdateSpotPricing(ctx context.Context) error {

func (p *Provider) LivenessProbe(_ *http.Request) error {
// ensure we don't deadlock and nolint for the empty critical section
p.mu.Lock()
p.muOnDemand.Lock()
p.muSpot.Lock()
//nolint: staticcheck
p.mu.Unlock()
p.muOnDemand.Unlock()
p.muSpot.Unlock()
return nil
}

Expand Down

0 comments on commit e798c27

Please sign in to comment.