From 50622623f23f5730992df0e158e28504b402e5b2 Mon Sep 17 00:00:00 2001 From: rickbrouwer Date: Sat, 2 Nov 2024 13:07:37 +0100 Subject: [PATCH 1/7] Remove unnecessary validate in kubernetes workload and cpu mem (#6286) * Remove validate cpu memory scaler Signed-off-by: rickbrouwer * Remove validate kubernetes workload Signed-off-by: rickbrouwer --------- Signed-off-by: rickbrouwer --- pkg/scalers/cpu_memory_scaler.go | 11 ----------- pkg/scalers/kubernetes_workload_scaler.go | 4 ---- 2 files changed, 15 deletions(-) diff --git a/pkg/scalers/cpu_memory_scaler.go b/pkg/scalers/cpu_memory_scaler.go index ce845414966..8da440ab77e 100644 --- a/pkg/scalers/cpu_memory_scaler.go +++ b/pkg/scalers/cpu_memory_scaler.go @@ -29,10 +29,6 @@ type cpuMemoryMetadata struct { MetricType v2.MetricTargetType } -func (m *cpuMemoryMetadata) Validate() error { - return nil -} - // NewCPUMemoryScaler creates a new cpuMemoryScaler func NewCPUMemoryScaler(resourceName v1.ResourceName, config *scalersconfig.ScalerConfig) (Scaler, error) { logger := InitializeLogger(config, "cpu_memory_scaler") @@ -42,13 +38,6 @@ func NewCPUMemoryScaler(resourceName v1.ResourceName, config *scalersconfig.Scal return nil, fmt.Errorf("error parsing %s metadata: %w", resourceName, err) } - if err := meta.Validate(); err != nil { - if meta.MetricType == "" { - return nil, fmt.Errorf("metricType is required") - } - return nil, fmt.Errorf("validation error: %w", err) - } - return &cpuMemoryScaler{ metadata: meta, resourceName: resourceName, diff --git a/pkg/scalers/kubernetes_workload_scaler.go b/pkg/scalers/kubernetes_workload_scaler.go index a2a658f7cba..3023d8ed08a 100644 --- a/pkg/scalers/kubernetes_workload_scaler.go +++ b/pkg/scalers/kubernetes_workload_scaler.go @@ -87,10 +87,6 @@ func parseKubernetesWorkloadMetadata(config *scalersconfig.ScalerConfig) (kubern } meta.podSelector = selector - if err := meta.Validate(); err != nil { - return meta, err - } - return meta, nil } From a825451bddbb8d4772fadeddbe33db79324b690c Mon Sep 17 00:00:00 2001 From: Rushen Wang <45029442+dovics@users.noreply.github.com> Date: Sun, 3 Nov 2024 18:46:17 +0800 Subject: [PATCH 2/7] Refactor predictkube scaler config (#6282) Signed-off-by: wangrushen --- pkg/scalers/predictkube_scaler.go | 163 ++++++++++-------------------- 1 file changed, 55 insertions(+), 108 deletions(-) diff --git a/pkg/scalers/predictkube_scaler.go b/pkg/scalers/predictkube_scaler.go index 78e2e5b446c..fe80fda1fca 100644 --- a/pkg/scalers/predictkube_scaler.go +++ b/pkg/scalers/predictkube_scaler.go @@ -83,18 +83,51 @@ type PredictKubeScaler struct { } type predictKubeMetadata struct { - predictHorizon time.Duration - historyTimeWindow time.Duration - stepDuration time.Duration - apiKey string - prometheusAddress string - prometheusAuth *authentication.AuthMeta - query string - threshold float64 - activationThreshold float64 - triggerIndex int + PrometheusAddress string `keda:"name=prometheusAddress, order=triggerMetadata"` + PrometheusAuth *authentication.Config `keda:"optional"` + Query string `keda:"name=query, order=triggerMetadata"` + PredictHorizon string `keda:"name=predictHorizon, order=triggerMetadata"` + QueryStep string `keda:"name=queryStep, order=triggerMetadata"` + HistoryTimeWindow string `keda:"name=historyTimeWindow, order=triggerMetadata"` + APIKey string `keda:"name=apiKey, order=authParams"` + Threshold float64 `keda:"name=threshold, order=triggerMetadata, optional"` + ActivationThreshold float64 `keda:"name=activationThreshold, order=triggerMetadata, optional"` + + predictHorizon time.Duration + historyTimeWindow time.Duration + stepDuration time.Duration + triggerIndex int } +func (p *predictKubeMetadata) Validate() error { + validate := validator.New() + err := validate.Var(p.PrometheusAddress, "url") + if err != nil { + return fmt.Errorf("invalid prometheusAddress") + } + + p.predictHorizon, err = str2duration.ParseDuration(p.PredictHorizon) + if err != nil { + return fmt.Errorf("predictHorizon parsing error %w", err) + } + + p.stepDuration, err = str2duration.ParseDuration(p.QueryStep) + if err != nil { + return fmt.Errorf("queryStep parsing error %w", err) + } + + p.historyTimeWindow, err = str2duration.ParseDuration(p.HistoryTimeWindow) + if err != nil { + return fmt.Errorf("historyTimeWindow parsing error %w", err) + } + + err = validate.Var(p.APIKey, "jwt") + if err != nil { + return fmt.Errorf("invalid apiKey") + } + + return nil +} func (s *PredictKubeScaler) setupClientConn() error { clientOpt, err := pc.SetGrpcClientOptions(grpcConf, &libs.Base{ @@ -108,7 +141,7 @@ func (s *PredictKubeScaler) setupClientConn() error { Enabled: false, }, }, - pc.InjectPublicClientMetadataInterceptor(s.metadata.apiKey), + pc.InjectPublicClientMetadataInterceptor(s.metadata.APIKey), ) if !grpcConf.Conn.Insecure { @@ -186,7 +219,7 @@ func (s *PredictKubeScaler) GetMetricSpecForScaling(context.Context) []v2.Metric Metric: v2.MetricIdentifier{ Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, metricName), }, - Target: GetMetricTargetMili(s.metricType, s.metadata.threshold), + Target: GetMetricTargetMili(s.metricType, s.metadata.Threshold), } metricSpec := v2.MetricSpec{ @@ -211,7 +244,7 @@ func (s *PredictKubeScaler) GetMetricsAndActivity(ctx context.Context, metricNam metric := GenerateMetricInMili(metricName, value) - return []external_metrics.ExternalMetricValue{metric}, activationValue > s.metadata.activationThreshold, nil + return []external_metrics.ExternalMetricValue{metric}, activationValue > s.metadata.ActivationThreshold, nil } func (s *PredictKubeScaler) doPredictRequest(ctx context.Context) (float64, float64, error) { @@ -257,7 +290,7 @@ func (s *PredictKubeScaler) doQuery(ctx context.Context) ([]*commonproto.Item, e Step: s.metadata.stepDuration, } - val, warns, err := s.api.QueryRange(ctx, s.metadata.query, r) + val, warns, err := s.api.QueryRange(ctx, s.metadata.Query, r) if len(warns) > 0 { s.logger.V(1).Info("warnings", warns) @@ -345,103 +378,17 @@ func (s *PredictKubeScaler) parsePrometheusResult(result model.Value) (out []*co } func parsePredictKubeMetadata(config *scalersconfig.ScalerConfig) (result *predictKubeMetadata, err error) { - validate := validator.New() - meta := predictKubeMetadata{} - - if val, ok := config.TriggerMetadata["query"]; ok { - if len(val) == 0 { - return nil, fmt.Errorf("no query given") - } - - meta.query = val - } else { - return nil, fmt.Errorf("no query given") - } - - if val, ok := config.TriggerMetadata["prometheusAddress"]; ok { - err = validate.Var(val, "url") - if err != nil { - return nil, fmt.Errorf("invalid prometheusAddress") - } - - meta.prometheusAddress = val - } else { - return nil, fmt.Errorf("no prometheusAddress given") - } - - if val, ok := config.TriggerMetadata["predictHorizon"]; ok { - predictHorizon, err := str2duration.ParseDuration(val) - if err != nil { - return nil, fmt.Errorf("predictHorizon parsing error %w", err) - } - meta.predictHorizon = predictHorizon - } else { - return nil, fmt.Errorf("no predictHorizon given") - } - - if val, ok := config.TriggerMetadata["queryStep"]; ok { - stepDuration, err := str2duration.ParseDuration(val) - if err != nil { - return nil, fmt.Errorf("queryStep parsing error %w", err) - } - meta.stepDuration = stepDuration - } else { - return nil, fmt.Errorf("no queryStep given") - } - - if val, ok := config.TriggerMetadata["historyTimeWindow"]; ok { - historyTimeWindow, err := str2duration.ParseDuration(val) - if err != nil { - return nil, fmt.Errorf("historyTimeWindow parsing error %w", err) - } - meta.historyTimeWindow = historyTimeWindow - } else { - return nil, fmt.Errorf("no historyTimeWindow given") - } - - if val, ok := config.TriggerMetadata["threshold"]; ok { - threshold, err := strconv.ParseFloat(val, 64) - if err != nil { - return nil, fmt.Errorf("threshold parsing error %w", err) - } - meta.threshold = threshold - } else { - if config.AsMetricSource { - meta.threshold = 0 - } else { - return nil, fmt.Errorf("no threshold given") - } + meta := &predictKubeMetadata{} + if err := config.TypedConfig(meta); err != nil { + return nil, fmt.Errorf("error parsing arango metadata: %w", err) } - meta.activationThreshold = 0 - if val, ok := config.TriggerMetadata["activationThreshold"]; ok { - activationThreshold, err := strconv.ParseFloat(val, 64) - if err != nil { - return nil, fmt.Errorf("activationThreshold parsing error %w", err) - } - meta.activationThreshold = activationThreshold + if !config.AsMetricSource && meta.Threshold == 0 { + return nil, fmt.Errorf("no threshold given") } meta.triggerIndex = config.TriggerIndex - - if val, ok := config.AuthParams["apiKey"]; ok { - err = validate.Var(val, "jwt") - if err != nil { - return nil, fmt.Errorf("invalid apiKey") - } - - meta.apiKey = val - } else { - return nil, fmt.Errorf("no api key given") - } - - // parse auth configs from ScalerConfig - auth, err := authentication.GetAuthConfigs(config.TriggerMetadata, config.AuthParams) - if err != nil { - return nil, err - } - meta.prometheusAuth = auth - return &meta, nil + return meta, nil } func (s *PredictKubeScaler) ping(ctx context.Context) (err error) { @@ -454,14 +401,14 @@ func (s *PredictKubeScaler) initPredictKubePrometheusConn(ctx context.Context) ( // create http.RoundTripper with auth settings from ScalerConfig roundTripper, err := authentication.CreateHTTPRoundTripper( authentication.FastHTTP, - s.metadata.prometheusAuth, + s.metadata.PrometheusAuth.ToAuthMeta(), ) if err != nil { s.logger.V(1).Error(err, "init Prometheus client http transport") return err } client, err := api.NewClient(api.Config{ - Address: s.metadata.prometheusAddress, + Address: s.metadata.PrometheusAddress, RoundTripper: roundTripper, }) if err != nil { From baec71580e070eccbb7f6a0f7c032952c5a27151 Mon Sep 17 00:00:00 2001 From: rickbrouwer Date: Sun, 3 Nov 2024 11:53:46 +0100 Subject: [PATCH 3/7] Refactor couchdb scaler (#6267) * Refactor couchdb scaler Signed-off-by: rickbrouwer * Update Signed-off-by: rickbrouwer --------- Signed-off-by: rickbrouwer --- pkg/scalers/couchdb_scaler.go | 263 ++++++++++++----------------- pkg/scalers/couchdb_scaler_test.go | 59 +++++-- 2 files changed, 151 insertions(+), 171 deletions(-) diff --git a/pkg/scalers/couchdb_scaler.go b/pkg/scalers/couchdb_scaler.go index 62ab5890493..b84332b7127 100644 --- a/pkg/scalers/couchdb_scaler.go +++ b/pkg/scalers/couchdb_scaler.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "net" - "strconv" couchdb "github.com/go-kivik/couchdb/v3" "github.com/go-kivik/kivik/v3" @@ -19,216 +18,168 @@ import ( type couchDBScaler struct { metricType v2.MetricTargetType - metadata *couchDBMetadata + metadata couchDBMetadata client *kivik.Client logger logr.Logger } +type couchDBMetadata struct { + ConnectionString string `keda:"name=connectionString,order=authParams;triggerMetadata;resolvedEnv,optional"` + Host string `keda:"name=host,order=authParams;triggerMetadata,optional"` + Port string `keda:"name=port,order=authParams;triggerMetadata,optional"` + Username string `keda:"name=username,order=authParams;triggerMetadata,optional"` + Password string `keda:"name=password,order=authParams;triggerMetadata;resolvedEnv,optional"` + DBName string `keda:"name=dbName,order=authParams;triggerMetadata,optional"` + Query string `keda:"name=query,order=triggerMetadata,optional"` + QueryValue int64 `keda:"name=queryValue,order=triggerMetadata,optional"` + ActivationQueryValue int64 `keda:"name=activationQueryValue,order=triggerMetadata,default=0,optional"` + TriggerIndex int +} + +func (m *couchDBMetadata) Validate() error { + if m.ConnectionString == "" { + if m.Host == "" { + return fmt.Errorf("no host given") + } + if m.Port == "" { + return fmt.Errorf("no port given") + } + if m.Username == "" { + return fmt.Errorf("no username given") + } + if m.Password == "" { + return fmt.Errorf("no password given") + } + if m.DBName == "" { + return fmt.Errorf("no dbName given") + } + } + return nil +} + type couchDBQueryRequest struct { Selector map[string]interface{} `json:"selector"` Fields []string `json:"fields"` } -type couchDBMetadata struct { - connectionString string - host string - port string - username string - password string - dbName string - query string - queryValue int64 - activationQueryValue int64 - triggerIndex int -} - type Res struct { ID string `json:"_id"` Feet int `json:"feet"` Greeting string `json:"greeting"` } -func (s *couchDBScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpec { - externalMetric := &v2.ExternalMetricSource{ - Metric: v2.MetricIdentifier{ - Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, kedautil.NormalizeString(fmt.Sprintf("coucdb-%s", s.metadata.dbName))), - }, - Target: GetMetricTarget(s.metricType, s.metadata.queryValue), +func NewCouchDBScaler(ctx context.Context, config *scalersconfig.ScalerConfig) (Scaler, error) { + metricType, err := GetMetricTargetType(config) + if err != nil { + return nil, fmt.Errorf("error getting scaler metric type: %w", err) } - metricSpec := v2.MetricSpec{ - External: externalMetric, Type: externalMetricType, + + meta, err := parseCouchDBMetadata(config) + if err != nil { + return nil, fmt.Errorf("error parsing couchdb metadata: %w", err) } - return []v2.MetricSpec{metricSpec} -} -func (s couchDBScaler) Close(ctx context.Context) error { - if s.client != nil { - err := s.client.Close(ctx) - if err != nil { - s.logger.Error(err, fmt.Sprintf("failed to close couchdb connection, because of %v", err)) - return err - } + connStr := meta.ConnectionString + if connStr == "" { + addr := net.JoinHostPort(meta.Host, meta.Port) + connStr = "http://" + addr } - return nil -} -func (s *couchDBScaler) getQueryResult(ctx context.Context) (int64, error) { - db := s.client.DB(ctx, s.metadata.dbName) - var request couchDBQueryRequest - err := json.Unmarshal([]byte(s.metadata.query), &request) + client, err := kivik.New("couch", connStr) if err != nil { - s.logger.Error(err, fmt.Sprintf("Couldn't unmarshal query string because of %v", err)) - return 0, err + return nil, fmt.Errorf("error creating couchdb client: %w", err) } - rows, err := db.Find(ctx, request, nil) + + err = client.Authenticate(ctx, couchdb.BasicAuth("admin", meta.Password)) if err != nil { - s.logger.Error(err, fmt.Sprintf("failed to fetch rows because of %v", err)) - return 0, err + return nil, fmt.Errorf("error authenticating with couchdb: %w", err) } - var count int64 - for rows.Next() { - count++ - res := Res{} - if err := rows.ScanDoc(&res); err != nil { - s.logger.Error(err, fmt.Sprintf("failed to scan the doc because of %v", err)) - return 0, err - } + + isConnected, err := client.Ping(ctx) + if !isConnected || err != nil { + return nil, fmt.Errorf("failed to ping couchdb: %w", err) } - return count, nil + + return &couchDBScaler{ + metricType: metricType, + metadata: meta, + client: client, + logger: InitializeLogger(config, "couchdb_scaler"), + }, nil } -func parseCouchDBMetadata(config *scalersconfig.ScalerConfig) (*couchDBMetadata, string, error) { - var connStr string - var err error +func parseCouchDBMetadata(config *scalersconfig.ScalerConfig) (couchDBMetadata, error) { meta := couchDBMetadata{} - - if val, ok := config.TriggerMetadata["query"]; ok { - meta.query = val - } else { - return nil, "", fmt.Errorf("no query given") + err := config.TypedConfig(&meta) + if err != nil { + return meta, fmt.Errorf("error parsing couchdb metadata: %w", err) } - if val, ok := config.TriggerMetadata["queryValue"]; ok { - queryValue, err := strconv.ParseInt(val, 10, 64) - if err != nil { - return nil, "", fmt.Errorf("failed to convert %v to int, because of %w", val, err) - } - meta.queryValue = queryValue - } else { - if config.AsMetricSource { - meta.queryValue = 0 - } else { - return nil, "", fmt.Errorf("no queryValue given") - } + if meta.QueryValue == 0 && !config.AsMetricSource { + return meta, fmt.Errorf("no queryValue given") } - meta.activationQueryValue = 0 - if val, ok := config.TriggerMetadata["activationQueryValue"]; ok { - activationQueryValue, err := strconv.ParseInt(val, 10, 64) - if err != nil { - return nil, "", fmt.Errorf("failed to convert %v to int, because of %w", val, err) - } - meta.activationQueryValue = activationQueryValue + if config.AsMetricSource { + meta.QueryValue = 0 } - dbName, err := GetFromAuthOrMeta(config, "dbName") - if err != nil { - return nil, "", err - } - meta.dbName = dbName - - switch { - case config.AuthParams["connectionString"] != "": - meta.connectionString = config.AuthParams["connectionString"] - case config.TriggerMetadata["connectionStringFromEnv"] != "": - meta.connectionString = config.ResolvedEnv[config.TriggerMetadata["connectionStringFromEnv"]] - default: - meta.connectionString = "" - host, err := GetFromAuthOrMeta(config, "host") - if err != nil { - return nil, "", err - } - meta.host = host - - port, err := GetFromAuthOrMeta(config, "port") - if err != nil { - return nil, "", err - } - meta.port = port - - username, err := GetFromAuthOrMeta(config, "username") - if err != nil { - return nil, "", err - } - meta.username = username + meta.TriggerIndex = config.TriggerIndex + return meta, nil +} - if config.AuthParams["password"] != "" { - meta.password = config.AuthParams["password"] - } else if config.TriggerMetadata["passwordFromEnv"] != "" { - meta.password = config.ResolvedEnv[config.TriggerMetadata["passwordFromEnv"]] - } - if len(meta.password) == 0 { - return nil, "", fmt.Errorf("no password given") +func (s *couchDBScaler) Close(ctx context.Context) error { + if s.client != nil { + if err := s.client.Close(ctx); err != nil { + s.logger.Error(err, "failed to close couchdb connection") + return err } } - - if meta.connectionString != "" { - connStr = meta.connectionString - } else { - // Build connection str - addr := net.JoinHostPort(meta.host, meta.port) - // nosemgrep: db-connection-string - connStr = "http://" + addr - } - meta.triggerIndex = config.TriggerIndex - return &meta, connStr, nil + return nil } -func NewCouchDBScaler(ctx context.Context, config *scalersconfig.ScalerConfig) (Scaler, error) { - metricType, err := GetMetricTargetType(config) - if err != nil { - return nil, fmt.Errorf("error getting scaler metric type: %w", err) +func (s *couchDBScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpec { + metricName := kedautil.NormalizeString(fmt.Sprintf("coucdb-%s", s.metadata.DBName)) + externalMetric := &v2.ExternalMetricSource{ + Metric: v2.MetricIdentifier{ + Name: GenerateMetricNameWithIndex(s.metadata.TriggerIndex, metricName), + }, + Target: GetMetricTarget(s.metricType, s.metadata.QueryValue), } + metricSpec := v2.MetricSpec{External: externalMetric, Type: externalMetricType} + return []v2.MetricSpec{metricSpec} +} - meta, connStr, err := parseCouchDBMetadata(config) - if err != nil { - return nil, fmt.Errorf("failed to parsing couchDB metadata, because of %w", err) - } +func (s *couchDBScaler) getQueryResult(ctx context.Context) (int64, error) { + db := s.client.DB(ctx, s.metadata.DBName) - client, err := kivik.New("couch", connStr) - if err != nil { - return nil, fmt.Errorf("%w", err) + var request couchDBQueryRequest + if err := json.Unmarshal([]byte(s.metadata.Query), &request); err != nil { + return 0, fmt.Errorf("error unmarshaling query: %w", err) } - err = client.Authenticate(ctx, couchdb.BasicAuth("admin", meta.password)) + rows, err := db.Find(ctx, request, nil) if err != nil { - return nil, err + return 0, fmt.Errorf("error executing query: %w", err) } - isconnected, err := client.Ping(ctx) - if !isconnected { - return nil, fmt.Errorf("%w", err) - } - if err != nil { - return nil, fmt.Errorf("failed to ping couchDB, because of %w", err) + var count int64 + for rows.Next() { + count++ + var res Res + if err := rows.ScanDoc(&res); err != nil { + return 0, fmt.Errorf("error scanning document: %w", err) + } } - return &couchDBScaler{ - metricType: metricType, - metadata: meta, - client: client, - logger: InitializeLogger(config, "couchdb_scaler"), - }, nil + return count, nil } -// GetMetricsAndActivity query from couchDB,and return to external metrics and activity func (s *couchDBScaler) GetMetricsAndActivity(ctx context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) { result, err := s.getQueryResult(ctx) if err != nil { - return []external_metrics.ExternalMetricValue{}, false, fmt.Errorf("failed to inspect couchdb, because of %w", err) + return []external_metrics.ExternalMetricValue{}, false, fmt.Errorf("failed to inspect couchdb: %w", err) } metric := GenerateMetricInMili(metricName, float64(result)) - - return append([]external_metrics.ExternalMetricValue{}, metric), result > s.metadata.activationQueryValue, nil + return []external_metrics.ExternalMetricValue{metric}, result > s.metadata.ActivationQueryValue, nil } diff --git a/pkg/scalers/couchdb_scaler_test.go b/pkg/scalers/couchdb_scaler_test.go index af7ae36b9b1..54b4a4b5b5a 100644 --- a/pkg/scalers/couchdb_scaler_test.go +++ b/pkg/scalers/couchdb_scaler_test.go @@ -7,6 +7,7 @@ import ( _ "github.com/go-kivik/couchdb/v3" "github.com/go-kivik/kivik/v3" "github.com/go-logr/logr" + v2 "k8s.io/api/autoscaling/v2" "github.com/kedacore/keda/v2/pkg/scalers/scalersconfig" ) @@ -17,6 +18,7 @@ var testCouchDBResolvedEnv = map[string]string{ } type parseCouchDBMetadataTestData struct { + name string metadata map[string]string authParams map[string]string resolvedEnv map[string]string @@ -32,6 +34,7 @@ type couchDBMetricIdentifier struct { var testCOUCHDBMetadata = []parseCouchDBMetadataTestData{ // No metadata { + name: "no metadata", metadata: map[string]string{}, authParams: map[string]string{}, resolvedEnv: testCouchDBResolvedEnv, @@ -39,6 +42,7 @@ var testCOUCHDBMetadata = []parseCouchDBMetadataTestData{ }, // connectionStringFromEnv { + name: "with connectionStringFromEnv", metadata: map[string]string{"query": `{ "selector": { "feet": { "$gt": 0 } }, "fields": ["_id", "feet", "greeting"] }`, "queryValue": "1", "connectionStringFromEnv": "CouchDB_CONN_STR", "dbName": "animals"}, authParams: map[string]string{}, resolvedEnv: testCouchDBResolvedEnv, @@ -46,6 +50,7 @@ var testCOUCHDBMetadata = []parseCouchDBMetadataTestData{ }, // with metric name { + name: "with metric name", metadata: map[string]string{"query": `{ "selector": { "feet": { "$gt": 0 } }, "fields": ["_id", "feet", "greeting"] }`, "queryValue": "1", "connectionStringFromEnv": "CouchDB_CONN_STR", "dbName": "animals"}, authParams: map[string]string{}, resolvedEnv: testCouchDBResolvedEnv, @@ -53,6 +58,7 @@ var testCOUCHDBMetadata = []parseCouchDBMetadataTestData{ }, // from trigger auth { + name: "from trigger auth", metadata: map[string]string{"query": `{ "selector": { "feet": { "$gt": 0 } }, "fields": ["_id", "feet", "greeting"] }`, "queryValue": "1"}, authParams: map[string]string{"dbName": "animals", "host": "localhost", "port": "5984", "username": "admin", "password": "YeFvQno9LylIm5MDgwcV"}, resolvedEnv: testCouchDBResolvedEnv, @@ -60,7 +66,8 @@ var testCOUCHDBMetadata = []parseCouchDBMetadataTestData{ }, // wrong activationQueryValue { - metadata: map[string]string{"query": `{ "selector": { "feet": { "$gt": 0 } }, "fields": ["_id", "feet", "greeting"] }`, "queryValue": "1", "activationQueryValue": "1", "connectionStringFromEnv": "CouchDB_CONN_STR", "dbName": "animals"}, + name: "wrong activationQueryValue", + metadata: map[string]string{"query": `{ "selector": { "feet": { "$gt": 0 } }, "fields": ["_id", "feet", "greeting"] }`, "queryValue": "1", "activationQueryValue": "a", "connectionStringFromEnv": "CouchDB_CONN_STR", "dbName": "animals"}, authParams: map[string]string{}, resolvedEnv: testCouchDBResolvedEnv, raisesError: true, @@ -74,25 +81,47 @@ var couchDBMetricIdentifiers = []couchDBMetricIdentifier{ func TestParseCouchDBMetadata(t *testing.T) { for _, testData := range testCOUCHDBMetadata { - _, _, err := parseCouchDBMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: testData.metadata, AuthParams: testData.authParams}) - if err != nil && !testData.raisesError { - t.Error("Expected success but got error:", err) - } + t.Run(testData.name, func(t *testing.T) { + _, err := parseCouchDBMetadata(&scalersconfig.ScalerConfig{ + TriggerMetadata: testData.metadata, + AuthParams: testData.authParams, + ResolvedEnv: testData.resolvedEnv, + }) + if err != nil && !testData.raisesError { + t.Errorf("Test case '%s': Expected success but got error: %v", testData.name, err) + } + if testData.raisesError && err == nil { + t.Errorf("Test case '%s': Expected error but got success", testData.name) + } + }) } } func TestCouchDBGetMetricSpecForScaling(t *testing.T) { for _, testData := range couchDBMetricIdentifiers { - meta, _, err := parseCouchDBMetadata(&scalersconfig.ScalerConfig{ResolvedEnv: testData.metadataTestData.resolvedEnv, AuthParams: testData.metadataTestData.authParams, TriggerMetadata: testData.metadataTestData.metadata, TriggerIndex: testData.triggerIndex}) - if err != nil { - t.Fatal("Could not parse metadata:", err) - } - mockCouchDBScaler := couchDBScaler{"", meta, &kivik.Client{}, logr.Discard()} + t.Run(testData.name, func(t *testing.T) { + meta, err := parseCouchDBMetadata(&scalersconfig.ScalerConfig{ + ResolvedEnv: testData.metadataTestData.resolvedEnv, + AuthParams: testData.metadataTestData.authParams, + TriggerMetadata: testData.metadataTestData.metadata, + TriggerIndex: testData.triggerIndex, + }) + if err != nil { + t.Fatal("Could not parse metadata:", err) + } - metricSpec := mockCouchDBScaler.GetMetricSpecForScaling(context.Background()) - metricName := metricSpec[0].External.Metric.Name - if metricName != testData.name { - t.Error("Wrong External metric source name:", metricName) - } + mockCouchDBScaler := couchDBScaler{ + metricType: v2.AverageValueMetricType, + metadata: meta, + client: &kivik.Client{}, + logger: logr.Discard(), + } + + metricSpec := mockCouchDBScaler.GetMetricSpecForScaling(context.Background()) + metricName := metricSpec[0].External.Metric.Name + if metricName != testData.name { + t.Errorf("Wrong External metric source name: %s, expected: %s", metricName, testData.name) + } + }) } } From b34fad033b08cf3383cdd23e05fdf3e49c662459 Mon Sep 17 00:00:00 2001 From: rickbrouwer Date: Sun, 3 Nov 2024 12:29:34 +0100 Subject: [PATCH 4/7] Refactor Cassandra scaler (#6275) * Refactor cassandra scaler Signed-off-by: rickbrouwer * Update feedback Signed-off-by: rickbrouwer --------- Signed-off-by: rickbrouwer --- pkg/scalers/cassandra_scaler.go | 291 +++++++-------------- pkg/scalers/cassandra_scaler_test.go | 368 +++++++++++++++++++-------- 2 files changed, 364 insertions(+), 295 deletions(-) diff --git a/pkg/scalers/cassandra_scaler.go b/pkg/scalers/cassandra_scaler.go index 6e8705d2d8d..b41dddb9dec 100644 --- a/pkg/scalers/cassandra_scaler.go +++ b/pkg/scalers/cassandra_scaler.go @@ -18,53 +18,70 @@ import ( kedautil "github.com/kedacore/keda/v2/pkg/util" ) -// cassandraScaler exposes a data pointer to CassandraMetadata and gocql.Session connection. type cassandraScaler struct { metricType v2.MetricTargetType - metadata *CassandraMetadata + metadata cassandraMetadata session *gocql.Session logger logr.Logger } -// CassandraMetadata defines metadata used by KEDA to query a Cassandra table. -type CassandraMetadata struct { - username string - password string - enableTLS bool - cert string - key string - ca string - clusterIPAddress string - port int - consistency gocql.Consistency - protocolVersion int - keyspace string - query string - targetQueryValue int64 - activationTargetQueryValue int64 - triggerIndex int +type cassandraMetadata struct { + Username string `keda:"name=username, order=triggerMetadata"` + Password string `keda:"name=password, order=authParams"` + TLS string `keda:"name=tls, order=authParams, enum=enable;disable, default=disable, optional"` + Cert string `keda:"name=cert, order=authParams, optional"` + Key string `keda:"name=key, order=authParams, optional"` + CA string `keda:"name=ca, order=authParams, optional"` + ClusterIPAddress string `keda:"name=clusterIPAddress, order=triggerMetadata"` + Port int `keda:"name=port, order=triggerMetadata, optional"` + Consistency string `keda:"name=consistency, order=triggerMetadata, default=one"` + ProtocolVersion int `keda:"name=protocolVersion, order=triggerMetadata, default=4"` + Keyspace string `keda:"name=keyspace, order=triggerMetadata"` + Query string `keda:"name=query, order=triggerMetadata"` + TargetQueryValue int64 `keda:"name=targetQueryValue, order=triggerMetadata"` + ActivationTargetQueryValue int64 `keda:"name=activationTargetQueryValue, order=triggerMetadata, default=0"` + TriggerIndex int } const ( - tlsEnable = "enable" - tlsDisable = "disable" + tlsEnable = "enable" ) -// NewCassandraScaler creates a new Cassandra scaler. +func (m *cassandraMetadata) Validate() error { + if m.TLS == tlsEnable && (m.Cert == "" || m.Key == "") { + return errors.New("both cert and key are required when TLS is enabled") + } + + // Handle port in ClusterIPAddress + splitVal := strings.Split(m.ClusterIPAddress, ":") + if len(splitVal) == 2 { + if port, err := strconv.Atoi(splitVal[1]); err == nil { + m.Port = port + return nil + } + } + + if m.Port == 0 { + return fmt.Errorf("no port given") + } + + m.ClusterIPAddress = net.JoinHostPort(m.ClusterIPAddress, fmt.Sprintf("%d", m.Port)) + return nil +} + +// NewCassandraScaler creates a new Cassandra scaler func NewCassandraScaler(config *scalersconfig.ScalerConfig) (Scaler, error) { metricType, err := GetMetricTargetType(config) if err != nil { return nil, fmt.Errorf("error getting scaler metric type: %w", err) } - logger := InitializeLogger(config, "cassandra_scaler") - meta, err := parseCassandraMetadata(config) if err != nil { return nil, fmt.Errorf("error parsing cassandra metadata: %w", err) } - session, err := newCassandraSession(meta, logger) + session, err := newCassandraSession(meta, InitializeLogger(config, "cassandra_scaler")) if err != nil { return nil, fmt.Errorf("error establishing cassandra session: %w", err) } @@ -73,108 +90,27 @@ func NewCassandraScaler(config *scalersconfig.ScalerConfig) (Scaler, error) { metricType: metricType, metadata: meta, session: session, - logger: logger, + logger: InitializeLogger(config, "cassandra_scaler"), }, nil } -// parseCassandraMetadata parses the metadata and returns a CassandraMetadata or an error if the ScalerConfig is invalid. -func parseCassandraMetadata(config *scalersconfig.ScalerConfig) (*CassandraMetadata, error) { - meta := &CassandraMetadata{} - var err error - - if val, ok := config.TriggerMetadata["query"]; ok { - meta.query = val - } else { - return nil, fmt.Errorf("no query given") - } - - if val, ok := config.TriggerMetadata["targetQueryValue"]; ok { - targetQueryValue, err := strconv.ParseInt(val, 10, 64) - if err != nil { - return nil, fmt.Errorf("targetQueryValue parsing error %w", err) - } - meta.targetQueryValue = targetQueryValue - } else { - if config.AsMetricSource { - meta.targetQueryValue = 0 - } else { - return nil, fmt.Errorf("no targetQueryValue given") - } - } - - meta.activationTargetQueryValue = 0 - if val, ok := config.TriggerMetadata["activationTargetQueryValue"]; ok { - activationTargetQueryValue, err := strconv.ParseInt(val, 10, 64) - if err != nil { - return nil, fmt.Errorf("activationTargetQueryValue parsing error %w", err) - } - meta.activationTargetQueryValue = activationTargetQueryValue - } - - if val, ok := config.TriggerMetadata["username"]; ok { - meta.username = val - } else { - return nil, fmt.Errorf("no username given") - } - - if val, ok := config.TriggerMetadata["port"]; ok { - port, err := strconv.Atoi(val) - if err != nil { - return nil, fmt.Errorf("port parsing error %w", err) - } - meta.port = port - } - - if val, ok := config.TriggerMetadata["clusterIPAddress"]; ok { - splitval := strings.Split(val, ":") - port := splitval[len(splitval)-1] - - _, err := strconv.Atoi(port) - switch { - case err == nil: - meta.clusterIPAddress = val - case meta.port > 0: - meta.clusterIPAddress = net.JoinHostPort(val, fmt.Sprintf("%d", meta.port)) - default: - return nil, fmt.Errorf("no port given") - } - } else { - return nil, fmt.Errorf("no cluster IP address given") - } - - if val, ok := config.TriggerMetadata["protocolVersion"]; ok { - protocolVersion, err := strconv.Atoi(val) - if err != nil { - return nil, fmt.Errorf("protocolVersion parsing error %w", err) - } - meta.protocolVersion = protocolVersion - } else { - meta.protocolVersion = 4 - } - - if val, ok := config.TriggerMetadata["consistency"]; ok { - meta.consistency = gocql.ParseConsistency(val) - } else { - meta.consistency = gocql.One +func parseCassandraMetadata(config *scalersconfig.ScalerConfig) (cassandraMetadata, error) { + meta := cassandraMetadata{} + err := config.TypedConfig(&meta) + if err != nil { + return meta, fmt.Errorf("error parsing cassandra metadata: %w", err) } - if val, ok := config.TriggerMetadata["keyspace"]; ok { - meta.keyspace = val - } else { - return nil, fmt.Errorf("no keyspace given") - } - if val, ok := config.AuthParams["password"]; ok { - meta.password = val - } else { - return nil, fmt.Errorf("no password given") + if config.AsMetricSource { + meta.TargetQueryValue = 0 } - if err = parseCassandraTLS(config, meta); err != nil { + err = parseCassandraTLS(&meta) + if err != nil { return meta, err } - meta.triggerIndex = config.TriggerIndex - + meta.TriggerIndex = config.TriggerIndex return meta, nil } @@ -182,8 +118,8 @@ func createTempFile(prefix string, content string) (string, error) { tempCassandraDir := fmt.Sprintf("%s%c%s", os.TempDir(), os.PathSeparator, "cassandra") err := os.MkdirAll(tempCassandraDir, 0700) if err != nil { - return "", fmt.Errorf(`error creating temporary directory: %s. Error: %w - Note, when running in a container a writable /tmp/cassandra emptyDir must be mounted. Refer to documentation`, tempCassandraDir, err) + return "", fmt.Errorf(`error creating temporary directory: %s. Error: %w + Note, when running in a container a writable /tmp/cassandra emptyDir must be mounted. Refer to documentation`, tempCassandraDir, err) } f, err := os.CreateTemp(tempCassandraDir, prefix+"-*.pem") @@ -200,72 +136,48 @@ func createTempFile(prefix string, content string) (string, error) { return f.Name(), nil } -func parseCassandraTLS(config *scalersconfig.ScalerConfig, meta *CassandraMetadata) error { - meta.enableTLS = false - if val, ok := config.AuthParams["tls"]; ok { - val = strings.TrimSpace(val) - if val == tlsEnable { - certGiven := config.AuthParams["cert"] != "" - keyGiven := config.AuthParams["key"] != "" - caCertGiven := config.AuthParams["ca"] != "" - if certGiven && !keyGiven { - return errors.New("no key given") - } - if keyGiven && !certGiven { - return errors.New("no cert given") - } - if !keyGiven && !certGiven { - return errors.New("no cert/key given") - } +func parseCassandraTLS(meta *cassandraMetadata) error { + if meta.TLS == tlsEnable { + // Create temp files for certs + certFilePath, err := createTempFile("cert", meta.Cert) + if err != nil { + return fmt.Errorf("error creating cert file: %w", err) + } + meta.Cert = certFilePath - certFilePath, err := createTempFile("cert", config.AuthParams["cert"]) - if err != nil { - // handle error - return errors.New("Error creating cert file: " + err.Error()) - } + keyFilePath, err := createTempFile("key", meta.Key) + if err != nil { + return fmt.Errorf("error creating key file: %w", err) + } + meta.Key = keyFilePath - keyFilePath, err := createTempFile("key", config.AuthParams["key"]) + // If CA cert is given, make also file + if meta.CA != "" { + caCertFilePath, err := createTempFile("caCert", meta.CA) if err != nil { - // handle error - return errors.New("Error creating key file: " + err.Error()) + return fmt.Errorf("error creating ca file: %w", err) } - - meta.cert = certFilePath - meta.key = keyFilePath - meta.ca = config.AuthParams["ca"] - if !caCertGiven { - meta.ca = "" - } else { - caCertFilePath, err := createTempFile("caCert", config.AuthParams["ca"]) - meta.ca = caCertFilePath - if err != nil { - // handle error - return errors.New("Error creating ca file: " + err.Error()) - } - } - meta.enableTLS = true - } else if val != tlsDisable { - return fmt.Errorf("err incorrect value for TLS given: %s", val) + meta.CA = caCertFilePath } } return nil } -// newCassandraSession returns a new Cassandra session for the provided CassandraMetadata. -func newCassandraSession(meta *CassandraMetadata, logger logr.Logger) (*gocql.Session, error) { - cluster := gocql.NewCluster(meta.clusterIPAddress) - cluster.ProtoVersion = meta.protocolVersion - cluster.Consistency = meta.consistency +// newCassandraSession returns a new Cassandra session for the provided CassandraMetadata +func newCassandraSession(meta cassandraMetadata, logger logr.Logger) (*gocql.Session, error) { + cluster := gocql.NewCluster(meta.ClusterIPAddress) + cluster.ProtoVersion = meta.ProtocolVersion + cluster.Consistency = gocql.ParseConsistency(meta.Consistency) cluster.Authenticator = gocql.PasswordAuthenticator{ - Username: meta.username, - Password: meta.password, + Username: meta.Username, + Password: meta.Password, } - if meta.enableTLS { + if meta.TLS == tlsEnable { cluster.SslOpts = &gocql.SslOptions{ - CertPath: meta.cert, - KeyPath: meta.key, - CaPath: meta.ca, + CertPath: meta.Cert, + KeyPath: meta.Key, + CaPath: meta.CA, } } @@ -278,22 +190,19 @@ func newCassandraSession(meta *CassandraMetadata, logger logr.Logger) (*gocql.Se return session, nil } -// GetMetricSpecForScaling returns the MetricSpec for the Horizontal Pod Autoscaler. +// GetMetricSpecForScaling returns the MetricSpec for the Horizontal Pod Autoscaler func (s *cassandraScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpec { externalMetric := &v2.ExternalMetricSource{ Metric: v2.MetricIdentifier{ - Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, kedautil.NormalizeString(fmt.Sprintf("cassandra-%s", s.metadata.keyspace))), + Name: GenerateMetricNameWithIndex(s.metadata.TriggerIndex, kedautil.NormalizeString(fmt.Sprintf("cassandra-%s", s.metadata.Keyspace))), }, - Target: GetMetricTarget(s.metricType, s.metadata.targetQueryValue), - } - metricSpec := v2.MetricSpec{ - External: externalMetric, Type: externalMetricType, + Target: GetMetricTarget(s.metricType, s.metadata.TargetQueryValue), } - + metricSpec := v2.MetricSpec{External: externalMetric, Type: externalMetricType} return []v2.MetricSpec{metricSpec} } -// GetMetricsAndActivity returns a value for a supported metric or an error if there is a problem getting the metric. +// GetMetricsAndActivity returns a value for a supported metric or an error if there is a problem getting the metric func (s *cassandraScaler) GetMetricsAndActivity(ctx context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) { num, err := s.GetQueryResult(ctx) if err != nil { @@ -301,38 +210,36 @@ func (s *cassandraScaler) GetMetricsAndActivity(ctx context.Context, metricName } metric := GenerateMetricInMili(metricName, float64(num)) - - return []external_metrics.ExternalMetricValue{metric}, num > s.metadata.activationTargetQueryValue, nil + return []external_metrics.ExternalMetricValue{metric}, num > s.metadata.ActivationTargetQueryValue, nil } -// GetQueryResult returns the result of the scaler query. +// GetQueryResult returns the result of the scaler query func (s *cassandraScaler) GetQueryResult(ctx context.Context) (int64, error) { var value int64 - if err := s.session.Query(s.metadata.query).WithContext(ctx).Scan(&value); err != nil { + if err := s.session.Query(s.metadata.Query).WithContext(ctx).Scan(&value); err != nil { if err != gocql.ErrNotFound { s.logger.Error(err, "query failed") return 0, err } } - return value, nil } -// Close closes the Cassandra session connection. +// Close closes the Cassandra session connection func (s *cassandraScaler) Close(_ context.Context) error { // clean up any temporary files - if strings.TrimSpace(s.metadata.cert) != "" { - if err := os.Remove(s.metadata.cert); err != nil { + if s.metadata.Cert != "" { + if err := os.Remove(s.metadata.Cert); err != nil { return err } } - if strings.TrimSpace(s.metadata.key) != "" { - if err := os.Remove(s.metadata.key); err != nil { + if s.metadata.Key != "" { + if err := os.Remove(s.metadata.Key); err != nil { return err } } - if strings.TrimSpace(s.metadata.ca) != "" { - if err := os.Remove(s.metadata.ca); err != nil { + if s.metadata.CA != "" { + if err := os.Remove(s.metadata.CA); err != nil { return err } } diff --git a/pkg/scalers/cassandra_scaler_test.go b/pkg/scalers/cassandra_scaler_test.go index 39930946a56..d2e892b8c32 100644 --- a/pkg/scalers/cassandra_scaler_test.go +++ b/pkg/scalers/cassandra_scaler_test.go @@ -2,156 +2,318 @@ package scalers import ( "context" - "fmt" "os" "testing" "github.com/go-logr/logr" "github.com/gocql/gocql" + "github.com/stretchr/testify/assert" + v2 "k8s.io/api/autoscaling/v2" "github.com/kedacore/keda/v2/pkg/scalers/scalersconfig" ) type parseCassandraMetadataTestData struct { + name string metadata map[string]string - isError bool authParams map[string]string + isError bool } type parseCassandraTLSTestData struct { + name string authParams map[string]string isError bool - enableTLS bool + tlsEnabled bool } type cassandraMetricIdentifier struct { + name string metadataTestData *parseCassandraMetadataTestData triggerIndex int - name string + metricName string } var testCassandraMetadata = []parseCassandraMetadataTestData{ - // nothing passed - {map[string]string{}, true, map[string]string{}}, - // everything is passed in verbatim - {map[string]string{"query": "SELECT COUNT(*) FROM test_keyspace.test_table;", "targetQueryValue": "1", "username": "cassandra", "port": "9042", "clusterIPAddress": "cassandra.test", "keyspace": "test_keyspace", "TriggerIndex": "0"}, false, map[string]string{"password": "Y2Fzc2FuZHJhCg=="}}, - // metricName is generated from keyspace - {map[string]string{"query": "SELECT COUNT(*) FROM test_keyspace.test_table;", "targetQueryValue": "1", "username": "cassandra", "clusterIPAddress": "cassandra.test:9042", "keyspace": "test_keyspace", "TriggerIndex": "0"}, false, map[string]string{"password": "Y2Fzc2FuZHJhCg=="}}, - // no query passed - {map[string]string{"targetQueryValue": "1", "username": "cassandra", "clusterIPAddress": "cassandra.test:9042", "keyspace": "test_keyspace", "TriggerIndex": "0"}, true, map[string]string{"password": "Y2Fzc2FuZHJhCg=="}}, - // no targetQueryValue passed - {map[string]string{"query": "SELECT COUNT(*) FROM test_keyspace.test_table;", "username": "cassandra", "clusterIPAddress": "cassandra.test:9042", "keyspace": "test_keyspace", "TriggerIndex": "0"}, true, map[string]string{"password": "Y2Fzc2FuZHJhCg=="}}, - // no username passed - {map[string]string{"query": "SELECT COUNT(*) FROM test_keyspace.test_table;", "targetQueryValue": "1", "clusterIPAddress": "cassandra.test:9042", "keyspace": "test_keyspace", "TriggerIndex": "0"}, true, map[string]string{"password": "Y2Fzc2FuZHJhCg=="}}, - // no port passed - {map[string]string{"query": "SELECT COUNT(*) FROM test_keyspace.test_table;", "targetQueryValue": "1", "username": "cassandra", "clusterIPAddress": "cassandra.test", "keyspace": "test_keyspace", "TriggerIndex": "0"}, true, map[string]string{"password": "Y2Fzc2FuZHJhCg=="}}, - // no clusterIPAddress passed - {map[string]string{"query": "SELECT COUNT(*) FROM test_keyspace.test_table;", "targetQueryValue": "1", "username": "cassandra", "port": "9042", "keyspace": "test_keyspace", "TriggerIndex": "0"}, true, map[string]string{"password": "Y2Fzc2FuZHJhCg=="}}, - // no keyspace passed - {map[string]string{"query": "SELECT COUNT(*) FROM test_keyspace.test_table;", "targetQueryValue": "1", "username": "cassandra", "clusterIPAddress": "cassandra.test:9042", "TriggerIndex": "0"}, true, map[string]string{"password": "Y2Fzc2FuZHJhCg=="}}, - // no password passed - {map[string]string{"query": "SELECT COUNT(*) FROM test_keyspace.test_table;", "targetQueryValue": "1", "username": "cassandra", "clusterIPAddress": "cassandra.test:9042", "keyspace": "test_keyspace", "TriggerIndex": "0"}, true, map[string]string{}}, - // fix issue[4110] passed - {map[string]string{"query": "SELECT COUNT(*) FROM test_keyspace.test_table;", "targetQueryValue": "1", "username": "cassandra", "port": "9042", "clusterIPAddress": "https://cassandra.test", "keyspace": "test_keyspace", "TriggerIndex": "0"}, false, map[string]string{"password": "Y2Fzc2FuZHJhCg=="}}, + { + name: "nothing passed", + metadata: map[string]string{}, + authParams: map[string]string{}, + isError: true, + }, + { + name: "everything passed verbatim", + metadata: map[string]string{ + "query": "SELECT COUNT(*) FROM test_keyspace.test_table;", + "targetQueryValue": "1", + "username": "cassandra", + "port": "9042", + "clusterIPAddress": "cassandra.test", + "keyspace": "test_keyspace", + }, + authParams: map[string]string{"password": "Y2Fzc2FuZHJhCg=="}, + isError: false, + }, + { + name: "metricName from keyspace", + metadata: map[string]string{ + "query": "SELECT COUNT(*) FROM test_keyspace.test_table;", + "targetQueryValue": "1", + "username": "cassandra", + "clusterIPAddress": "cassandra.test:9042", + "keyspace": "test_keyspace", + }, + authParams: map[string]string{"password": "Y2Fzc2FuZHJhCg=="}, + isError: false, + }, + { + name: "no query", + metadata: map[string]string{ + "targetQueryValue": "1", + "username": "cassandra", + "clusterIPAddress": "cassandra.test:9042", + "keyspace": "test_keyspace", + }, + authParams: map[string]string{"password": "Y2Fzc2FuZHJhCg=="}, + isError: true, + }, + { + name: "no targetQueryValue", + metadata: map[string]string{ + "query": "SELECT COUNT(*) FROM test_keyspace.test_table;", + "username": "cassandra", + "clusterIPAddress": "cassandra.test:9042", + "keyspace": "test_keyspace", + }, + authParams: map[string]string{"password": "Y2Fzc2FuZHJhCg=="}, + isError: true, + }, + { + name: "no username", + metadata: map[string]string{ + "query": "SELECT COUNT(*) FROM test_keyspace.test_table;", + "targetQueryValue": "1", + "clusterIPAddress": "cassandra.test:9042", + "keyspace": "test_keyspace", + }, + authParams: map[string]string{"password": "Y2Fzc2FuZHJhCg=="}, + isError: true, + }, + { + name: "no port", + metadata: map[string]string{ + "query": "SELECT COUNT(*) FROM test_keyspace.test_table;", + "targetQueryValue": "1", + "username": "cassandra", + "clusterIPAddress": "cassandra.test", + "keyspace": "test_keyspace", + }, + authParams: map[string]string{"password": "Y2Fzc2FuZHJhCg=="}, + isError: true, + }, + { + name: "no clusterIPAddress", + metadata: map[string]string{ + "query": "SELECT COUNT(*) FROM test_keyspace.test_table;", + "targetQueryValue": "1", + "username": "cassandra", + "port": "9042", + "keyspace": "test_keyspace", + }, + authParams: map[string]string{"password": "Y2Fzc2FuZHJhCg=="}, + isError: true, + }, + { + name: "no keyspace", + metadata: map[string]string{ + "query": "SELECT COUNT(*) FROM test_keyspace.test_table;", + "targetQueryValue": "1", + "username": "cassandra", + "clusterIPAddress": "cassandra.test:9042", + }, + authParams: map[string]string{"password": "Y2Fzc2FuZHJhCg=="}, + isError: true, + }, + { + name: "no password", + metadata: map[string]string{ + "query": "SELECT COUNT(*) FROM test_keyspace.test_table;", + "targetQueryValue": "1", + "username": "cassandra", + "clusterIPAddress": "cassandra.test:9042", + "keyspace": "test_keyspace", + }, + authParams: map[string]string{}, + isError: true, + }, + { + name: "with https prefix", + metadata: map[string]string{ + "query": "SELECT COUNT(*) FROM test_keyspace.test_table;", + "targetQueryValue": "1", + "username": "cassandra", + "port": "9042", + "clusterIPAddress": "https://cassandra.test", + "keyspace": "test_keyspace", + }, + authParams: map[string]string{"password": "Y2Fzc2FuZHJhCg=="}, + isError: false, + }, } var tlsAuthParamsTestData = []parseCassandraTLSTestData{ - // success, TLS cert/key - {map[string]string{"tls": "enable", "cert": "ceert", "key": "keey", "password": "Y2Fzc2FuZHJhCg=="}, false, true}, - // failure, TLS missing cert - {map[string]string{"tls": "enable", "key": "keey", "password": "Y2Fzc2FuZHJhCg=="}, true, false}, - // failure, TLS missing key - {map[string]string{"tls": "enable", "cert": "ceert", "password": "Y2Fzc2FuZHJhCg=="}, true, false}, - // failure, TLS invalid - {map[string]string{"tls": "yes", "cert": "ceert", "key": "keeey", "password": "Y2Fzc2FuZHJhCg=="}, true, false}, + { + name: "success with cert/key", + authParams: map[string]string{ + "tls": "enable", + "cert": "test-cert", + "key": "test-key", + "password": "Y2Fzc2FuZHJhCg==", + }, + isError: false, + tlsEnabled: true, + }, + { + name: "failure missing cert", + authParams: map[string]string{ + "tls": "enable", + "key": "test-key", + "password": "Y2Fzc2FuZHJhCg==", + }, + isError: true, + tlsEnabled: false, + }, + { + name: "failure missing key", + authParams: map[string]string{ + "tls": "enable", + "cert": "test-cert", + "password": "Y2Fzc2FuZHJhCg==", + }, + isError: true, + tlsEnabled: false, + }, + { + name: "failure invalid tls value", + authParams: map[string]string{ + "tls": "yes", + "cert": "test-cert", + "key": "test-key", + "password": "Y2Fzc2FuZHJhCg==", + }, + isError: true, + tlsEnabled: false, + }, } var cassandraMetricIdentifiers = []cassandraMetricIdentifier{ - {&testCassandraMetadata[1], 0, "s0-cassandra-test_keyspace"}, - {&testCassandraMetadata[2], 1, "s1-cassandra-test_keyspace"}, + { + name: "everything passed verbatim", + metadataTestData: &testCassandraMetadata[1], + triggerIndex: 0, + metricName: "s0-cassandra-test_keyspace", + }, + { + name: "metricName from keyspace", + metadataTestData: &testCassandraMetadata[2], + triggerIndex: 1, + metricName: "s1-cassandra-test_keyspace", + }, +} + +var successMetaData = map[string]string{ + "query": "SELECT COUNT(*) FROM test_keyspace.test_table;", + "targetQueryValue": "1", + "username": "cassandra", + "clusterIPAddress": "cassandra.test:9042", + "keyspace": "test_keyspace", } func TestCassandraParseMetadata(t *testing.T) { - testCaseNum := 1 for _, testData := range testCassandraMetadata { - _, err := parseCassandraMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: testData.metadata, AuthParams: testData.authParams}) - if err != nil && !testData.isError { - t.Errorf("Expected success but got error for unit test # %v", testCaseNum) - } - if testData.isError && err == nil { - t.Errorf("Expected error but got success for unit test # %v", testCaseNum) - } - testCaseNum++ + t.Run(testData.name, func(t *testing.T) { + _, err := parseCassandraMetadata(&scalersconfig.ScalerConfig{ + TriggerMetadata: testData.metadata, + AuthParams: testData.authParams, + }) + if err != nil && !testData.isError { + t.Error("Expected success but got error", err) + } + if testData.isError && err == nil { + t.Error("Expected error but got success") + } + }) } } func TestCassandraGetMetricSpecForScaling(t *testing.T) { for _, testData := range cassandraMetricIdentifiers { - meta, err := parseCassandraMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: testData.metadataTestData.metadata, TriggerIndex: testData.triggerIndex, AuthParams: testData.metadataTestData.authParams}) - if err != nil { - t.Fatal("Could not parse metadata:", err) - } - cluster := gocql.NewCluster(meta.clusterIPAddress) - session, _ := cluster.CreateSession() - mockCassandraScaler := cassandraScaler{"", meta, session, logr.Discard()} - - metricSpec := mockCassandraScaler.GetMetricSpecForScaling(context.Background()) - metricName := metricSpec[0].External.Metric.Name - if metricName != testData.name { - t.Errorf("Wrong External metric source name: %s, expected: %s", metricName, testData.name) - } - } -} + t.Run(testData.name, func(t *testing.T) { + meta, err := parseCassandraMetadata(&scalersconfig.ScalerConfig{ + TriggerMetadata: testData.metadataTestData.metadata, + TriggerIndex: testData.triggerIndex, + AuthParams: testData.metadataTestData.authParams, + }) + if err != nil { + t.Fatal("Could not parse metadata:", err) + } + mockCassandraScaler := cassandraScaler{ + metricType: v2.AverageValueMetricType, + metadata: meta, + session: &gocql.Session{}, + logger: logr.Discard(), + } -func assertCertContents(testData parseCassandraTLSTestData, meta *CassandraMetadata, prop string) error { - if testData.authParams[prop] != "" { - var path string - switch prop { - case "cert": - path = meta.cert - case "key": - path = meta.key - } - data, err := os.ReadFile(path) - if err != nil { - return fmt.Errorf("expected to find '%v' file at %v", prop, path) - } - contents := string(data) - if contents != testData.authParams[prop] { - return fmt.Errorf("expected value: '%v' but got '%v'", testData.authParams[prop], contents) - } + metricSpec := mockCassandraScaler.GetMetricSpecForScaling(context.Background()) + metricName := metricSpec[0].External.Metric.Name + assert.Equal(t, testData.metricName, metricName) + }) } - return nil } -var successMetaData = map[string]string{"query": "SELECT COUNT(*) FROM test_keyspace.test_table;", "targetQueryValue": "1", "username": "cassandra", "clusterIPAddress": "cassandra.test:9042", "keyspace": "test_keyspace", "TriggerIndex": "0"} - func TestParseCassandraTLS(t *testing.T) { for _, testData := range tlsAuthParamsTestData { - meta, err := parseCassandraMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: successMetaData, AuthParams: testData.authParams}) - - if err != nil && !testData.isError { - t.Error("Expected success but got error", err) - } - if testData.isError && err == nil { - t.Error("Expected error but got success") - } - if meta.enableTLS != testData.enableTLS { - t.Errorf("Expected enableTLS to be set to %v but got %v\n", testData.enableTLS, meta.enableTLS) - } - if meta.enableTLS { - if meta.cert != testData.authParams["cert"] { - err := assertCertContents(testData, meta, "cert") - if err != nil { - t.Errorf(err.Error()) - } - } - if meta.key != testData.authParams["key"] { - err := assertCertContents(testData, meta, "key") - if err != nil { - t.Errorf(err.Error()) + t.Run(testData.name, func(t *testing.T) { + meta, err := parseCassandraMetadata(&scalersconfig.ScalerConfig{ + TriggerMetadata: successMetaData, + AuthParams: testData.authParams, + }) + + if testData.isError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, testData.tlsEnabled, meta.TLS == "enable") + + if meta.TLS == "enable" { + // Verify cert contents + if testData.authParams["cert"] != "" { + data, err := os.ReadFile(meta.Cert) + assert.NoError(t, err) + assert.Equal(t, testData.authParams["cert"], string(data)) + // Cleanup + defer os.Remove(meta.Cert) + } + + // Verify key contents + if testData.authParams["key"] != "" { + data, err := os.ReadFile(meta.Key) + assert.NoError(t, err) + assert.Equal(t, testData.authParams["key"], string(data)) + // Cleanup + defer os.Remove(meta.Key) + } + + // Verify CA contents if present + if testData.authParams["ca"] != "" { + data, err := os.ReadFile(meta.CA) + assert.NoError(t, err) + assert.Equal(t, testData.authParams["ca"], string(data)) + // Cleanup + defer os.Remove(meta.CA) + } } } - } + }) } } From 0e7801d834f679bf693dc835dcd441d2865cbff1 Mon Sep 17 00:00:00 2001 From: rickbrouwer Date: Sun, 3 Nov 2024 12:39:03 +0100 Subject: [PATCH 5/7] Refactor mongo scaler (#6261) Signed-off-by: rickbrouwer --- pkg/scalers/mongo_scaler.go | 251 ++++++++++--------------------- pkg/scalers/mongo_scaler_test.go | 15 +- 2 files changed, 87 insertions(+), 179 deletions(-) diff --git a/pkg/scalers/mongo_scaler.go b/pkg/scalers/mongo_scaler.go index f30b8fb97ec..25d2a62ede9 100644 --- a/pkg/scalers/mongo_scaler.go +++ b/pkg/scalers/mongo_scaler.go @@ -6,8 +6,6 @@ import ( "fmt" "net" "net/url" - "strconv" - "strings" "time" "github.com/go-logr/logr" @@ -22,60 +20,45 @@ import ( kedautil "github.com/kedacore/keda/v2/pkg/util" ) -// mongoDBScaler is support for mongoDB in keda. type mongoDBScaler struct { metricType v2.MetricTargetType - metadata *mongoDBMetadata + metadata mongoDBMetadata client *mongo.Client logger logr.Logger } -// mongoDBMetadata specify mongoDB scaler params. type mongoDBMetadata struct { - // The string is used by connected with mongoDB. - // +optional - connectionString string - // Specify the prefix to connect to the mongoDB server, default value `mongodb`, if the connectionString be provided, don't need to specify this param. - // +optional - scheme string - // Specify the host to connect to the mongoDB server,if the connectionString be provided, don't need to specify this param. - // +optional - host string - // Specify the port to connect to the mongoDB server,if the connectionString be provided, don't need to specify this param. - // +optional - port string - // Specify the username to connect to the mongoDB server,if the connectionString be provided, don't need to specify this param. - // +optional - username string - // Specify the password to connect to the mongoDB server,if the connectionString be provided, don't need to specify this param. - // +optional - password string - - // The name of the database to be queried. - // +required - dbName string - // The name of the collection to be queried. - // +required - collection string - // A mongoDB filter doc,used by specify DB. - // +required - query string - // A threshold that is used as targetAverageValue in HPA - // +required - queryValue int64 - // A threshold that is used to check if scaler is active - // +optional - activationQueryValue int64 - - // The index of the scaler inside the ScaledObject - // +internal - triggerIndex int + ConnectionString string `keda:"name=connectionString,order=authParams;triggerMetadata;resolvedEnv,optional"` + Scheme string `keda:"name=scheme,order=authParams;triggerMetadata,default=mongodb,optional"` + Host string `keda:"name=host,order=authParams;triggerMetadata,optional"` + Port string `keda:"name=port,order=authParams;triggerMetadata,optional"` + Username string `keda:"name=username,order=authParams;triggerMetadata,optional"` + Password string `keda:"name=password,order=authParams;triggerMetadata;resolvedEnv,optional"` + DBName string `keda:"name=dbName,order=authParams;triggerMetadata"` + Collection string `keda:"name=collection,order=triggerMetadata"` + Query string `keda:"name=query,order=triggerMetadata"` + QueryValue int64 `keda:"name=queryValue,order=triggerMetadata"` + ActivationQueryValue int64 `keda:"name=activationQueryValue,order=triggerMetadata,default=0"` + TriggerIndex int } -// Default variables and settings -const ( - mongoDBDefaultTimeOut = 10 * time.Second -) +func (m *mongoDBMetadata) Validate() error { + if m.ConnectionString == "" { + if m.Host == "" { + return fmt.Errorf("no host given") + } + if m.Port == "" && m.Scheme != "mongodb+srv" { + return fmt.Errorf("no port given") + } + if m.Username == "" { + return fmt.Errorf("no username given") + } + if m.Password == "" { + return fmt.Errorf("no password given") + } + } + return nil +} // NewMongoDBScaler creates a new mongoDB scaler func NewMongoDBScaler(ctx context.Context, config *scalersconfig.ScalerConfig) (Scaler, error) { @@ -84,22 +67,14 @@ func NewMongoDBScaler(ctx context.Context, config *scalersconfig.ScalerConfig) ( return nil, fmt.Errorf("error getting scaler metric type: %w", err) } - ctx, cancel := context.WithTimeout(ctx, mongoDBDefaultTimeOut) - defer cancel() - - meta, connStr, err := parseMongoDBMetadata(config) + meta, err := parseMongoDBMetadata(config) if err != nil { - return nil, fmt.Errorf("failed to parsing mongoDB metadata, because of %w", err) + return nil, fmt.Errorf("error parsing mongodb metadata: %w", err) } - opt := options.Client().ApplyURI(connStr) - client, err := mongo.Connect(ctx, opt) + client, err := createMongoDBClient(ctx, meta) if err != nil { - return nil, fmt.Errorf("failed to establish connection with mongoDB, because of %w", err) - } - - if err = client.Ping(ctx, readpref.Primary()); err != nil { - return nil, fmt.Errorf("failed to ping mongoDB, because of %w", err) + return nil, fmt.Errorf("error creating mongodb client: %w", err) } return &mongoDBScaler{ @@ -110,171 +85,101 @@ func NewMongoDBScaler(ctx context.Context, config *scalersconfig.ScalerConfig) ( }, nil } -func parseMongoDBMetadata(config *scalersconfig.ScalerConfig) (*mongoDBMetadata, string, error) { - var connStr string - var err error - // setting default metadata +func parseMongoDBMetadata(config *scalersconfig.ScalerConfig) (mongoDBMetadata, error) { meta := mongoDBMetadata{} - - // parse metaData from ScaledJob config - if val, ok := config.TriggerMetadata["collection"]; ok { - meta.collection = val - } else { - return nil, "", fmt.Errorf("no collection given") + err := config.TypedConfig(&meta) + if err != nil { + return meta, fmt.Errorf("error parsing mongodb metadata: %w", err) } - if val, ok := config.TriggerMetadata["query"]; ok { - meta.query = val - } else { - return nil, "", fmt.Errorf("no query given") - } + meta.TriggerIndex = config.TriggerIndex + return meta, nil +} - if val, ok := config.TriggerMetadata["queryValue"]; ok { - queryValue, err := strconv.ParseInt(val, 10, 64) - if err != nil { - return nil, "", fmt.Errorf("failed to convert %v to int, because of %w", val, err) - } - meta.queryValue = queryValue +func createMongoDBClient(ctx context.Context, meta mongoDBMetadata) (*mongo.Client, error) { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + var connString string + if meta.ConnectionString != "" { + connString = meta.ConnectionString } else { - if config.AsMetricSource { - meta.queryValue = 0 - } else { - return nil, "", fmt.Errorf("no queryValue given") + host := meta.Host + if meta.Scheme != "mongodb+srv" { + host = net.JoinHostPort(meta.Host, meta.Port) } - } - - meta.activationQueryValue = 0 - if val, ok := config.TriggerMetadata["activationQueryValue"]; ok { - activationQueryValue, err := strconv.ParseInt(val, 10, 64) - if err != nil { - return nil, "", fmt.Errorf("failed to convert %v to int, because of %w", val, err) + u := &url.URL{ + Scheme: meta.Scheme, + User: url.UserPassword(meta.Username, meta.Password), + Host: host, + Path: meta.DBName, } - meta.activationQueryValue = activationQueryValue + connString = u.String() } - dbName, err := GetFromAuthOrMeta(config, "dbName") + client, err := mongo.Connect(ctx, options.Client().ApplyURI(connString)) if err != nil { - return nil, "", err + return nil, fmt.Errorf("failed to create mongodb client: %w", err) } - meta.dbName = dbName - - // Resolve connectionString - switch { - case config.AuthParams["connectionString"] != "": - meta.connectionString = config.AuthParams["connectionString"] - case config.TriggerMetadata["connectionStringFromEnv"] != "": - meta.connectionString = config.ResolvedEnv[config.TriggerMetadata["connectionStringFromEnv"]] - default: - meta.connectionString = "" - scheme, err := GetFromAuthOrMeta(config, "scheme") - if err != nil { - meta.scheme = "mongodb" - } else { - meta.scheme = scheme - } - - host, err := GetFromAuthOrMeta(config, "host") - if err != nil { - return nil, "", err - } - meta.host = host - - if !strings.Contains(scheme, "mongodb+srv") { - port, err := GetFromAuthOrMeta(config, "port") - if err != nil { - return nil, "", err - } - meta.port = port - } - username, err := GetFromAuthOrMeta(config, "username") - if err != nil { - return nil, "", err - } - meta.username = username - - if config.AuthParams["password"] != "" { - meta.password = config.AuthParams["password"] - } else if config.TriggerMetadata["passwordFromEnv"] != "" { - meta.password = config.ResolvedEnv[config.TriggerMetadata["passwordFromEnv"]] - } - if len(meta.password) == 0 { - return nil, "", fmt.Errorf("no password given") - } - } - - switch { - case meta.connectionString != "": - connStr = meta.connectionString - case meta.scheme == "mongodb+srv": - // nosemgrep: db-connection-string - connStr = fmt.Sprintf("%s://%s:%s@%s/%s", meta.scheme, url.QueryEscape(meta.username), url.QueryEscape(meta.password), meta.host, meta.dbName) - default: - addr := net.JoinHostPort(meta.host, meta.port) - // nosemgrep: db-connection-string - connStr = fmt.Sprintf("%s://%s:%s@%s/%s", meta.scheme, url.QueryEscape(meta.username), url.QueryEscape(meta.password), addr, meta.dbName) + err = client.Ping(ctx, readpref.Primary()) + if err != nil { + return nil, fmt.Errorf("failed to ping mongodb: %w", err) } - meta.triggerIndex = config.TriggerIndex - return &meta, connStr, nil + return client, nil } -// Close disposes of mongoDB connections func (s *mongoDBScaler) Close(ctx context.Context) error { if s.client != nil { err := s.client.Disconnect(ctx) if err != nil { - s.logger.Error(err, fmt.Sprintf("failed to close mongoDB connection, because of %v", err)) + s.logger.Error(err, "Error closing mongodb connection") return err } } - return nil } -// getQueryResult query mongoDB by meta.query func (s *mongoDBScaler) getQueryResult(ctx context.Context) (int64, error) { - ctx, cancel := context.WithTimeout(ctx, mongoDBDefaultTimeOut) + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - filter, err := json2BsonDoc(s.metadata.query) + collection := s.client.Database(s.metadata.DBName).Collection(s.metadata.Collection) + + filter, err := json2BsonDoc(s.metadata.Query) if err != nil { - s.logger.Error(err, fmt.Sprintf("failed to convert query param to bson.Doc, because of %v", err)) - return 0, err + return 0, fmt.Errorf("failed to parse query: %w", err) } - docsNum, err := s.client.Database(s.metadata.dbName).Collection(s.metadata.collection).CountDocuments(ctx, filter) + count, err := collection.CountDocuments(ctx, filter) if err != nil { - s.logger.Error(err, fmt.Sprintf("failed to query %v in %v, because of %v", s.metadata.dbName, s.metadata.collection, err)) - return 0, err + return 0, fmt.Errorf("failed to execute query: %w", err) } - return docsNum, nil + return count, nil } -// GetMetricsAndActivity query from mongoDB,and return to external metrics func (s *mongoDBScaler) GetMetricsAndActivity(ctx context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) { num, err := s.getQueryResult(ctx) if err != nil { - return []external_metrics.ExternalMetricValue{}, false, fmt.Errorf("failed to inspect momgoDB, because of %w", err) + return []external_metrics.ExternalMetricValue{}, false, fmt.Errorf("failed to inspect mongodb: %w", err) } metric := GenerateMetricInMili(metricName, float64(num)) - return []external_metrics.ExternalMetricValue{metric}, num > s.metadata.activationQueryValue, nil + return []external_metrics.ExternalMetricValue{metric}, num > s.metadata.ActivationQueryValue, nil } -// GetMetricSpecForScaling get the query value for scaling func (s *mongoDBScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpec { + metricName := kedautil.NormalizeString(fmt.Sprintf("mongodb-%s", s.metadata.Collection)) externalMetric := &v2.ExternalMetricSource{ Metric: v2.MetricIdentifier{ - Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, kedautil.NormalizeString(fmt.Sprintf("mongodb-%s", s.metadata.collection))), + Name: GenerateMetricNameWithIndex(s.metadata.TriggerIndex, metricName), }, - Target: GetMetricTarget(s.metricType, s.metadata.queryValue), - } - metricSpec := v2.MetricSpec{ - External: externalMetric, Type: externalMetricType, + Target: GetMetricTarget(s.metricType, s.metadata.QueryValue), } + metricSpec := v2.MetricSpec{External: externalMetric, Type: externalMetricType} return []v2.MetricSpec{metricSpec} } diff --git a/pkg/scalers/mongo_scaler_test.go b/pkg/scalers/mongo_scaler_test.go index fd9f54f8337..c749b9f7ae4 100644 --- a/pkg/scalers/mongo_scaler_test.go +++ b/pkg/scalers/mongo_scaler_test.go @@ -5,8 +5,8 @@ import ( "testing" "github.com/go-logr/logr" - "github.com/stretchr/testify/assert" "go.mongodb.org/mongo-driver/mongo" + v2 "k8s.io/api/autoscaling/v2" "github.com/kedacore/keda/v2/pkg/scalers/scalersconfig" ) @@ -100,7 +100,7 @@ var mongoDBMetricIdentifiers = []mongoDBMetricIdentifier{ func TestParseMongoDBMetadata(t *testing.T) { for _, testData := range testMONGODBMetadata { - _, _, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{ResolvedEnv: testData.resolvedEnv, TriggerMetadata: testData.metadata, AuthParams: testData.authParams}) + _, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{ResolvedEnv: testData.resolvedEnv, TriggerMetadata: testData.metadata, AuthParams: testData.authParams}) if err != nil && !testData.raisesError { t.Error("Expected success but got error:", err) } @@ -112,21 +112,24 @@ func TestParseMongoDBMetadata(t *testing.T) { func TestParseMongoDBConnectionString(t *testing.T) { for _, testData := range mongoDBConnectionStringTestDatas { - _, connStr, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{ResolvedEnv: testData.metadataTestData.resolvedEnv, TriggerMetadata: testData.metadataTestData.metadata, AuthParams: testData.metadataTestData.authParams}) + _, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{ + ResolvedEnv: testData.metadataTestData.resolvedEnv, + TriggerMetadata: testData.metadataTestData.metadata, + AuthParams: testData.metadataTestData.authParams, + }) if err != nil { t.Error("Expected success but got error:", err) } - assert.Equal(t, testData.connectionString, connStr) } } func TestMongoDBGetMetricSpecForScaling(t *testing.T) { for _, testData := range mongoDBMetricIdentifiers { - meta, _, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{ResolvedEnv: testData.metadataTestData.resolvedEnv, AuthParams: testData.metadataTestData.authParams, TriggerMetadata: testData.metadataTestData.metadata, TriggerIndex: testData.triggerIndex}) + meta, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{ResolvedEnv: testData.metadataTestData.resolvedEnv, AuthParams: testData.metadataTestData.authParams, TriggerMetadata: testData.metadataTestData.metadata, TriggerIndex: testData.triggerIndex}) if err != nil { t.Fatal("Could not parse metadata:", err) } - mockMongoDBScaler := mongoDBScaler{"", meta, &mongo.Client{}, logr.Discard()} + mockMongoDBScaler := mongoDBScaler{metricType: v2.AverageValueMetricType, metadata: meta, client: &mongo.Client{}, logger: logr.Discard()} metricSpec := mockMongoDBScaler.GetMetricSpecForScaling(context.Background()) metricName := metricSpec[0].External.Metric.Name From 977cab863c1ec50430f3855c9ab552a0631bde41 Mon Sep 17 00:00:00 2001 From: Rushen Wang <45029442+dovics@users.noreply.github.com> Date: Sun, 3 Nov 2024 19:40:14 +0800 Subject: [PATCH 6/7] Refactor rabbitmq scaler config (#6257) Signed-off-by: wangrushen --- pkg/scalers/rabbitmq_scaler.go | 536 ++++++++++------------------ pkg/scalers/rabbitmq_scaler_test.go | 64 ++-- 2 files changed, 221 insertions(+), 379 deletions(-) diff --git a/pkg/scalers/rabbitmq_scaler.go b/pkg/scalers/rabbitmq_scaler.go index 860228d3a99..f7ed365f52d 100644 --- a/pkg/scalers/rabbitmq_scaler.go +++ b/pkg/scalers/rabbitmq_scaler.go @@ -8,9 +8,7 @@ import ( "net/http" "net/url" "path" - "reflect" "regexp" - "strconv" "strings" "time" @@ -36,12 +34,14 @@ const ( rabbitModeTriggerConfigName = "mode" rabbitValueTriggerConfigName = "value" rabbitActivationValueTriggerConfigName = "activationValue" + rabbitModeUnknown = "Unknown" rabbitModeQueueLength = "QueueLength" rabbitModeMessageRate = "MessageRate" defaultRabbitMQQueueLength = 20 rabbitMetricType = "External" rabbitRootVhostPath = "/%2F" rmqTLSEnable = "enable" + rmqTLSDisable = "disable" ) const ( @@ -69,37 +69,155 @@ type rabbitMQScaler struct { } type rabbitMQMetadata struct { - queueName string - connectionName string // name used for the AMQP connection - mode string // QueueLength or MessageRate - value float64 // trigger value (queue length or publish/sec. rate) - activationValue float64 // activation value - host string // connection string for either HTTP or AMQP protocol - protocol string // either http or amqp protocol - vhostName string // override the vhost from the connection info - useRegex bool // specify if the queueName contains a rexeg - excludeUnacknowledged bool // specify if the QueueLength value should exclude Unacknowledged messages (Ready messages only) - pageSize int64 // specify the page size if useRegex is enabled - operation string // specify the operation to apply in case of multiples queues - timeout time.Duration // custom http timeout for a specific trigger - triggerIndex int // scaler index - - username string - password string + connectionName string // name used for the AMQP connection + triggerIndex int // scaler index + + QueueName string `keda:"name=queueName, order=triggerMetadata"` + // QueueLength or MessageRate + Mode string `keda:"name=mode, order=triggerMetadata, optional, default=Unknown"` + // + QueueLength float64 `keda:"name=queueLength, order=triggerMetadata, optional"` + // trigger value (queue length or publish/sec. rate) + Value float64 `keda:"name=value, order=triggerMetadata, optional"` + // activation value + ActivationValue float64 `keda:"name=activationValue, order=triggerMetadata, optional"` + // connection string for either HTTP or AMQP protocol + Host string `keda:"name=host, order=triggerMetadata;authParams;resolvedEnv"` + // either http or amqp protocol + Protocol string `keda:"name=protocol, order=triggerMetadata;authParams, optional, default=auto"` + // override the vhost from the connection info + VhostName string `keda:"name=vhostName, order=triggerMetadata, optional"` + // specify if the queueName contains a rexeg + UseRegex bool `keda:"name=useRegex, order=triggerMetadata, optional"` + // specify if the QueueLength value should exclude Unacknowledged messages (Ready messages only) + ExcludeUnacknowledged bool `keda:"name=excludeUnacknowledged, order=triggerMetadata, optional"` + // specify the page size if useRegex is enabled + PageSize int64 `keda:"name=pageSize, order=triggerMetadata, optional, default=100"` + // specify the operation to apply in case of multiples queues + Operation string `keda:"name=operation, order=triggerMetadata, optional, default=sum"` + // custom http timeout for a specific trigger + TimeoutMs int `keda:"name=timeout, order=triggerMetadata, optional"` + + Username string `keda:"name=username, order=authParams;resolvedEnv, optional"` + Password string `keda:"name=password, order=authParams;resolvedEnv, optional"` // TLS - ca string - cert string - key string - keyPassword string - enableTLS bool - unsafeSsl bool + Ca string `keda:"name=ca, order=authParams, optional"` + Cert string `keda:"name=cert, order=authParams, optional"` + Key string `keda:"name=key, order=authParams, optional"` + KeyPassword string `keda:"name=keyPassword, order=authParams, optional"` + EnableTLS string `keda:"name=tls, order=authParams, optional, default=disable"` + UnsafeSsl bool `keda:"name=unsafeSsl, order=triggerMetadata, optional"` // token provider for azure AD + WorkloadIdentityResource string `keda:"name=workloadIdentityResource, order=authParams, optional"` workloadIdentityClientID string workloadIdentityTenantID string workloadIdentityAuthorityHost string - workloadIdentityResource string +} + +func (r *rabbitMQMetadata) Validate() error { + if r.Protocol != amqpProtocol && r.Protocol != httpProtocol && r.Protocol != autoProtocol { + return fmt.Errorf("the protocol has to be either `%s`, `%s`, or `%s` but is `%s`", + amqpProtocol, httpProtocol, autoProtocol, r.Protocol) + } + + if r.EnableTLS != rmqTLSEnable && r.EnableTLS != rmqTLSDisable { + return fmt.Errorf("err incorrect value for TLS given: %s", r.EnableTLS) + } + + certGiven := r.Cert != "" + keyGiven := r.Key != "" + if certGiven != keyGiven { + return fmt.Errorf("both key and cert must be provided") + } + + if r.PageSize < 1 { + return fmt.Errorf("pageSize should be 1 or greater than 1") + } + + if (r.Username != "" || r.Password != "") && (r.Username == "" || r.Password == "") { + return fmt.Errorf("username and password must be given together") + } + + // If the protocol is auto, check the host scheme. + if r.Protocol == autoProtocol { + parsedURL, err := url.Parse(r.Host) + if err != nil { + return fmt.Errorf("can't parse host to find protocol: %w", err) + } + switch parsedURL.Scheme { + case "amqp", "amqps": + r.Protocol = amqpProtocol + case "http", "https": + r.Protocol = httpProtocol + default: + return fmt.Errorf("unknown host URL scheme `%s`", parsedURL.Scheme) + } + } + + if r.Protocol == amqpProtocol && r.WorkloadIdentityResource != "" { + return fmt.Errorf("workload identity is not supported for amqp protocol currently") + } + + if r.UseRegex && r.Protocol != httpProtocol { + return fmt.Errorf("configure only useRegex with http protocol") + } + + if r.ExcludeUnacknowledged && r.Protocol != httpProtocol { + return fmt.Errorf("configure excludeUnacknowledged=true with http protocol only") + } + + if err := r.validateTrigger(); err != nil { + return err + } + + return nil +} + +func (r *rabbitMQMetadata) validateTrigger() error { + // If nothing is specified for the trigger then return the default + if r.QueueLength == 0 && r.Mode == rabbitModeUnknown && r.Value == 0 { + r.Mode = rabbitModeQueueLength + r.Value = defaultRabbitMQQueueLength + return nil + } + + if r.QueueLength != 0 && (r.Mode != rabbitModeUnknown || r.Value != 0) { + return fmt.Errorf("queueLength is deprecated; configure only %s and %s", rabbitModeTriggerConfigName, rabbitValueTriggerConfigName) + } + + if r.QueueLength != 0 { + r.Mode = rabbitModeQueueLength + r.Value = r.QueueLength + + return nil + } + + if r.Mode == rabbitModeUnknown { + return fmt.Errorf("%s must be specified", rabbitModeTriggerConfigName) + } + + if r.Value == 0 { + return fmt.Errorf("%s must be specified", rabbitValueTriggerConfigName) + } + + if r.Mode != rabbitModeQueueLength && r.Mode != rabbitModeMessageRate { + return fmt.Errorf("trigger mode %s must be one of %s, %s", r.Mode, rabbitModeQueueLength, rabbitModeMessageRate) + } + + if r.Mode == rabbitModeMessageRate && r.Protocol != httpProtocol { + return fmt.Errorf("protocol %s not supported; must be http to use mode %s", r.Protocol, rabbitModeMessageRate) + } + + if r.Protocol == amqpProtocol && r.TimeoutMs != 0 { + return fmt.Errorf("amqp protocol doesn't support custom timeouts: %d", r.TimeoutMs) + } + + if r.TimeoutMs < 0 { + return fmt.Errorf("timeout must be greater than 0: %d", r.TimeoutMs) + } + return nil } type queueInfo struct { @@ -139,32 +257,40 @@ func NewRabbitMQScaler(config *scalersconfig.ScalerConfig) (Scaler, error) { if err != nil { return nil, fmt.Errorf("error parsing rabbitmq metadata: %w", err) } + s.metadata = meta - s.httpClient = kedautil.CreateHTTPClient(meta.timeout, meta.unsafeSsl) - if meta.enableTLS { - tlsConfig, tlsErr := kedautil.NewTLSConfigWithPassword(meta.cert, meta.key, meta.keyPassword, meta.ca, meta.unsafeSsl) + var timeout time.Duration + if s.metadata.TimeoutMs != 0 { + timeout = time.Duration(s.metadata.TimeoutMs) * time.Millisecond + } else { + timeout = config.GlobalHTTPTimeout + } + + s.httpClient = kedautil.CreateHTTPClient(timeout, meta.UnsafeSsl) + if meta.EnableTLS == rmqTLSEnable { + tlsConfig, tlsErr := kedautil.NewTLSConfigWithPassword(meta.Cert, meta.Key, meta.KeyPassword, meta.Ca, meta.UnsafeSsl) if tlsErr != nil { return nil, tlsErr } s.httpClient.Transport = kedautil.CreateHTTPTransportWithTLSConfig(tlsConfig) } - if meta.protocol == amqpProtocol { + if meta.Protocol == amqpProtocol { // Override vhost if requested. - host := meta.host - if meta.vhostName != "" || (meta.username != "" && meta.password != "") { + host := meta.Host + if meta.VhostName != "" || (meta.Username != "" && meta.Password != "") { hostURI, err := amqp.ParseURI(host) if err != nil { return nil, fmt.Errorf("error parsing rabbitmq connection string: %w", err) } - if meta.vhostName != "" { - hostURI.Vhost = meta.vhostName + if meta.VhostName != "" { + hostURI.Vhost = meta.VhostName } - if meta.username != "" && meta.password != "" { - hostURI.Username = meta.username - hostURI.Password = meta.password + if meta.Username != "" && meta.Password != "" { + hostURI.Username = meta.Username + hostURI.Password = meta.Password } host = hostURI.String() @@ -181,308 +307,24 @@ func NewRabbitMQScaler(config *scalersconfig.ScalerConfig) (Scaler, error) { return s, nil } -func resolveProtocol(config *scalersconfig.ScalerConfig, meta *rabbitMQMetadata) error { - meta.protocol = defaultProtocol - if val, ok := config.AuthParams["protocol"]; ok { - meta.protocol = val - } - if val, ok := config.TriggerMetadata["protocol"]; ok { - meta.protocol = val - } - if meta.protocol != amqpProtocol && meta.protocol != httpProtocol && meta.protocol != autoProtocol { - return fmt.Errorf("the protocol has to be either `%s`, `%s`, or `%s` but is `%s`", amqpProtocol, httpProtocol, autoProtocol, meta.protocol) - } - return nil -} - -func resolveHostValue(config *scalersconfig.ScalerConfig, meta *rabbitMQMetadata) error { - switch { - case config.AuthParams["host"] != "": - meta.host = config.AuthParams["host"] - case config.TriggerMetadata["host"] != "": - meta.host = config.TriggerMetadata["host"] - case config.TriggerMetadata["hostFromEnv"] != "": - meta.host = config.ResolvedEnv[config.TriggerMetadata["hostFromEnv"]] - default: - return fmt.Errorf("no host setting given") - } - return nil -} - -func resolveTimeout(config *scalersconfig.ScalerConfig, meta *rabbitMQMetadata) error { - if val, ok := config.TriggerMetadata["timeout"]; ok { - timeoutMS, err := strconv.Atoi(val) - if err != nil { - return fmt.Errorf("unable to parse timeout: %w", err) - } - if meta.protocol == amqpProtocol { - return fmt.Errorf("amqp protocol doesn't support custom timeouts: %w", err) - } - if timeoutMS <= 0 { - return fmt.Errorf("timeout must be greater than 0: %w", err) - } - meta.timeout = time.Duration(timeoutMS) * time.Millisecond - } else { - meta.timeout = config.GlobalHTTPTimeout - } - return nil -} - -func resolveTLSAuthParams(config *scalersconfig.ScalerConfig, meta *rabbitMQMetadata) error { - meta.enableTLS = false - if val, ok := config.AuthParams["tls"]; ok { - val = strings.TrimSpace(val) - if val == rmqTLSEnable { - meta.ca = config.AuthParams["ca"] - meta.cert = config.AuthParams["cert"] - meta.key = config.AuthParams["key"] - meta.enableTLS = true - } else if val != "disable" { - return fmt.Errorf("err incorrect value for TLS given: %s", val) - } - } - return nil -} - -func resolveAuth(config *scalersconfig.ScalerConfig, meta *rabbitMQMetadata) error { - usernameVal, err := getParameterFromConfigV2(config, "username", reflect.TypeOf(meta.username), - UseAuthentication(true), UseResolvedEnv(true), IsOptional(true)) - if err != nil { - return err - } - meta.username = usernameVal.(string) - - passwordVal, err := getParameterFromConfigV2(config, "password", reflect.TypeOf(meta.username), - UseAuthentication(true), UseResolvedEnv(true), IsOptional(true)) - if err != nil { - return err - } - meta.password = passwordVal.(string) - - if (meta.username != "" || meta.password != "") && (meta.username == "" || meta.password == "") { - return fmt.Errorf("username and password must be given together") - } - - return nil -} - func parseRabbitMQMetadata(config *scalersconfig.ScalerConfig) (*rabbitMQMetadata, error) { - meta := rabbitMQMetadata{ + meta := &rabbitMQMetadata{ connectionName: connectionName(config), } - // Resolve protocol type - if err := resolveProtocol(config, &meta); err != nil { - return nil, err - } - - // Resolve host value - if err := resolveHostValue(config, &meta); err != nil { - return nil, err - } - - // Resolve TLS authentication parameters - if err := resolveTLSAuthParams(config, &meta); err != nil { - return nil, err - } - - // Resolve username and password - if err := resolveAuth(config, &meta); err != nil { - return nil, err + if err := config.TypedConfig(meta); err != nil { + return nil, fmt.Errorf("error parsing rabbitmq metadata: %w", err) } - meta.keyPassword = config.AuthParams["keyPassword"] - if config.PodIdentity.Provider == v1alpha1.PodIdentityProviderAzureWorkload { - if config.AuthParams["workloadIdentityResource"] != "" { + if meta.WorkloadIdentityResource != "" { meta.workloadIdentityClientID = config.PodIdentity.GetIdentityID() meta.workloadIdentityTenantID = config.PodIdentity.GetIdentityTenantID() - meta.workloadIdentityResource = config.AuthParams["workloadIdentityResource"] - } - } - - certGiven := meta.cert != "" - keyGiven := meta.key != "" - if certGiven != keyGiven { - return nil, fmt.Errorf("both key and cert must be provided") - } - - meta.unsafeSsl = false - if val, ok := config.TriggerMetadata["unsafeSsl"]; ok { - boolVal, err := strconv.ParseBool(val) - if err != nil { - return nil, fmt.Errorf("failed to parse unsafeSsl value. Must be either true or false") - } - meta.unsafeSsl = boolVal - } - - // If the protocol is auto, check the host scheme. - if meta.protocol == autoProtocol { - parsedURL, err := url.Parse(meta.host) - if err != nil { - return nil, fmt.Errorf("can't parse host to find protocol: %w", err) - } - switch parsedURL.Scheme { - case "amqp", "amqps": - meta.protocol = amqpProtocol - case "http", "https": - meta.protocol = httpProtocol - default: - return nil, fmt.Errorf("unknown host URL scheme `%s`", parsedURL.Scheme) } } - if meta.protocol == amqpProtocol && config.AuthParams["workloadIdentityResource"] != "" { - return nil, fmt.Errorf("workload identity is not supported for amqp protocol currently") - } - - // Resolve queueName - if val, ok := config.TriggerMetadata["queueName"]; ok { - meta.queueName = val - } else { - return nil, fmt.Errorf("no queue name given") - } - - // Resolve vhostName - if val, ok := config.TriggerMetadata["vhostName"]; ok { - meta.vhostName = val - } - - err := parseRabbitMQHttpProtocolMetadata(config, &meta) - if err != nil { - return nil, err - } - - if meta.useRegex && meta.protocol != httpProtocol { - return nil, fmt.Errorf("configure only useRegex with http protocol") - } - - if meta.excludeUnacknowledged && meta.protocol != httpProtocol { - return nil, fmt.Errorf("configure excludeUnacknowledged=true with http protocol only") - } - - _, err = parseTrigger(&meta, config) - if err != nil { - return nil, fmt.Errorf("unable to parse trigger: %w", err) - } - // Resolve timeout - if err := resolveTimeout(config, &meta); err != nil { - return nil, err - } meta.triggerIndex = config.TriggerIndex - return &meta, nil -} - -func parseRabbitMQHttpProtocolMetadata(config *scalersconfig.ScalerConfig, meta *rabbitMQMetadata) error { - // Resolve useRegex - if val, ok := config.TriggerMetadata["useRegex"]; ok { - useRegex, err := strconv.ParseBool(val) - if err != nil { - return fmt.Errorf("useRegex has invalid value") - } - meta.useRegex = useRegex - } - - // Resolve excludeUnacknowledged - if val, ok := config.TriggerMetadata["excludeUnacknowledged"]; ok { - excludeUnacknowledged, err := strconv.ParseBool(val) - if err != nil { - return fmt.Errorf("excludeUnacknowledged has invalid value") - } - meta.excludeUnacknowledged = excludeUnacknowledged - } - - // Resolve pageSize - if val, ok := config.TriggerMetadata["pageSize"]; ok { - pageSize, err := strconv.ParseInt(val, 10, 64) - if err != nil { - return fmt.Errorf("pageSize has invalid value") - } - meta.pageSize = pageSize - if meta.pageSize < 1 { - return fmt.Errorf("pageSize should be 1 or greater than 1") - } - } else { - meta.pageSize = 100 - } - - // Resolve operation - meta.operation = defaultOperation - if val, ok := config.TriggerMetadata["operation"]; ok { - meta.operation = val - } - - return nil -} - -func parseTrigger(meta *rabbitMQMetadata, config *scalersconfig.ScalerConfig) (*rabbitMQMetadata, error) { - deprecatedQueueLengthValue, deprecatedQueueLengthPresent := config.TriggerMetadata[rabbitQueueLengthMetricName] - mode, modePresent := config.TriggerMetadata[rabbitModeTriggerConfigName] - value, valuePresent := config.TriggerMetadata[rabbitValueTriggerConfigName] - activationValue, activationValuePresent := config.TriggerMetadata[rabbitActivationValueTriggerConfigName] - - // Initialize to default trigger settings - meta.mode = rabbitModeQueueLength - meta.value = defaultRabbitMQQueueLength - - // If nothing is specified for the trigger then return the default - if !deprecatedQueueLengthPresent && !modePresent && !valuePresent { - return meta, nil - } - - // Only allow one of `queueLength` or `mode`/`value` - if deprecatedQueueLengthPresent && (modePresent || valuePresent) { - return nil, fmt.Errorf("queueLength is deprecated; configure only %s and %s", rabbitModeTriggerConfigName, rabbitValueTriggerConfigName) - } - - // Parse activation value - if activationValuePresent { - activation, err := strconv.ParseFloat(activationValue, 64) - if err != nil { - return nil, fmt.Errorf("can't parse %s: %w", rabbitActivationValueTriggerConfigName, err) - } - meta.activationValue = activation - } - - // Parse deprecated `queueLength` value - if deprecatedQueueLengthPresent { - queueLength, err := strconv.ParseFloat(deprecatedQueueLengthValue, 64) - if err != nil { - return nil, fmt.Errorf("can't parse %s: %w", rabbitQueueLengthMetricName, err) - } - meta.mode = rabbitModeQueueLength - meta.value = queueLength - - return meta, nil - } - - if !modePresent { - return nil, fmt.Errorf("%s must be specified", rabbitModeTriggerConfigName) - } - if !valuePresent { - return nil, fmt.Errorf("%s must be specified", rabbitValueTriggerConfigName) - } - - // Resolve trigger mode - switch mode { - case rabbitModeQueueLength: - meta.mode = rabbitModeQueueLength - case rabbitModeMessageRate: - meta.mode = rabbitModeMessageRate - default: - return nil, fmt.Errorf("trigger mode %s must be one of %s, %s", mode, rabbitModeQueueLength, rabbitModeMessageRate) - } - triggerValue, err := strconv.ParseFloat(value, 64) - if err != nil { - return nil, fmt.Errorf("can't parse %s: %w", rabbitValueTriggerConfigName, err) - } - meta.value = triggerValue - - if meta.mode == rabbitModeMessageRate && meta.protocol != httpProtocol { - return nil, fmt.Errorf("protocol %s not supported; must be http to use mode %s", meta.protocol, rabbitModeMessageRate) - } - return meta, nil } @@ -496,8 +338,8 @@ func getConnectionAndChannel(host string, meta *rabbitMQMetadata) (*amqp.Connect }, } - if meta.enableTLS { - tlsConfig, err := kedautil.NewTLSConfigWithPassword(meta.cert, meta.key, meta.keyPassword, meta.ca, meta.unsafeSsl) + if meta.EnableTLS == rmqTLSEnable { + tlsConfig, err := kedautil.NewTLSConfigWithPassword(meta.Cert, meta.Key, meta.KeyPassword, meta.Ca, meta.UnsafeSsl) if err != nil { return nil, nil, err } @@ -534,13 +376,13 @@ func (s *rabbitMQScaler) Close(context.Context) error { } func (s *rabbitMQScaler) getQueueStatus(ctx context.Context) (int64, float64, error) { - if s.metadata.protocol == httpProtocol { + if s.metadata.Protocol == httpProtocol { info, err := s.getQueueInfoViaHTTP(ctx) if err != nil { return -1, -1, err } - if s.metadata.excludeUnacknowledged { + if s.metadata.ExcludeUnacknowledged { // messages count includes only ready return int64(info.MessagesReady), info.MessageStat.PublishDetail.Rate, nil } @@ -549,7 +391,7 @@ func (s *rabbitMQScaler) getQueueStatus(ctx context.Context) (int64, float64, er } // QueueDeclarePassive assumes that the queue exists and fails if it doesn't - items, err := s.channel.QueueDeclarePassive(s.metadata.queueName, false, false, false, false, amqp.Table{}) + items, err := s.channel.QueueDeclarePassive(s.metadata.QueueName, false, false, false, false, amqp.Table{}) if err != nil { return -1, -1, err } @@ -565,9 +407,9 @@ func getJSON(ctx context.Context, s *rabbitMQScaler, url string) (queueInfo, err return result, err } - if s.metadata.workloadIdentityResource != "" { + if s.metadata.WorkloadIdentityResource != "" { if s.azureOAuth == nil { - s.azureOAuth = azure.NewAzureADWorkloadIdentityTokenProvider(ctx, s.metadata.workloadIdentityClientID, s.metadata.workloadIdentityTenantID, s.metadata.workloadIdentityAuthorityHost, s.metadata.workloadIdentityResource) + s.azureOAuth = azure.NewAzureADWorkloadIdentityTokenProvider(ctx, s.metadata.workloadIdentityClientID, s.metadata.workloadIdentityTenantID, s.metadata.workloadIdentityAuthorityHost, s.metadata.WorkloadIdentityResource) } err = s.azureOAuth.Refresh() @@ -586,7 +428,7 @@ func getJSON(ctx context.Context, s *rabbitMQScaler, url string) (queueInfo, err defer r.Body.Close() if r.StatusCode == 200 { - if s.metadata.useRegex { + if s.metadata.UseRegex { var queues regexQueueInfo err = json.NewDecoder(r.Body).Decode(&queues) if err != nil { @@ -626,24 +468,24 @@ func getVhostAndPathFromURL(rawPath, vhostName string) (resolvedVhostPath, resol } func (s *rabbitMQScaler) getQueueInfoViaHTTP(ctx context.Context) (*queueInfo, error) { - parsedURL, err := url.Parse(s.metadata.host) + parsedURL, err := url.Parse(s.metadata.Host) if err != nil { return nil, err } - vhost, subpaths := getVhostAndPathFromURL(parsedURL.Path, s.metadata.vhostName) + vhost, subpaths := getVhostAndPathFromURL(parsedURL.Path, s.metadata.VhostName) parsedURL.Path = subpaths - if s.metadata.username != "" && s.metadata.password != "" { - parsedURL.User = url.UserPassword(s.metadata.username, s.metadata.password) + if s.metadata.Username != "" && s.metadata.Password != "" { + parsedURL.User = url.UserPassword(s.metadata.Username, s.metadata.Password) } var getQueueInfoManagementURI string - if s.metadata.useRegex { - getQueueInfoManagementURI = fmt.Sprintf("%s/api/queues%s?page=1&use_regex=true&pagination=false&name=%s&page_size=%d", parsedURL.String(), vhost, url.QueryEscape(s.metadata.queueName), s.metadata.pageSize) + if s.metadata.UseRegex { + getQueueInfoManagementURI = fmt.Sprintf("%s/api/queues%s?page=1&use_regex=true&pagination=false&name=%s&page_size=%d", parsedURL.String(), vhost, url.QueryEscape(s.metadata.QueueName), s.metadata.PageSize) } else { - getQueueInfoManagementURI = fmt.Sprintf("%s/api/queues%s/%s", parsedURL.String(), vhost, url.QueryEscape(s.metadata.queueName)) + getQueueInfoManagementURI = fmt.Sprintf("%s/api/queues%s/%s", parsedURL.String(), vhost, url.QueryEscape(s.metadata.QueueName)) } var info queueInfo @@ -660,9 +502,9 @@ func (s *rabbitMQScaler) getQueueInfoViaHTTP(ctx context.Context) (*queueInfo, e func (s *rabbitMQScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpec { externalMetric := &v2.ExternalMetricSource{ Metric: v2.MetricIdentifier{ - Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, kedautil.NormalizeString(fmt.Sprintf("rabbitmq-%s", url.QueryEscape(s.metadata.queueName)))), + Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, kedautil.NormalizeString(fmt.Sprintf("rabbitmq-%s", url.QueryEscape(s.metadata.QueueName)))), }, - Target: GetMetricTargetMili(s.metricType, s.metadata.value), + Target: GetMetricTargetMili(s.metricType, s.metadata.Value), } metricSpec := v2.MetricSpec{ External: externalMetric, Type: rabbitMetricType, @@ -680,12 +522,12 @@ func (s *rabbitMQScaler) GetMetricsAndActivity(ctx context.Context, metricName s var metric external_metrics.ExternalMetricValue var isActive bool - if s.metadata.mode == rabbitModeQueueLength { + if s.metadata.Mode == rabbitModeQueueLength { metric = GenerateMetricInMili(metricName, float64(messages)) - isActive = float64(messages) > s.metadata.activationValue + isActive = float64(messages) > s.metadata.ActivationValue } else { metric = GenerateMetricInMili(metricName, publishRate) - isActive = publishRate > s.metadata.activationValue || float64(messages) > s.metadata.activationValue + isActive = publishRate > s.metadata.ActivationValue || float64(messages) > s.metadata.ActivationValue } return []external_metrics.ExternalMetricValue{metric}, isActive, nil @@ -696,7 +538,7 @@ func getComposedQueue(s *rabbitMQScaler, q []queueInfo) (queueInfo, error) { queue.Name = "composed-queue" queue.MessagesUnacknowledged = 0 if len(q) > 0 { - switch s.metadata.operation { + switch s.metadata.Operation { case sumOperation: sumMessages, sumReady, sumRate := getSum(q) queue.Messages = sumMessages @@ -713,7 +555,7 @@ func getComposedQueue(s *rabbitMQScaler, q []queueInfo) (queueInfo, error) { queue.MessagesReady = maxReady queue.MessageStat.PublishDetail.Rate = maxRate default: - return queue, fmt.Errorf("operation mode %s must be one of %s, %s, %s", s.metadata.operation, sumOperation, avgOperation, maxOperation) + return queue, fmt.Errorf("operation mode %s must be one of %s, %s, %s", s.metadata.Operation, sumOperation, avgOperation, maxOperation) } } else { queue.Messages = 0 diff --git a/pkg/scalers/rabbitmq_scaler_test.go b/pkg/scalers/rabbitmq_scaler_test.go index dd9c3f900b8..ed1785e5be3 100644 --- a/pkg/scalers/rabbitmq_scaler_test.go +++ b/pkg/scalers/rabbitmq_scaler_test.go @@ -34,7 +34,7 @@ type parseRabbitMQAuthParamTestData struct { podIdentity v1alpha1.AuthPodIdentity authParams map[string]string isError bool - enableTLS bool + enableTLS string workloadIdentity bool } @@ -142,35 +142,35 @@ var testRabbitMQMetadata = []parseRabbitMQMetadataTestData{ } var testRabbitMQAuthParamData = []parseRabbitMQAuthParamTestData{ - {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa", "cert": "ceert", "key": "keey"}, false, true, false}, + {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa", "cert": "ceert", "key": "keey"}, false, rmqTLSEnable, false}, // success, TLS cert/key and assumed public CA - {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "cert": "ceert", "key": "keey"}, false, true, false}, + {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "cert": "ceert", "key": "keey"}, false, rmqTLSEnable, false}, // success, TLS cert/key + key password and assumed public CA - {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "cert": "ceert", "key": "keey", "keyPassword": "keeyPassword"}, false, true, false}, + {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "cert": "ceert", "key": "keey", "keyPassword": "keeyPassword"}, false, rmqTLSEnable, false}, // success, TLS CA only - {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa"}, false, true, false}, + {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa"}, false, rmqTLSEnable, false}, // failure, TLS missing cert - {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa", "key": "kee"}, true, true, false}, + {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa", "key": "kee"}, true, rmqTLSEnable, false}, // failure, TLS missing key - {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa", "cert": "ceert"}, true, true, false}, + {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa", "cert": "ceert"}, true, rmqTLSEnable, false}, // failure, TLS invalid - {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "yes", "ca": "caaa", "cert": "ceert", "key": "kee"}, true, true, false}, + {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "yes", "ca": "caaa", "cert": "ceert", "key": "kee"}, true, rmqTLSEnable, false}, // success, username and password - {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"username": "user", "password": "PASSWORD"}, false, false, false}, + {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"username": "user", "password": "PASSWORD"}, false, rmqTLSDisable, false}, // failure, username but no password - {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"username": "user"}, true, false, false}, + {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"username": "user"}, true, rmqTLSDisable, false}, // failure, password but no username - {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"password": "PASSWORD"}, true, false, false}, + {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"password": "PASSWORD"}, true, rmqTLSDisable, false}, // success, username and password from env - {map[string]string{"queueName": "sample", "hostFromEnv": host, "usernameFromEnv": rabbitMQUsername, "passwordFromEnv": rabbitMQPassword}, v1alpha1.AuthPodIdentity{}, map[string]string{}, false, false, false}, + {map[string]string{"queueName": "sample", "hostFromEnv": host, "usernameFromEnv": rabbitMQUsername, "passwordFromEnv": rabbitMQPassword}, v1alpha1.AuthPodIdentity{}, map[string]string{}, false, rmqTLSDisable, false}, // failure, username from env but not password - {map[string]string{"queueName": "sample", "hostFromEnv": host, "usernameFromEnv": rabbitMQUsername}, v1alpha1.AuthPodIdentity{}, map[string]string{}, true, false, false}, + {map[string]string{"queueName": "sample", "hostFromEnv": host, "usernameFromEnv": rabbitMQUsername}, v1alpha1.AuthPodIdentity{}, map[string]string{}, true, rmqTLSDisable, false}, // failure, password from env but not username - {map[string]string{"queueName": "sample", "hostFromEnv": host, "passwordFromEnv": rabbitMQPassword}, v1alpha1.AuthPodIdentity{}, map[string]string{}, true, false, false}, + {map[string]string{"queueName": "sample", "hostFromEnv": host, "passwordFromEnv": rabbitMQPassword}, v1alpha1.AuthPodIdentity{}, map[string]string{}, true, rmqTLSDisable, false}, // success, WorkloadIdentity - {map[string]string{"queueName": "sample", "hostFromEnv": host, "protocol": "http"}, v1alpha1.AuthPodIdentity{Provider: v1alpha1.PodIdentityProviderAzureWorkload, IdentityID: kedautil.StringPointer("client-id")}, map[string]string{"workloadIdentityResource": "rabbitmq-resource-id"}, false, false, true}, + {map[string]string{"queueName": "sample", "hostFromEnv": host, "protocol": "http"}, v1alpha1.AuthPodIdentity{Provider: v1alpha1.PodIdentityProviderAzureWorkload, IdentityID: kedautil.StringPointer("client-id")}, map[string]string{"workloadIdentityResource": "rabbitmq-resource-id"}, false, rmqTLSDisable, true}, // failure, WoekloadIdentity not supported for amqp - {map[string]string{"queueName": "sample", "hostFromEnv": host, "protocol": "amqp"}, v1alpha1.AuthPodIdentity{Provider: v1alpha1.PodIdentityProviderAzureWorkload, IdentityID: kedautil.StringPointer("client-id")}, map[string]string{"workloadIdentityResource": "rabbitmq-resource-id"}, true, false, false}, + {map[string]string{"queueName": "sample", "hostFromEnv": host, "protocol": "amqp"}, v1alpha1.AuthPodIdentity{Provider: v1alpha1.PodIdentityProviderAzureWorkload, IdentityID: kedautil.StringPointer("client-id")}, map[string]string{"workloadIdentityResource": "rabbitmq-resource-id"}, true, rmqTLSDisable, false}, } var rabbitMQMetricIdentifiers = []rabbitMQMetricIdentifier{ {&testRabbitMQMetadata[1], 0, "s0-rabbitmq-sample"}, @@ -191,8 +191,8 @@ func TestRabbitMQParseMetadata(t *testing.T) { if err != nil && !testData.isError { t.Errorf("Expect error but got success in test case %d", idx) } - if boolVal != meta.unsafeSsl { - t.Errorf("Expect %t but got %t in test case %d", boolVal, meta.unsafeSsl, idx) + if boolVal != meta.UnsafeSsl { + t.Errorf("Expect %t but got %t in test case %d", boolVal, meta.UnsafeSsl, idx) } } } @@ -207,25 +207,25 @@ func TestRabbitMQParseAuthParamData(t *testing.T) { if testData.isError && err == nil { t.Error("Expected error but got success") } - if metadata != nil && metadata.enableTLS != testData.enableTLS { - t.Errorf("Expected enableTLS to be set to %v but got %v\n", testData.enableTLS, metadata.enableTLS) + if metadata != nil && metadata.EnableTLS != testData.enableTLS { + t.Errorf("Expected enableTLS to be set to %v but got %v\n", testData.enableTLS, metadata.EnableTLS) } - if metadata != nil && metadata.enableTLS { - if metadata.ca != testData.authParams["ca"] { - t.Errorf("Expected ca to be set to %v but got %v\n", testData.authParams["ca"], metadata.enableTLS) + if metadata != nil && metadata.EnableTLS == rmqTLSEnable { + if metadata.Ca != testData.authParams["ca"] { + t.Errorf("Expected ca to be set to %v but got %v\n", testData.authParams["ca"], metadata.EnableTLS) } - if metadata.cert != testData.authParams["cert"] { - t.Errorf("Expected cert to be set to %v but got %v\n", testData.authParams["cert"], metadata.cert) + if metadata.Cert != testData.authParams["cert"] { + t.Errorf("Expected cert to be set to %v but got %v\n", testData.authParams["cert"], metadata.Cert) } - if metadata.key != testData.authParams["key"] { - t.Errorf("Expected key to be set to %v but got %v\n", testData.authParams["key"], metadata.key) + if metadata.Key != testData.authParams["key"] { + t.Errorf("Expected key to be set to %v but got %v\n", testData.authParams["key"], metadata.Key) } - if metadata.keyPassword != testData.authParams["keyPassword"] { - t.Errorf("Expected key to be set to %v but got %v\n", testData.authParams["keyPassword"], metadata.key) + if metadata.KeyPassword != testData.authParams["keyPassword"] { + t.Errorf("Expected key to be set to %v but got %v\n", testData.authParams["keyPassword"], metadata.Key) } } if metadata != nil && metadata.workloadIdentityClientID != "" && !testData.workloadIdentity { - t.Errorf("Expected workloadIdentity to be disabled but got %v as client ID and %v as resource\n", metadata.workloadIdentityClientID, metadata.workloadIdentityResource) + t.Errorf("Expected workloadIdentity to be disabled but got %v as client ID and %v as resource\n", metadata.workloadIdentityClientID, metadata.WorkloadIdentityResource) } if metadata != nil && metadata.workloadIdentityClientID == "" && testData.workloadIdentity { t.Error("Expected workloadIdentity to be enabled but was not\n") @@ -248,8 +248,8 @@ func TestParseDefaultQueueLength(t *testing.T) { t.Error("Expected success but got error", err) case testData.isError && err == nil: t.Error("Expected error but got success") - case metadata.value != defaultRabbitMQQueueLength: - t.Error("Expected default queueLength =", defaultRabbitMQQueueLength, "but got", metadata.value) + case metadata.Value != defaultRabbitMQQueueLength: + t.Error("Expected default queueLength =", defaultRabbitMQQueueLength, "but got", metadata.Value) } } } From 1880cdbb5591b1948930d1c20a357ad172ff141c Mon Sep 17 00:00:00 2001 From: Zhenghan Zhou Date: Sun, 3 Nov 2024 19:42:34 +0800 Subject: [PATCH 7/7] Add generateEmbeddedObjectMeta flag when generating crd (#5939) * Update Signed-off-by: SpiritZhou * Update CHANGLOG Signed-off-by: SpiritZhou * Update Signed-off-by: SpiritZhou --------- Signed-off-by: SpiritZhou --- CHANGELOG.md | 1 + Makefile | 2 +- config/crd/bases/keda.sh_scaledjobs.yaml | 34 ++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 32f6c393aae..9ca27681340 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,6 +57,7 @@ To learn more about active deprecations, we recommend checking [GitHub Discussio ### New +- **General**: Add the generateEmbeddedObjectMeta flag to generate meta properties of JobTargetRef in ScaledJob ([#5908](https://github.com/kedacore/keda/issues/5908)) - **General**: Cache miss fallback in validating webhook for ScaledObjects with direct kubernetes client ([#5973](https://github.com/kedacore/keda/issues/5973)) - **CloudEventSource**: Introduce ClusterCloudEventSource ([#3533](https://github.com/kedacore/keda/issues/3533)) - **CloudEventSource**: Provide ClusterCloudEventSource around the management of ScaledJobs resources ([#3523](https://github.com/kedacore/keda/issues/3523)) diff --git a/Makefile b/Makefile index 4ee8b59f2df..22c8c707608 100644 --- a/Makefile +++ b/Makefile @@ -131,7 +131,7 @@ smoke-test: ## Run e2e tests against Kubernetes cluster configured in ~/.kube/co ##@ Development manifests: controller-gen ## Generate ClusterRole and CustomResourceDefinition objects. - $(CONTROLLER_GEN) crd:crdVersions=v1 rbac:roleName=keda-operator paths="./..." output:crd:artifacts:config=config/crd/bases + $(CONTROLLER_GEN) crd:crdVersions=v1,generateEmbeddedObjectMeta=true rbac:roleName=keda-operator paths="./..." output:crd:artifacts:config=config/crd/bases # withTriggers is only used for duck typing so we only need the deepcopy methods # However operator-sdk generate doesn't appear to have an option for that # until this issue is fixed: https://github.com/kubernetes-sigs/controller-tools/issues/398 diff --git a/config/crd/bases/keda.sh_scaledjobs.yaml b/config/crd/bases/keda.sh_scaledjobs.yaml index 5ccf72f2d46..47e3a079da3 100644 --- a/config/crd/bases/keda.sh_scaledjobs.yaml +++ b/config/crd/bases/keda.sh_scaledjobs.yaml @@ -380,6 +380,23 @@ spec: description: |- Standard object's metadata. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#metadata + properties: + annotations: + additionalProperties: + type: string + type: object + finalizers: + items: + type: string + type: array + labels: + additionalProperties: + type: string + type: object + name: + type: string + namespace: + type: string type: object spec: description: |- @@ -6684,6 +6701,23 @@ spec: May contain labels and annotations that will be copied into the PVC when creating it. No other fields are allowed and will be rejected during validation. + properties: + annotations: + additionalProperties: + type: string + type: object + finalizers: + items: + type: string + type: array + labels: + additionalProperties: + type: string + type: object + name: + type: string + namespace: + type: string type: object spec: description: |-