diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 4a8c0f50f9..4467469e44 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -280,14 +280,7 @@ func BuildRawPod(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*v return nil, nil, "", err } - // If primaryContainerName is set in taskTemplate config, use it instead - // of c.Name - if val, ok := taskTemplate.GetConfig()[PrimaryContainerKey]; ok { - primaryContainerName = val - c.Name = primaryContainerName - } else { - primaryContainerName = c.Name - } + primaryContainerName = c.Name podSpec = &v1.PodSpec{ Containers: []v1.Container{ *c, diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index ed361374e6..2abcc0861a 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -351,7 +351,8 @@ func dummySparkTaskTemplateContainer(id string, sparkConf map[string]string) *co } } -func dummySparkTaskTemplateDriverExecutor(id string, sparkConf map[string]string, driverPod *core.K8SPod, executorPod *core.K8SPod) *core.TaskTemplate { +func dummySparkTaskTemplateDriverExecutor(id string, sparkConf map[string]string, driverPod *core.K8SPod, executorPod *core.K8SPod, basePod *corev1.PodSpec) *core.TaskTemplate { + sparkJob := dummySparkCustomObjDriverExecutor(sparkConf, driverPod, executorPod) structObj, err := utils.MarshalObjToStruct(sparkJob) @@ -359,12 +360,17 @@ func dummySparkTaskTemplateDriverExecutor(id string, sparkConf map[string]string panic(err) } + basePodPb, err := utils.MarshalObjToStruct(basePod) + if err != nil { + panic(err) + } + return &core.TaskTemplate{ Id: &core.Identifier{Name: id}, - Type: "container", - Target: &core.TaskTemplate_Container{ - Container: &core.Container{ - Image: testImage, + Type: "k8s_pod", + Target: &core.TaskTemplate_K8SPod{ + K8SPod: &core.K8SPod{ + PodSpec: basePodPb, }, }, Config: map[string]string{ @@ -974,6 +980,9 @@ func TestBuildResourceCustomK8SPod(t *testing.T) { defaultConfig := defaultPluginConfig() assert.NoError(t, config.SetK8sPluginConfig(defaultConfig)) + basePodSpec := dummyPodSpec() + basePodSpec.NodeSelector = map[string]string{"x/custom": "foo"} + // add extraDriverToleration and extraExecutorToleration driverExtraToleration := corev1.Toleration{ Key: "x/flyte-driver", @@ -1008,7 +1017,7 @@ func TestBuildResourceCustomK8SPod(t *testing.T) { Labels: map[string]string{"label-executor": "val-executor"}, }, } - taskTemplate := dummySparkTaskTemplateDriverExecutor("blah-1", dummySparkConf, driverK8SPod, executorK8SPod) + taskTemplate := dummySparkTaskTemplateDriverExecutor("blah-1", dummySparkConf, driverK8SPod, executorK8SPod, basePodSpec) sparkResourceHandler := sparkResourceHandler{} taskCtx := dummySparkTaskContext(taskTemplate, true, k8s.PluginState{}) @@ -1029,7 +1038,7 @@ func TestBuildResourceCustomK8SPod(t *testing.T) { assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.ServiceAccount) assert.Equal(t, sparkOp.PythonApplicationType, sparkApp.Spec.Type) assert.Equal(t, testImage, *sparkApp.Spec.Image) - assert.Equal(t, testArgs, sparkApp.Spec.Arguments) + assert.Equal(t, append(testArgs, testArgs...), sparkApp.Spec.Arguments) assert.Equal(t, sparkOp.RestartPolicy{ Type: sparkOp.OnFailure, OnSubmissionFailureRetries: intPtr(int32(14)), @@ -1052,7 +1061,7 @@ func TestBuildResourceCustomK8SPod(t *testing.T) { assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Driver.Env, "foo").Value) assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Driver.Env, "fooEnv").Value) assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Driver.Env, "SECRET")) - assert.Equal(t, 9, len(sparkApp.Spec.Driver.Env)) + assert.Equal(t, 11, len(sparkApp.Spec.Driver.Env)) assert.Equal(t, testImage, *sparkApp.Spec.Driver.Image) assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.Driver.ServiceAccount) assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Driver.SecurityContenxt) @@ -1095,7 +1104,7 @@ func TestBuildResourceCustomK8SPod(t *testing.T) { assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Executor.Env, "foo").Value) assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Executor.Env, "fooEnv").Value) assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Executor.Env, "SECRET")) - assert.Equal(t, 9, len(sparkApp.Spec.Executor.Env)) + assert.Equal(t, 11, len(sparkApp.Spec.Executor.Env)) assert.Equal(t, testImage, *sparkApp.Spec.Executor.Image) assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Executor.SecurityContenxt) assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Executor.DNSConfig)