diff --git a/pkg/cloudprovider/cloudprovider.go b/pkg/cloudprovider/cloudprovider.go index f6248c68aabf..b2f8231da176 100644 --- a/pkg/cloudprovider/cloudprovider.go +++ b/pkg/cloudprovider/cloudprovider.go @@ -27,12 +27,14 @@ import ( corev1beta1 "github.com/aws/karpenter-core/pkg/apis/v1beta1" "github.com/aws/karpenter-core/pkg/events" + coreoptions "github.com/aws/karpenter-core/pkg/operator/options" "github.com/aws/karpenter-core/pkg/utils/functional" machineutil "github.com/aws/karpenter-core/pkg/utils/machine" nodepoolutil "github.com/aws/karpenter-core/pkg/utils/nodepool" "github.com/aws/karpenter/pkg/apis" "github.com/aws/karpenter/pkg/apis/v1alpha1" "github.com/aws/karpenter/pkg/apis/v1beta1" + "github.com/aws/karpenter/pkg/operator/options" "github.com/aws/karpenter/pkg/utils" nodeclassutil "github.com/aws/karpenter/pkg/utils/nodeclass" @@ -63,7 +65,7 @@ func init() { v1alpha5.NormalizedLabels = lo.Assign(v1alpha5.NormalizedLabels, map[string]string{"topology.ebs.csi.aws.com/zone": v1.LabelTopologyZone}) corev1beta1.NormalizedLabels = lo.Assign(corev1beta1.NormalizedLabels, map[string]string{"topology.ebs.csi.aws.com/zone": v1.LabelTopologyZone}) coreapis.Settings = append(coreapis.Settings, apis.Settings...) - coreapis.Options = append(coreapis.Options, apis.Options...) + coreoptions.Injectables = append(coreoptions.Injectables, &options.Options{}) } var _ cloudprovider.CloudProvider = (*CloudProvider)(nil) diff --git a/pkg/operator/options/options.go b/pkg/operator/options/options.go index b43597b4e6d3..cba0fbd8b12c 100644 --- a/pkg/operator/options/options.go +++ b/pkg/operator/options/options.go @@ -20,18 +20,18 @@ import ( "flag" "fmt" "os" + "strings" "time" "k8s.io/apimachinery/pkg/util/sets" - coresettings "github.com/aws/karpenter-core/pkg/apis/settings" "github.com/aws/karpenter-core/pkg/utils/env" "github.com/aws/karpenter/pkg/apis/settings" ) type optionsKey struct{} -type Fields struct { +type Options struct { AssumeRoleARN string AssumeRoleDuration time.Duration ClusterCABundle string @@ -40,106 +40,82 @@ type Fields struct { IsolatedVPC bool VMMemoryOverheadPercent float64 InterruptionQueueName string -} -type setFlags struct { - AssumeRoleARNSet bool - AssumeRoleDurationSet bool - ClusterCABundleSet bool - ClusterNameSet bool - ClusterEndpointSet bool - IsolatedVPCSet bool - VMMemoryOverheadPercentSet bool - InterruptionQueueNameSet bool + setFlags map[string]bool } -type Options struct { - *flag.FlagSet - Fields - setFlags +func (o *Options) AddFlags(fs *flag.FlagSet) { + fs.StringVar(&o.AssumeRoleARN, "aws-assume-role-arn", env.WithDefaultString("AWS_ASSUME_ROLE_ARN", ""), "Role to assume for calling AWS services.") + fs.DurationVar(&o.AssumeRoleDuration, "aws-assume-role-duration", env.WithDefaultDuration("AWS_ASSUME_ROLE_DURATION", 15*time.Minute), "Duration of assumed credentials in minutes. Default value is 15 minutes. Not used unless aws.assumeRole set.") + fs.StringVar(&o.ClusterCABundle, "aws-cluster-ca-bundle", env.WithDefaultString("AWS_CLUSTER_CA_BUNDLE", ""), "Cluster CA bundle for nodes to use for TLS connections with the API server. If not set, this is taken from the controller's TLS configuration.") + fs.StringVar(&o.ClusterName, "aws-cluster-name", env.WithDefaultString("AWS_CLUSTER_NAME", ""), "[REQUIRED] The kubernetes cluster name for resource discovery.") + fs.StringVar(&o.ClusterEndpoint, "aws-cluster-endpoint", env.WithDefaultString("AWS_CLUSTER_ENDPOINT", ""), "The external kubernetes cluster endpoint for new nodes to connect with. If not specified, will discover the cluster endpoint using DescribeCluster API.") + fs.BoolVar(&o.IsolatedVPC, "aws-isolated-vpc", env.WithDefaultBool("AWS_ISOLATED_VPC", false), "If true, then assume we can't reach AWS services which don't have a VPC endpoint. This also has the effect of disabling look-ups to the AWS pricing endpoint.") + fs.Float64Var(&o.VMMemoryOverheadPercent, "aws-vm-memory-overhead-percent", env.WithDefaultFloat64("AWS_VM_MEMORY_OVERHEAD_PERCENT", 0.075), "The VM memory overhead as a percent that will be subtracted from the total memory for all instance types.") + fs.StringVar(&o.InterruptionQueueName, "aws-interruption-queue-name", env.WithDefaultString("AWS_INTERRUPTION_QUEUE_NAME", ""), "aws.interruptionQueueName is disabled if not specified. Enabling interruption handling may require additional permissions on the controller service account. Additional permissions are outlined in the docs.") } -func New() *Options { - opts := &Options{} - f := flag.NewFlagSet("aws-karpenter", flag.ContinueOnError) - opts.FlagSet = f - - f.StringVar(&opts.AssumeRoleARN, "aws-assume-role-arn", env.WithDefaultString("AWS_ASSUME_ROLE_ARN", ""), "Role to assume for calling AWS services.") - f.DurationVar(&opts.AssumeRoleDuration, "aws-assume-role-duration", env.WithDefaultDuration("AWS_ASSUME_ROLE_DURATION", 15*time.Minute), "Duration of assumed credentials in minutes. Default value is 15 minutes. Not used unless aws.assumeRole set.") - f.StringVar(&opts.ClusterCABundle, "aws-cluster-ca-bundle", env.WithDefaultString("AWS_CLUSTER_CA_BUNDLE", ""), "Cluster CA bundle for nodes to use for TLS connections with the API server. If not set, this is taken from the controller's TLS configuration.") - f.StringVar(&opts.ClusterName, "aws-cluster-name", env.WithDefaultString("AWS_CLUSTER_NAME", ""), "[REQUIRED] The kubernetes cluster name for resource discovery.") - f.StringVar(&opts.ClusterEndpoint, "aws-cluster-endpoint", env.WithDefaultString("AWS_CLUSTER_ENDPOINT", ""), "The external kubernetes cluster endpoint for new nodes to connect with. If not specified, will discover the cluster endpoint using DescribeCluster API.") - f.BoolVar(&opts.IsolatedVPC, "aws-isolated-vpc", env.WithDefaultBool("AWS_ISOLATED_VPC", false), "If true, then assume we can't reach AWS services which don't have a VPC endpoint. This also has the effect of disabling look-ups to the AWS pricing endpoint.") - f.Float64Var(&opts.VMMemoryOverheadPercent, "aws-vm-memory-overhead-percent", env.WithDefaultFloat64("AWS_VM_MEMORY_OVERHEAD_PERCENT", 0.075), "The VM memory overhead as a percent that will be subtracted from the total memory for all instance types.") - f.StringVar(&opts.InterruptionQueueName, "aws-interruption-queue-name", env.WithDefaultString("AWS_INTERRUPTION_QUEUE_NAME", ""), "aws.interruptionQueueName is disabled if not specified. Enabling interruption handling may require additional permissions on the controller service account. Additional permissions are outlined in the docs.") - - return opts -} - -func (*Options) Inject(ctx context.Context, args ...string) (context.Context, error) { - o := New() - if err := o.Parse(args); err != nil { +func (o *Options) Parse(fs *flag.FlagSet, args ...string) error { + if err := fs.Parse(args); err != nil { if errors.Is(err, flag.ErrHelp) { os.Exit(0) } - return ctx, fmt.Errorf("failed to parse cli flags, %w", err) + return fmt.Errorf("failed to parse cli flags, %w", err) } if err := o.Validate(); err != nil { - return ctx, fmt.Errorf("failed to validate options, %w", err) + return fmt.Errorf("failed to validate options, %w", err) } // Check if each option has been set. This is a little brute force and better options might exist, // but this only needs to be here for one version + o.setFlags = make(map[string]bool) cliFlags := sets.New[string]() - o.Visit(func(f *flag.Flag) { + fs.Visit(func(f *flag.Flag) { cliFlags.Insert(f.Name) }) - for _, entry := range []struct { - setFlag *bool - cliName string - envName string - }{ - {&o.AssumeRoleARNSet, "aws-assume-role-arn", "AWS_ASSUME_ROLE_ARN"}, - {&o.AssumeRoleDurationSet, "aws-assume-role-duration", "AWS_ASSUME_ROLE_DURATION"}, - {&o.ClusterCABundleSet, "aws-cluster-ca-bundle", "AWS_CLUSTER_CA_BUNDLE"}, - {&o.ClusterNameSet, "aws-cluster-name", "AWS_CLUSTER_NAME"}, - {&o.ClusterEndpointSet, "aws-cluster-endpoint", "AWS_CLUSTER_ENDPOINT"}, - {&o.IsolatedVPCSet, "aws-isolated-vpc", "AWS_ISOLATED_VPC"}, - {&o.VMMemoryOverheadPercentSet, "aws-vm-memory-overhead-percent", "AWS_VM_MEMORY_OVERHEAD_PERCENT"}, - {&o.InterruptionQueueNameSet, "aws-interruption-queue-name", "AWS_INTERRUPTION_QUEUE_NAME"}, - } { - if cliFlags.Has(entry.cliName) { - *entry.setFlag = true - } - if _, ok := os.LookupEnv(entry.envName); ok { - *entry.setFlag = true - } - } + fs.VisitAll(func(f *flag.Flag) { + envName := strings.ReplaceAll(strings.ToUpper(f.Name), "-", "_") + _, ok := os.LookupEnv(envName) + o.setFlags[f.Name] = ok || cliFlags.Has(f.Name) + }) - ctx = ToContext(ctx, o) - return ctx, nil + return nil } -func (*Options) MergeSettings(ctx context.Context, injectables ...coresettings.Injectable) context.Context { - for _, in := range injectables { - _, ok := in.(*settings.Settings) - if !ok { - continue - } - s := in.FromContext(ctx).(*settings.Settings) - o := FromContext(ctx) - mergeField(&o.AssumeRoleARN, s.AssumeRoleARN, o.AssumeRoleARNSet) - mergeField(&o.AssumeRoleDuration, s.AssumeRoleDuration, o.AssumeRoleDurationSet) - mergeField(&o.ClusterCABundle, s.ClusterCABundle, o.ClusterCABundleSet) - mergeField(&o.ClusterName, s.ClusterName, o.ClusterNameSet) - mergeField(&o.ClusterEndpoint, s.ClusterEndpoint, o.ClusterEndpointSet) - mergeField(&o.IsolatedVPC, s.IsolatedVPC, o.IsolatedVPCSet) - mergeField(&o.VMMemoryOverheadPercent, s.VMMemoryOverheadPercent, o.VMMemoryOverheadPercentSet) - mergeField(&o.InterruptionQueueName, s.InterruptionQueueName, o.InterruptionQueueNameSet) - ctx = ToContext(ctx, o) +func (o *Options) ToContext(ctx context.Context) context.Context { + return ToContext(ctx, o) +} + +func (o *Options) MergeSettings(ctx context.Context) { + s := settings.FromContext(ctx) + if s == nil { + return + } + if !o.setFlags["aws-assume-role-arn"] { + o.AssumeRoleARN = s.AssumeRoleARN + } + if !o.setFlags["aws-assume-role-duration"] { + o.AssumeRoleDuration = s.AssumeRoleDuration + } + if !o.setFlags["aws-cluster-ca-bundle"] { + o.ClusterCABundle = s.ClusterCABundle + } + if !o.setFlags["aws-cluster-name"] { + o.ClusterName = s.ClusterName + } + if !o.setFlags["aws-cluster-endpoint"] { + o.ClusterEndpoint = s.ClusterEndpoint + } + if !o.setFlags["aws-isolated-vpc"] { + o.IsolatedVPC = s.IsolatedVPC + } + if !o.setFlags["aws-vm-memory-overhead-percent"] { + o.VMMemoryOverheadPercent = s.VMMemoryOverheadPercent + } + if !o.setFlags["aws-interruption-queue-name"] { + o.InterruptionQueueName = s.InterruptionQueueName } - return ctx } func ToContext(ctx context.Context, opts *Options) context.Context { @@ -153,10 +129,3 @@ func FromContext(ctx context.Context) *Options { } return retval.(*Options) } - -func mergeField[T any](dest *T, val T, isSet bool) { - if isSet { - return - } - *dest = val -} diff --git a/pkg/operator/options/options_validation.go b/pkg/operator/options/options_validation.go index f6d325046ef1..480a1fb0bea8 100644 --- a/pkg/operator/options/options_validation.go +++ b/pkg/operator/options/options_validation.go @@ -32,7 +32,8 @@ func (o Options) Validate() (errs *apis.FieldError) { } func (o Options) validateAssumeRoleDuration() (errs *apis.FieldError) { - if !o.AssumeRoleDurationSet { + // TODO: Remove with karpenter-global-settings + if !o.setFlags["aws-assume-role-arn"] { return nil } if o.AssumeRoleDuration < time.Minute*15 { @@ -42,7 +43,8 @@ func (o Options) validateAssumeRoleDuration() (errs *apis.FieldError) { } func (o Options) validateClusterName() (errs *apis.FieldError) { - if !o.ClusterNameSet { + // TODO: Remove with karpenter-global-settings + if !o.setFlags["aws-cluster-name"] { return nil } if o.ClusterName == "" { @@ -65,7 +67,8 @@ func (o Options) validateEndpoint() (errs *apis.FieldError) { } func (o Options) validateVMMemoryOverheadPercent() (errs *apis.FieldError) { - if !o.VMMemoryOverheadPercentSet { + // TODO: Remove with karpenter-global-settings + if !o.setFlags["aws-vm-memory-overhead-percent"] { return nil } if o.VMMemoryOverheadPercent < 0 { diff --git a/pkg/test/options.go b/pkg/test/options.go index 1669a213c91b..32bd9d9c1b2a 100644 --- a/pkg/test/options.go +++ b/pkg/test/options.go @@ -43,15 +43,13 @@ func Options(overrides ...OptionsFields) *options.Options { } } return &options.Options{ - Fields: options.Fields{ - AssumeRoleARN: lo.FromPtrOr(opts.AssumeRoleARN, ""), - AssumeRoleDuration: lo.FromPtrOr(opts.AssumeRoleDuration, 15*time.Minute), - ClusterCABundle: lo.FromPtrOr(opts.ClusterCABundle, ""), - ClusterName: lo.FromPtrOr(opts.ClusterName, "test-cluster"), - ClusterEndpoint: lo.FromPtrOr(opts.ClusterEndpoint, "https://test-cluster"), - IsolatedVPC: lo.FromPtrOr(opts.IsolatedVPC, false), - VMMemoryOverheadPercent: lo.FromPtrOr(opts.VMMemoryOverheadPercent, 0.075), - InterruptionQueueName: lo.FromPtrOr(opts.InterruptionQueueName, ""), - }, + AssumeRoleARN: lo.FromPtrOr(opts.AssumeRoleARN, ""), + AssumeRoleDuration: lo.FromPtrOr(opts.AssumeRoleDuration, 15*time.Minute), + ClusterCABundle: lo.FromPtrOr(opts.ClusterCABundle, ""), + ClusterName: lo.FromPtrOr(opts.ClusterName, "test-cluster"), + ClusterEndpoint: lo.FromPtrOr(opts.ClusterEndpoint, "https://test-cluster"), + IsolatedVPC: lo.FromPtrOr(opts.IsolatedVPC, false), + VMMemoryOverheadPercent: lo.FromPtrOr(opts.VMMemoryOverheadPercent, 0.075), + InterruptionQueueName: lo.FromPtrOr(opts.InterruptionQueueName, ""), } }