Skip to content

Commit

Permalink
changes for injection cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jmdeal committed Oct 14, 2023
1 parent a0bfa7e commit 1bea4fb
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 101 deletions.
4 changes: 3 additions & 1 deletion pkg/cloudprovider/cloudprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ import (

corev1beta1 "github.com/aws/karpenter-core/pkg/apis/v1beta1"
"github.com/aws/karpenter-core/pkg/events"
coreoptions "github.com/aws/karpenter-core/pkg/operator/options"
"github.com/aws/karpenter-core/pkg/utils/functional"
machineutil "github.com/aws/karpenter-core/pkg/utils/machine"
nodepoolutil "github.com/aws/karpenter-core/pkg/utils/nodepool"
"github.com/aws/karpenter/pkg/apis"
"github.com/aws/karpenter/pkg/apis/v1alpha1"
"github.com/aws/karpenter/pkg/apis/v1beta1"
"github.com/aws/karpenter/pkg/operator/options"
"github.com/aws/karpenter/pkg/utils"
nodeclassutil "github.com/aws/karpenter/pkg/utils/nodeclass"

Expand Down Expand Up @@ -63,7 +65,7 @@ func init() {
v1alpha5.NormalizedLabels = lo.Assign(v1alpha5.NormalizedLabels, map[string]string{"topology.ebs.csi.aws.com/zone": v1.LabelTopologyZone})
corev1beta1.NormalizedLabels = lo.Assign(corev1beta1.NormalizedLabels, map[string]string{"topology.ebs.csi.aws.com/zone": v1.LabelTopologyZone})
coreapis.Settings = append(coreapis.Settings, apis.Settings...)
coreapis.Options = append(coreapis.Options, apis.Options...)
coreoptions.Injectables = append(coreoptions.Injectables, &options.Options{})
}

var _ cloudprovider.CloudProvider = (*CloudProvider)(nil)
Expand Down
143 changes: 56 additions & 87 deletions pkg/operator/options/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@ import (
"flag"
"fmt"
"os"
"strings"
"time"

"k8s.io/apimachinery/pkg/util/sets"

coresettings "github.com/aws/karpenter-core/pkg/apis/settings"
"github.com/aws/karpenter-core/pkg/utils/env"
"github.com/aws/karpenter/pkg/apis/settings"
)

type optionsKey struct{}

type Fields struct {
type Options struct {
AssumeRoleARN string
AssumeRoleDuration time.Duration
ClusterCABundle string
Expand All @@ -40,106 +40,82 @@ type Fields struct {
IsolatedVPC bool
VMMemoryOverheadPercent float64
InterruptionQueueName string
}

type setFlags struct {
AssumeRoleARNSet bool
AssumeRoleDurationSet bool
ClusterCABundleSet bool
ClusterNameSet bool
ClusterEndpointSet bool
IsolatedVPCSet bool
VMMemoryOverheadPercentSet bool
InterruptionQueueNameSet bool
setFlags map[string]bool
}

type Options struct {
*flag.FlagSet
Fields
setFlags
func (o *Options) AddFlags(fs *flag.FlagSet) {
fs.StringVar(&o.AssumeRoleARN, "aws-assume-role-arn", env.WithDefaultString("AWS_ASSUME_ROLE_ARN", ""), "Role to assume for calling AWS services.")
fs.DurationVar(&o.AssumeRoleDuration, "aws-assume-role-duration", env.WithDefaultDuration("AWS_ASSUME_ROLE_DURATION", 15*time.Minute), "Duration of assumed credentials in minutes. Default value is 15 minutes. Not used unless aws.assumeRole set.")
fs.StringVar(&o.ClusterCABundle, "aws-cluster-ca-bundle", env.WithDefaultString("AWS_CLUSTER_CA_BUNDLE", ""), "Cluster CA bundle for nodes to use for TLS connections with the API server. If not set, this is taken from the controller's TLS configuration.")
fs.StringVar(&o.ClusterName, "aws-cluster-name", env.WithDefaultString("AWS_CLUSTER_NAME", ""), "[REQUIRED] The kubernetes cluster name for resource discovery.")
fs.StringVar(&o.ClusterEndpoint, "aws-cluster-endpoint", env.WithDefaultString("AWS_CLUSTER_ENDPOINT", ""), "The external kubernetes cluster endpoint for new nodes to connect with. If not specified, will discover the cluster endpoint using DescribeCluster API.")
fs.BoolVar(&o.IsolatedVPC, "aws-isolated-vpc", env.WithDefaultBool("AWS_ISOLATED_VPC", false), "If true, then assume we can't reach AWS services which don't have a VPC endpoint. This also has the effect of disabling look-ups to the AWS pricing endpoint.")
fs.Float64Var(&o.VMMemoryOverheadPercent, "aws-vm-memory-overhead-percent", env.WithDefaultFloat64("AWS_VM_MEMORY_OVERHEAD_PERCENT", 0.075), "The VM memory overhead as a percent that will be subtracted from the total memory for all instance types.")
fs.StringVar(&o.InterruptionQueueName, "aws-interruption-queue-name", env.WithDefaultString("AWS_INTERRUPTION_QUEUE_NAME", ""), "aws.interruptionQueueName is disabled if not specified. Enabling interruption handling may require additional permissions on the controller service account. Additional permissions are outlined in the docs.")
}

func New() *Options {
opts := &Options{}
f := flag.NewFlagSet("aws-karpenter", flag.ContinueOnError)
opts.FlagSet = f

f.StringVar(&opts.AssumeRoleARN, "aws-assume-role-arn", env.WithDefaultString("AWS_ASSUME_ROLE_ARN", ""), "Role to assume for calling AWS services.")
f.DurationVar(&opts.AssumeRoleDuration, "aws-assume-role-duration", env.WithDefaultDuration("AWS_ASSUME_ROLE_DURATION", 15*time.Minute), "Duration of assumed credentials in minutes. Default value is 15 minutes. Not used unless aws.assumeRole set.")
f.StringVar(&opts.ClusterCABundle, "aws-cluster-ca-bundle", env.WithDefaultString("AWS_CLUSTER_CA_BUNDLE", ""), "Cluster CA bundle for nodes to use for TLS connections with the API server. If not set, this is taken from the controller's TLS configuration.")
f.StringVar(&opts.ClusterName, "aws-cluster-name", env.WithDefaultString("AWS_CLUSTER_NAME", ""), "[REQUIRED] The kubernetes cluster name for resource discovery.")
f.StringVar(&opts.ClusterEndpoint, "aws-cluster-endpoint", env.WithDefaultString("AWS_CLUSTER_ENDPOINT", ""), "The external kubernetes cluster endpoint for new nodes to connect with. If not specified, will discover the cluster endpoint using DescribeCluster API.")
f.BoolVar(&opts.IsolatedVPC, "aws-isolated-vpc", env.WithDefaultBool("AWS_ISOLATED_VPC", false), "If true, then assume we can't reach AWS services which don't have a VPC endpoint. This also has the effect of disabling look-ups to the AWS pricing endpoint.")
f.Float64Var(&opts.VMMemoryOverheadPercent, "aws-vm-memory-overhead-percent", env.WithDefaultFloat64("AWS_VM_MEMORY_OVERHEAD_PERCENT", 0.075), "The VM memory overhead as a percent that will be subtracted from the total memory for all instance types.")
f.StringVar(&opts.InterruptionQueueName, "aws-interruption-queue-name", env.WithDefaultString("AWS_INTERRUPTION_QUEUE_NAME", ""), "aws.interruptionQueueName is disabled if not specified. Enabling interruption handling may require additional permissions on the controller service account. Additional permissions are outlined in the docs.")

return opts
}

func (*Options) Inject(ctx context.Context, args ...string) (context.Context, error) {
o := New()
if err := o.Parse(args); err != nil {
func (o *Options) Parse(fs *flag.FlagSet, args ...string) error {
if err := fs.Parse(args); err != nil {
if errors.Is(err, flag.ErrHelp) {
os.Exit(0)
}
return ctx, fmt.Errorf("failed to parse cli flags, %w", err)
return fmt.Errorf("failed to parse cli flags, %w", err)
}

if err := o.Validate(); err != nil {
return ctx, fmt.Errorf("failed to validate options, %w", err)
return fmt.Errorf("failed to validate options, %w", err)
}

// Check if each option has been set. This is a little brute force and better options might exist,
// but this only needs to be here for one version
o.setFlags = make(map[string]bool)
cliFlags := sets.New[string]()
o.Visit(func(f *flag.Flag) {
fs.Visit(func(f *flag.Flag) {
cliFlags.Insert(f.Name)
})
for _, entry := range []struct {
setFlag *bool
cliName string
envName string
}{
{&o.AssumeRoleARNSet, "aws-assume-role-arn", "AWS_ASSUME_ROLE_ARN"},
{&o.AssumeRoleDurationSet, "aws-assume-role-duration", "AWS_ASSUME_ROLE_DURATION"},
{&o.ClusterCABundleSet, "aws-cluster-ca-bundle", "AWS_CLUSTER_CA_BUNDLE"},
{&o.ClusterNameSet, "aws-cluster-name", "AWS_CLUSTER_NAME"},
{&o.ClusterEndpointSet, "aws-cluster-endpoint", "AWS_CLUSTER_ENDPOINT"},
{&o.IsolatedVPCSet, "aws-isolated-vpc", "AWS_ISOLATED_VPC"},
{&o.VMMemoryOverheadPercentSet, "aws-vm-memory-overhead-percent", "AWS_VM_MEMORY_OVERHEAD_PERCENT"},
{&o.InterruptionQueueNameSet, "aws-interruption-queue-name", "AWS_INTERRUPTION_QUEUE_NAME"},
} {
if cliFlags.Has(entry.cliName) {
*entry.setFlag = true
}
if _, ok := os.LookupEnv(entry.envName); ok {
*entry.setFlag = true
}
}
fs.VisitAll(func(f *flag.Flag) {
envName := strings.ReplaceAll(strings.ToUpper(f.Name), "-", "_")
_, ok := os.LookupEnv(envName)
o.setFlags[f.Name] = ok || cliFlags.Has(f.Name)
})

ctx = ToContext(ctx, o)
return ctx, nil
return nil
}

func (*Options) MergeSettings(ctx context.Context, injectables ...coresettings.Injectable) context.Context {
for _, in := range injectables {
_, ok := in.(*settings.Settings)
if !ok {
continue
}
s := in.FromContext(ctx).(*settings.Settings)
o := FromContext(ctx)
mergeField(&o.AssumeRoleARN, s.AssumeRoleARN, o.AssumeRoleARNSet)
mergeField(&o.AssumeRoleDuration, s.AssumeRoleDuration, o.AssumeRoleDurationSet)
mergeField(&o.ClusterCABundle, s.ClusterCABundle, o.ClusterCABundleSet)
mergeField(&o.ClusterName, s.ClusterName, o.ClusterNameSet)
mergeField(&o.ClusterEndpoint, s.ClusterEndpoint, o.ClusterEndpointSet)
mergeField(&o.IsolatedVPC, s.IsolatedVPC, o.IsolatedVPCSet)
mergeField(&o.VMMemoryOverheadPercent, s.VMMemoryOverheadPercent, o.VMMemoryOverheadPercentSet)
mergeField(&o.InterruptionQueueName, s.InterruptionQueueName, o.InterruptionQueueNameSet)
ctx = ToContext(ctx, o)
func (o *Options) ToContext(ctx context.Context) context.Context {
return ToContext(ctx, o)
}

func (o *Options) MergeSettings(ctx context.Context) {
s := settings.FromContext(ctx)
if s == nil {
return
}
if !o.setFlags["aws-assume-role-arn"] {
o.AssumeRoleARN = s.AssumeRoleARN
}
if !o.setFlags["aws-assume-role-duration"] {
o.AssumeRoleDuration = s.AssumeRoleDuration
}
if !o.setFlags["aws-cluster-ca-bundle"] {
o.ClusterCABundle = s.ClusterCABundle
}
if !o.setFlags["aws-cluster-name"] {
o.ClusterName = s.ClusterName
}
if !o.setFlags["aws-cluster-endpoint"] {
o.ClusterEndpoint = s.ClusterEndpoint
}
if !o.setFlags["aws-isolated-vpc"] {
o.IsolatedVPC = s.IsolatedVPC
}
if !o.setFlags["aws-vm-memory-overhead-percent"] {
o.VMMemoryOverheadPercent = s.VMMemoryOverheadPercent
}
if !o.setFlags["aws-interruption-queue-name"] {
o.InterruptionQueueName = s.InterruptionQueueName
}
return ctx
}

func ToContext(ctx context.Context, opts *Options) context.Context {
Expand All @@ -153,10 +129,3 @@ func FromContext(ctx context.Context) *Options {
}
return retval.(*Options)
}

func mergeField[T any](dest *T, val T, isSet bool) {
if isSet {
return
}
*dest = val
}
9 changes: 6 additions & 3 deletions pkg/operator/options/options_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ func (o Options) Validate() (errs *apis.FieldError) {
}

func (o Options) validateAssumeRoleDuration() (errs *apis.FieldError) {
if !o.AssumeRoleDurationSet {
// TODO: Remove with karpenter-global-settings
if !o.setFlags["aws-assume-role-arn"] {
return nil
}
if o.AssumeRoleDuration < time.Minute*15 {
Expand All @@ -42,7 +43,8 @@ func (o Options) validateAssumeRoleDuration() (errs *apis.FieldError) {
}

func (o Options) validateClusterName() (errs *apis.FieldError) {
if !o.ClusterNameSet {
// TODO: Remove with karpenter-global-settings
if !o.setFlags["aws-cluster-name"] {
return nil
}
if o.ClusterName == "" {
Expand All @@ -65,7 +67,8 @@ func (o Options) validateEndpoint() (errs *apis.FieldError) {
}

func (o Options) validateVMMemoryOverheadPercent() (errs *apis.FieldError) {
if !o.VMMemoryOverheadPercentSet {
// TODO: Remove with karpenter-global-settings
if !o.setFlags["aws-vm-memory-overhead-percent"] {
return nil
}
if o.VMMemoryOverheadPercent < 0 {
Expand Down
18 changes: 8 additions & 10 deletions pkg/test/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,13 @@ func Options(overrides ...OptionsFields) *options.Options {
}
}
return &options.Options{
Fields: options.Fields{
AssumeRoleARN: lo.FromPtrOr(opts.AssumeRoleARN, ""),
AssumeRoleDuration: lo.FromPtrOr(opts.AssumeRoleDuration, 15*time.Minute),
ClusterCABundle: lo.FromPtrOr(opts.ClusterCABundle, ""),
ClusterName: lo.FromPtrOr(opts.ClusterName, "test-cluster"),
ClusterEndpoint: lo.FromPtrOr(opts.ClusterEndpoint, "https://test-cluster"),
IsolatedVPC: lo.FromPtrOr(opts.IsolatedVPC, false),
VMMemoryOverheadPercent: lo.FromPtrOr(opts.VMMemoryOverheadPercent, 0.075),
InterruptionQueueName: lo.FromPtrOr(opts.InterruptionQueueName, ""),
},
AssumeRoleARN: lo.FromPtrOr(opts.AssumeRoleARN, ""),
AssumeRoleDuration: lo.FromPtrOr(opts.AssumeRoleDuration, 15*time.Minute),
ClusterCABundle: lo.FromPtrOr(opts.ClusterCABundle, ""),
ClusterName: lo.FromPtrOr(opts.ClusterName, "test-cluster"),
ClusterEndpoint: lo.FromPtrOr(opts.ClusterEndpoint, "https://test-cluster"),
IsolatedVPC: lo.FromPtrOr(opts.IsolatedVPC, false),
VMMemoryOverheadPercent: lo.FromPtrOr(opts.VMMemoryOverheadPercent, 0.075),
InterruptionQueueName: lo.FromPtrOr(opts.InterruptionQueueName, ""),
}
}

0 comments on commit 1bea4fb

Please sign in to comment.