diff --git a/pkg/operations/custom.go b/pkg/operations/custom.go index ffc1e19af83..e8aab5e30a0 100644 --- a/pkg/operations/custom.go +++ b/pkg/operations/custom.go @@ -292,6 +292,21 @@ func resolveParameterValue(ctx context.Context, return "", nil } +func validateAndGetCompSpec(cluster *appsv1.Cluster, opsDef *opsv1alpha1.OpsDefinition, componentName string) (*appsv1.ClusterComponentSpec, error) { + compSpec := cluster.Spec.GetComponentByName(componentName) + if compSpec != nil { + return compSpec, nil + } + shardingSpec := cluster.Spec.GetShardingByName(componentName) + if shardingSpec == nil { + return nil, intctrlutil.NewFatalError(fmt.Sprintf(`cannot found the component "%s" in cluster "%s"`, componentName, cluster.Name)) + } + if len(opsDef.Spec.PodInfoExtractors) == 0 { + return nil, intctrlutil.NewFatalError(fmt.Sprintf(`podInfoExtractors cannot be empty in opsDef "%s" when the component "%s" is a shard component`, opsDef.Name, componentName)) + } + return &shardingSpec.Template, nil +} + // initOpsDefAndValidate inits the opsDefinition to OpsResource and validates if the opsRequest is valid. func initOpsDefAndValidate(reqCtx intctrlutil.RequestCtx, cli client.Client, @@ -305,31 +320,33 @@ func initOpsDefAndValidate(reqCtx intctrlutil.RequestCtx, return err } opsRes.OpsDef = opsDef - // 1. validate OpenApV3Schema + parametersSchema := opsDef.Spec.ParametersSchema - if parametersSchema == nil { - return nil - } for _, v := range customSpec.CustomOpsComponents { - paramsMap, err := covertParametersToMap(reqCtx.Ctx, cli, v.Parameters, opsRes.OpsRequest.Namespace) - if err != nil { - return err - } - // covert to type map[string]interface{} - params, err := common.CoverStringToInterfaceBySchemaType(parametersSchema.OpenAPIV3Schema, paramsMap) - if err != nil { - return intctrlutil.NewFatalError(err.Error()) - } - if parametersSchema != nil && parametersSchema.OpenAPIV3Schema != nil { - if err = common.ValidateDataWithSchema(parametersSchema.OpenAPIV3Schema, params); err != nil { + // 1. validate OpenApV3Schema + if parametersSchema != nil { + paramsMap, err := covertParametersToMap(reqCtx.Ctx, cli, v.Parameters, opsRes.OpsRequest.Namespace) + if err != nil { + return err + } + // covert to type map[string]interface{} + params, err := common.CoverStringToInterfaceBySchemaType(parametersSchema.OpenAPIV3Schema, paramsMap) + if err != nil { return intctrlutil.NewFatalError(err.Error()) } + if parametersSchema != nil && parametersSchema.OpenAPIV3Schema != nil { + if err = common.ValidateDataWithSchema(parametersSchema.OpenAPIV3Schema, params); err != nil { + return intctrlutil.NewFatalError(err.Error()) + } + } } - // 2. validate component and componentDef + compSpec, err := validateAndGetCompSpec(opsRes.Cluster, opsDef, v.ComponentName) + if err != nil { + return err + } if len(opsRes.OpsDef.Spec.ComponentInfos) > 0 { // get component definition - compSpec := getComponentSpecOrShardingTemplate(opsRes.Cluster, v.ComponentName) compDef, err := component.GetCompDefByName(reqCtx.Ctx, cli, compSpec.ComponentDef) if err != nil { return err diff --git a/pkg/operations/custom/utils.go b/pkg/operations/custom/utils.go index 22e4fadcca3..8b7e1fcdb32 100644 --- a/pkg/operations/custom/utils.go +++ b/pkg/operations/custom/utils.go @@ -58,15 +58,16 @@ func buildComponentEnvs(reqCtx intctrlutil.RequestCtx, cluster *appsv1.Cluster, opsDef *opsv1alpha1.OpsDefinition, env *[]corev1.EnvVar, - comp *appsv1.ClusterComponentSpec) error { + comp *appsv1.ClusterComponentSpec, + componentName string) error { // inject built-in component env - fullCompName := constant.GenerateClusterComponentName(cluster.Name, comp.Name) + fullCompName := constant.GenerateClusterComponentName(cluster.Name, componentName) *env = append(*env, []corev1.EnvVar{ {Name: constant.KBEnvClusterName, Value: cluster.Name}, - {Name: constant.KBEnvCompName, Value: comp.Name}, + {Name: constant.KBEnvCompName, Value: componentName}, {Name: constant.KBEnvClusterCompName, Value: fullCompName}, {Name: constant.KBEnvCompReplicas, Value: strconv.Itoa(int(comp.Replicas))}, - {Name: kbEnvCompHeadlessSVCName, Value: constant.GenerateDefaultComponentHeadlessServiceName(cluster.Name, comp.Name)}, + {Name: kbEnvCompHeadlessSVCName, Value: constant.GenerateDefaultComponentHeadlessServiceName(cluster.Name, componentName)}, }...) if len(opsDef.Spec.ComponentInfos) == 0 { return nil @@ -95,7 +96,7 @@ func buildComponentEnvs(reqCtx intctrlutil.RequestCtx, } // inject connect envs if componentInfo.AccountName != "" { - accountSecretName := constant.GenerateAccountSecretName(cluster.Name, comp.Name, componentInfo.AccountName) + accountSecretName := constant.GenerateAccountSecretName(cluster.Name, componentName, componentInfo.AccountName) *env = append(*env, corev1.EnvVar{Name: kbEnvAccountUserName, Value: componentInfo.AccountName}) *env = append(*env, corev1.EnvVar{Name: kbEnvAccountPassword, ValueFrom: buildSecretKeyRef(accountSecretName, constant.AccountPasswdForSecret)}) } @@ -106,7 +107,7 @@ func buildComponentEnvs(reqCtx intctrlutil.RequestCtx, if v.Name != componentInfo.ServiceName { continue } - *env = append(*env, corev1.EnvVar{Name: kbEnvCompSVCName, Value: constant.GenerateComponentServiceName(cluster.Name, comp.Name, v.ServiceName)}) + *env = append(*env, corev1.EnvVar{Name: kbEnvCompSVCName, Value: constant.GenerateComponentServiceName(cluster.Name, componentName, v.ServiceName)}) for _, port := range v.Spec.Ports { portName := strings.ReplaceAll(port.Name, "-", "_") *env = append(*env, corev1.EnvVar{Name: kbEnvCompSVCPortPrefix + strings.ToUpper(portName), Value: strconv.Itoa(int(port.Port))}) @@ -252,8 +253,16 @@ func buildActionPodEnv(reqCtx intctrlutil.RequestCtx, Value: ops.Namespace, }, } + + componentName := compCustomItem.ComponentName + if targetPod != nil { + // "component" might be a shard component and the name is logical. + // we should get the real component name from the pod labels. + componentName = targetPod.Labels[constant.KBAppComponentLabelKey] + } + // inject component and componentDef envs - if err := buildComponentEnvs(reqCtx, cli, cluster, opsDef, &env, comp); err != nil { + if err := buildComponentEnvs(reqCtx, cli, cluster, opsDef, &env, comp, componentName); err != nil { return nil, err } @@ -338,7 +347,7 @@ func getTargetPods( if cluster.Spec.GetShardingByName(compName) != nil { // get pods of the sharding components podList := &corev1.PodList{} - labels := constant.GetClusterLabels(cluster.Namespace) + labels := constant.GetClusterLabels(cluster.Name) labels[constant.KBAppShardingNameLabelKey] = compName if podSelector.Role != "" { labels[constant.RoleLabelKey] = podSelector.Role diff --git a/pkg/operations/custom_test.go b/pkg/operations/custom_test.go index 7a0be8b50cb..5fc6c7f837b 100644 --- a/pkg/operations/custom_test.go +++ b/pkg/operations/custom_test.go @@ -33,6 +33,7 @@ import ( appsv1 "github.com/apecloud/kubeblocks/apis/apps/v1" opsv1alpha1 "github.com/apecloud/kubeblocks/apis/operations/v1alpha1" + "github.com/apecloud/kubeblocks/pkg/common" "github.com/apecloud/kubeblocks/pkg/constant" intctrlutil "github.com/apecloud/kubeblocks/pkg/controllerutil" "github.com/apecloud/kubeblocks/pkg/generics" @@ -268,7 +269,7 @@ var _ = Describe("CustomOps", func() { Expect(ops.Status.Phase).Should(Equal(opsv1alpha1.OpsFailedPhase)) }) - It("Test custom ops when workload job completed ", func() { + testCustomOps := func() { By("create custom Ops") params := []opsv1alpha1.Parameter{ {Name: requiredParam, Value: "select 1"}, @@ -300,7 +301,66 @@ var _ = Describe("CustomOps", func() { _, err = GetOpsManager().Reconcile(reqCtx, k8sClient, opsResource) Expect(err).ShouldNot(HaveOccurred()) Expect(opsResource.OpsRequest.Status.Phase).Should(Equal(opsv1alpha1.OpsSucceedPhase)) + } + + It("Test custom ops when workload job completed ", func() { + testCustomOps() }) + It("Should failed when creating ops with a sharding component ahd the opsDef misses podInfoExtractors", func() { + cluster = testapps.NewClusterFactory(testCtx.DefaultNamespace, "", ""). + WithRandomName().AddSharding(defaultCompName, "", compDefName).Create(&testCtx).GetObject() + + params := []opsv1alpha1.Parameter{ + {Name: "sql", Value: "select 1"}, + } + ops := createCustomOps(defaultCompName, params) + opsResource.Cluster = cluster + By("validate pass for json schema") + _, err := GetOpsManager().Do(reqCtx, k8sClient, opsResource) + Expect(err).ShouldNot(HaveOccurred()) + Expect(ops.Status.Phase).Should(Equal(opsv1alpha1.OpsFailedPhase)) + }) + + It("Test custom ops with sharding cluster", func() { + By("init environment for sharding cluster") + cluster = testapps.NewClusterFactory(testCtx.DefaultNamespace, "", ""). + WithRandomName().AddSharding(defaultCompName, "", compDefName).Create(&testCtx).GetObject() + + opsResource.Cluster = cluster + + Expect(testapps.ChangeObj(&testCtx, opsDef, func(obj *opsv1alpha1.OpsDefinition) { + podExtraInfoName := "running-pod" + obj.Spec.PodInfoExtractors = []opsv1alpha1.PodInfoExtractor{ + { + Name: podExtraInfoName, + PodSelector: opsv1alpha1.PodSelector{ + MultiPodSelectionPolicy: opsv1alpha1.Any, + }, + }, + } + obj.Spec.Actions[0].Workload.PodInfoExtractorName = podExtraInfoName + })).Should(Succeed()) + + // create a sharding component + shardingNamePrefix := constant.GenerateClusterComponentName(cluster.Name, defaultCompName) + shardingCompName := common.SimpleNameGenerator.GenerateName(shardingNamePrefix) + compObj = testapps.NewComponentFactory(testCtx.DefaultNamespace, shardingCompName, compDefName). + AddLabels(constant.AppInstanceLabelKey, cluster.Name). + AddLabels(constant.KBAppClusterUIDKey, string(cluster.UID)). + AddLabels(constant.KBAppShardingNameLabelKey, defaultCompName). + AddLabels(constant.KBAppComponentLabelKey, shardingCompName). + SetReplicas(1). + Create(&testCtx). + GetObject() + + // create a pod which belongs to the sharding component + pod := testapps.MockInstanceSetPod(&testCtx, nil, cluster.Name, defaultCompName, fmt.Sprintf(shardingCompName+"-0"), "", "") + Expect(testapps.ChangeObj(&testCtx, pod, func(obj *corev1.Pod) { + pod.Labels[constant.KBAppShardingNameLabelKey] = defaultCompName + })).Should(Succeed()) + + testCustomOps() + }) }) }) diff --git a/pkg/operations/upgrade.go b/pkg/operations/upgrade.go index e157dba9dac..7fff7d771c0 100644 --- a/pkg/operations/upgrade.go +++ b/pkg/operations/upgrade.go @@ -52,7 +52,7 @@ func init() { // ActionStartedCondition the started condition when handle the upgrade request. func (u upgradeOpsHandler) ActionStartedCondition(reqCtx intctrlutil.RequestCtx, cli client.Client, opsRes *OpsResource) (*metav1.Condition, error) { - return opsv1alpha1.NewHorizontalScalingCondition(opsRes.OpsRequest), nil + return opsv1alpha1.NewUpgradingCondition(opsRes.OpsRequest), nil } func (u upgradeOpsHandler) Action(reqCtx intctrlutil.RequestCtx, cli client.Client, opsRes *OpsResource) error {