Skip to content

Commit

Permalink
Create Subnet controller
Browse files Browse the repository at this point in the history
  • Loading branch information
engedaam committed Apr 25, 2024
1 parent f724f6e commit 1c3acff
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 56 deletions.
2 changes: 2 additions & 0 deletions pkg/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ const (
InstanceTypesAndZonesTTL = 5 * time.Minute
// InstanceProfileTTL is the time before we refresh checking instance profile existence at IAM
InstanceProfileTTL = 15 * time.Minute
// SubnetTTL is the time before we remove subnets that have been refreshed
SubnetTTL = 5 * time.Minute
)

const (
Expand Down
14 changes: 5 additions & 9 deletions pkg/cloudprovider/drift.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (c *CloudProvider) isNodeClassDrifted(ctx context.Context, nodeClaim *corev
if err != nil {
return "", fmt.Errorf("calculating securitygroup drift, %w", err)
}
subnetDrifted, err := c.isSubnetDrifted(ctx, instance, nodeClass)
subnetDrifted, err := c.isSubnetDrifted(instance, nodeClass)
if err != nil {
return "", fmt.Errorf("calculating subnet drift, %w", err)
}
Expand Down Expand Up @@ -96,18 +96,14 @@ func (c *CloudProvider) isAMIDrifted(ctx context.Context, nodeClaim *corev1beta1

// Checks if the security groups are drifted, by comparing the subnet returned from the subnetProvider
// to the ec2 instance subnets
func (c *CloudProvider) isSubnetDrifted(ctx context.Context, instance *instance.Instance, nodeClass *v1beta1.EC2NodeClass) (cloudprovider.DriftReason, error) {
subnets, err := c.subnetProvider.List(ctx, nodeClass)
if err != nil {
return "", err
}
func (c *CloudProvider) isSubnetDrifted(instance *instance.Instance, nodeClass *v1beta1.EC2NodeClass) (cloudprovider.DriftReason, error) {
// subnets need to be found to check for drift
if len(subnets) == 0 {
if len(nodeClass.Status.Subnets) == 0 {
return "", fmt.Errorf("no subnets are discovered")
}

_, found := lo.Find(subnets, func(subnet *ec2.Subnet) bool {
return aws.StringValue(subnet.SubnetId) == instance.SubnetID
_, found := lo.Find(nodeClass.Status.Subnets, func(subnet v1beta1.Subnet) bool {
return subnet.ID == instance.SubnetID
})

if !found {
Expand Down
2 changes: 1 addition & 1 deletion pkg/operator/operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func NewOperator(ctx context.Context, operator *operator.Operator) (context.Cont
}

unavailableOfferingsCache := awscache.NewUnavailableOfferings()
subnetProvider := subnet.NewDefaultProvider(ec2api, cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval))
subnetProvider := subnet.NewDefaultProvider(ec2api, cache.New(awscache.SubnetTTL, awscache.DefaultCleanupInterval))
securityGroupProvider := securitygroup.NewDefaultProvider(ec2api, cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval))
instanceProfileProvider := instanceprofile.NewDefaultProvider(*sess.Config.Region, iam.New(sess), cache.New(awscache.InstanceProfileTTL, awscache.DefaultCleanupInterval))
pricingProvider := pricing.NewDefaultProvider(
Expand Down
8 changes: 4 additions & 4 deletions pkg/providers/instance/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ func (p *DefaultProvider) checkODFallback(nodeClaim *corev1beta1.NodeClaim, inst
}

func (p *DefaultProvider) getLaunchTemplateConfigs(ctx context.Context, nodeClass *v1beta1.EC2NodeClass, nodeClaim *corev1beta1.NodeClaim,
instanceTypes []*cloudprovider.InstanceType, zonalSubnets map[string]*ec2.Subnet, capacityType string, tags map[string]string) ([]*ec2.FleetLaunchTemplateConfigRequest, error) {
instanceTypes []*cloudprovider.InstanceType, zonalSubnets map[string]v1beta1.Subnet, capacityType string, tags map[string]string) ([]*ec2.FleetLaunchTemplateConfigRequest, error) {
var launchTemplateConfigs []*ec2.FleetLaunchTemplateConfigRequest
launchTemplates, err := p.launchTemplateProvider.EnsureAll(ctx, nodeClass, nodeClaim, instanceTypes, capacityType, tags)
if err != nil {
Expand All @@ -311,7 +311,7 @@ func (p *DefaultProvider) getLaunchTemplateConfigs(ctx context.Context, nodeClas

// getOverrides creates and returns launch template overrides for the cross product of InstanceTypes and subnets (with subnets being constrained by
// zones and the offerings in InstanceTypes)
func (p *DefaultProvider) getOverrides(instanceTypes []*cloudprovider.InstanceType, zonalSubnets map[string]*ec2.Subnet, zones *scheduling.Requirement, capacityType string, image string) []*ec2.FleetLaunchTemplateOverridesRequest {
func (p *DefaultProvider) getOverrides(instanceTypes []*cloudprovider.InstanceType, zonalSubnets map[string]v1beta1.Subnet, zones *scheduling.Requirement, capacityType string, image string) []*ec2.FleetLaunchTemplateOverridesRequest {
// Unwrap all the offerings to a flat slice that includes a pointer
// to the parent instance type name
type offeringWithParentName struct {
Expand Down Expand Up @@ -343,11 +343,11 @@ func (p *DefaultProvider) getOverrides(instanceTypes []*cloudprovider.InstanceTy
}
overrides = append(overrides, &ec2.FleetLaunchTemplateOverridesRequest{
InstanceType: aws.String(offering.parentInstanceTypeName),
SubnetId: subnet.SubnetId,
SubnetId: lo.ToPtr(subnet.ID),
ImageId: aws.String(image),
// This is technically redundant, but is useful if we have to parse insufficient capacity errors from
// CreateFleet so that we can figure out the zone rather than additional API calls to look up the subnet
AvailabilityZone: subnet.AvailabilityZone,
AvailabilityZone: lo.ToPtr(subnet.Zone),
})
}
return overrides
Expand Down
55 changes: 19 additions & 36 deletions pkg/providers/subnet/subnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ type Provider interface {
LivenessProbe(*http.Request) error
List(context.Context, *v1beta1.EC2NodeClass) ([]*ec2.Subnet, error)
CheckAnyPublicIPAssociations(context.Context, *v1beta1.EC2NodeClass) (bool, error)
ZonalSubnetsForLaunch(context.Context, *v1beta1.EC2NodeClass, []*cloudprovider.InstanceType, string) (map[string]*ec2.Subnet, error)
UpdateInflightIPs(*ec2.CreateFleetInput, *ec2.CreateFleetOutput, []*cloudprovider.InstanceType, []*ec2.Subnet, string)
ZonalSubnetsForLaunch(context.Context, *v1beta1.EC2NodeClass, []*cloudprovider.InstanceType, string) (map[string]v1beta1.Subnet, error)
UpdateInflightIPs(*ec2.CreateFleetInput, *ec2.CreateFleetOutput, []*cloudprovider.InstanceType, []v1beta1.Subnet, string)
}

type DefaultProvider struct {
Expand Down Expand Up @@ -87,7 +87,7 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1beta1.EC2NodeCl
}
for i := range output.Subnets {
subnets[lo.FromPtr(output.Subnets[i].SubnetId)] = output.Subnets[i]
delete(p.inflightIPs, lo.FromPtr(output.Subnets[i].SubnetId)) // remove any previously tracked IP addresses since we just refreshed from EC2
p.inflightIPs[lo.FromPtr(output.Subnets[i].SubnetId)] = lo.FromPtr(output.Subnets[i].AvailableIpAddressCount) // remove any previously tracked IP addresses since we just refreshed from EC2
}
}
p.cache.SetDefault(fmt.Sprint(hash), lo.Values(subnets))
Expand All @@ -114,47 +114,30 @@ func (p *DefaultProvider) CheckAnyPublicIPAssociations(ctx context.Context, node
}

// ZonalSubnetsForLaunch returns a mapping of zone to the subnet with the most available IP addresses and deducts the passed ips from the available count
func (p *DefaultProvider) ZonalSubnetsForLaunch(ctx context.Context, nodeClass *v1beta1.EC2NodeClass, instanceTypes []*cloudprovider.InstanceType, capacityType string) (map[string]*ec2.Subnet, error) {
subnets, err := p.List(ctx, nodeClass)
if err != nil {
return nil, err
}
if len(subnets) == 0 {
func (p *DefaultProvider) ZonalSubnetsForLaunch(ctx context.Context, nodeClass *v1beta1.EC2NodeClass, instanceTypes []*cloudprovider.InstanceType, capacityType string) (map[string]v1beta1.Subnet, error) {
if len(nodeClass.Status.Subnets) == 0 {
return nil, fmt.Errorf("no subnets matched selector %v", nodeClass.Spec.SubnetSelectorTerms)
}
p.Lock()
defer p.Unlock()
// sort subnets in ascending order of available IP addresses and populate map with most available subnet per AZ
zonalSubnets := map[string]*ec2.Subnet{}
sort.Slice(subnets, func(i, j int) bool {
iIPs := aws.Int64Value(subnets[i].AvailableIpAddressCount)
jIPs := aws.Int64Value(subnets[j].AvailableIpAddressCount)
// override ip count from ec2.Subnet if we've tracked launches
if ips, ok := p.inflightIPs[*subnets[i].SubnetId]; ok {
iIPs = ips
}
if ips, ok := p.inflightIPs[*subnets[j].SubnetId]; ok {
jIPs = ips
}
return iIPs < jIPs
zonalSubnets := map[string]v1beta1.Subnet{}
sort.Slice(nodeClass.Status.Subnets, func(i, j int) bool {
return p.inflightIPs[nodeClass.Status.Subnets[i].ID] < p.inflightIPs[nodeClass.Status.Subnets[j].ID]
})
for _, subnet := range subnets {
zonalSubnets[*subnet.AvailabilityZone] = subnet
for _, subnet := range nodeClass.Status.Subnets {
zonalSubnets[subnet.Zone] = subnet
}
for _, subnet := range zonalSubnets {
predictedIPsUsed := p.minPods(instanceTypes, *subnet.AvailabilityZone, capacityType)
prevIPs := *subnet.AvailableIpAddressCount
if trackedIPs, ok := p.inflightIPs[*subnet.SubnetId]; ok {
prevIPs = trackedIPs
}
p.inflightIPs[*subnet.SubnetId] = prevIPs - predictedIPsUsed
predictedIPsUsed := p.minPods(instanceTypes, subnet.Zone, capacityType)
p.inflightIPs[subnet.ID] = p.inflightIPs[subnet.ID] - predictedIPsUsed
}
return zonalSubnets, nil
}

// UpdateInflightIPs is used to refresh the in-memory IP usage by adding back unused IPs after a CreateFleet response is returned
func (p *DefaultProvider) UpdateInflightIPs(createFleetInput *ec2.CreateFleetInput, createFleetOutput *ec2.CreateFleetOutput, instanceTypes []*cloudprovider.InstanceType,
subnets []*ec2.Subnet, capacityType string) {
subnets []v1beta1.Subnet, capacityType string) {
p.Lock()
defer p.Unlock()

Expand Down Expand Up @@ -193,19 +176,19 @@ func (p *DefaultProvider) UpdateInflightIPs(createFleetInput *ec2.CreateFleetInp
if !lo.Contains(subnetIDsToAddBackIPs, *cachedSubnet.SubnetId) {
continue
}
originalSubnet, ok := lo.Find(subnets, func(subnet *ec2.Subnet) bool {
return *subnet.SubnetId == *cachedSubnet.SubnetId
originalSubnet, ok := lo.Find(subnets, func(subnet v1beta1.Subnet) bool {
return subnet.ID == *cachedSubnet.SubnetId
})
if !ok {
continue
}
// If the cached subnet IP address count hasn't changed from the original subnet used to
// launch the instance, then we need to update the tracked IPs
if *originalSubnet.AvailableIpAddressCount == *cachedSubnet.AvailableIpAddressCount {
if p.inflightIPs[originalSubnet.ID] == *cachedSubnet.AvailableIpAddressCount {
// other IPs deducted were opportunistic and need to be readded since Fleet didn't pick those subnets to launch into
if ips, ok := p.inflightIPs[*originalSubnet.SubnetId]; ok {
minPods := p.minPods(instanceTypes, *originalSubnet.AvailabilityZone, capacityType)
p.inflightIPs[*originalSubnet.SubnetId] = ips + minPods
if ips, ok := p.inflightIPs[originalSubnet.ID]; ok {
minPods := p.minPods(instanceTypes, originalSubnet.Zone, capacityType)
p.inflightIPs[originalSubnet.ID] = ips + minPods
}
}
}
Expand Down
28 changes: 22 additions & 6 deletions pkg/providers/subnet/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,16 @@ var _ = Describe("SubnetProvider", func() {
})
Context("Provider Cache", func() {
It("should resolve subnets from cache that are filtered by id", func() {
expectedSubnets := awsEnv.EC2API.DescribeSubnetsOutput.Clone().Subnets
for _, subnet := range expectedSubnets {
expectedSubnets, err := awsEnv.EC2API.DescribeSubnetsWithContext(ctx, &ec2.DescribeSubnetsInput{
Filters: []*ec2.Filter{
{
Name: lo.ToPtr("tag-key"),
Values: []*string{lo.ToPtr("*")},
},
},
})
Expect(err).To(BeNil())
for _, subnet := range expectedSubnets.Subnets {
nodeClass.Spec.SubnetSelectorTerms = []v1beta1.SubnetSelectorTerm{
{
ID: *subnet.SubnetId,
Expand All @@ -257,12 +265,20 @@ var _ = Describe("SubnetProvider", func() {
for _, cachedObject := range awsEnv.SubnetCache.Items() {
cachedSubnet := cachedObject.Object.([]*ec2.Subnet)
Expect(cachedSubnet).To(HaveLen(1))
lo.Contains(expectedSubnets, cachedSubnet[0])
lo.Contains(expectedSubnets.Subnets, cachedSubnet[0])
}
})
It("should resolve subnets from cache that are filtered by tags", func() {
expectedSubnets := awsEnv.EC2API.DescribeSubnetsOutput.Clone().Subnets
tagSet := lo.Map(expectedSubnets, func(subnet *ec2.Subnet, _ int) map[string]string {
expectedSubnets, err := awsEnv.EC2API.DescribeSubnetsWithContext(ctx, &ec2.DescribeSubnetsInput{
Filters: []*ec2.Filter{
{
Name: lo.ToPtr("tag-key"),
Values: []*string{lo.ToPtr("*")},
},
},
})
Expect(err).To(BeNil())
tagSet := lo.Map(expectedSubnets.Subnets, func(subnet *ec2.Subnet, _ int) map[string]string {
tag, _ := lo.Find(subnet.Tags, func(tag *ec2.Tag) bool {
return lo.FromPtr(tag.Key) == "Name"
})
Expand All @@ -282,7 +298,7 @@ var _ = Describe("SubnetProvider", func() {
for _, cachedObject := range awsEnv.SubnetCache.Items() {
cachedSubnet := cachedObject.Object.([]*ec2.Subnet)
Expect(cachedSubnet).To(HaveLen(1))
lo.Contains(expectedSubnets, cachedSubnet[0])
lo.Contains(expectedSubnets.Subnets, cachedSubnet[0])
}
})
})
Expand Down

0 comments on commit 1c3acff

Please sign in to comment.