Skip to content

Commit

Permalink
test: update test
Browse files Browse the repository at this point in the history
Signed-off-by: machichima <[email protected]>
  • Loading branch information
machichima committed Jan 19, 2025
1 parent 65f81fa commit 5edc331
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
9 changes: 1 addition & 8 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,14 +280,7 @@ func BuildRawPod(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*v
return nil, nil, "", err
}

// If primaryContainerName is set in taskTemplate config, use it instead
// of c.Name
if val, ok := taskTemplate.GetConfig()[PrimaryContainerKey]; ok {
primaryContainerName = val
c.Name = primaryContainerName
} else {
primaryContainerName = c.Name
}
primaryContainerName = c.Name
podSpec = &v1.PodSpec{
Containers: []v1.Container{
*c,
Expand Down
27 changes: 18 additions & 9 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,20 +351,26 @@ func dummySparkTaskTemplateContainer(id string, sparkConf map[string]string) *co
}
}

func dummySparkTaskTemplateDriverExecutor(id string, sparkConf map[string]string, driverPod *core.K8SPod, executorPod *core.K8SPod) *core.TaskTemplate {
func dummySparkTaskTemplateDriverExecutor(id string, sparkConf map[string]string, driverPod *core.K8SPod, executorPod *core.K8SPod, basePod *corev1.PodSpec) *core.TaskTemplate {

sparkJob := dummySparkCustomObjDriverExecutor(sparkConf, driverPod, executorPod)

structObj, err := utils.MarshalObjToStruct(sparkJob)
if err != nil {
panic(err)
}

basePodPb, err := utils.MarshalObjToStruct(basePod)
if err != nil {
panic(err)
}

return &core.TaskTemplate{
Id: &core.Identifier{Name: id},
Type: "container",
Target: &core.TaskTemplate_Container{
Container: &core.Container{
Image: testImage,
Type: "k8s_pod",
Target: &core.TaskTemplate_K8SPod{
K8SPod: &core.K8SPod{
PodSpec: basePodPb,
},
},
Config: map[string]string{
Expand Down Expand Up @@ -974,6 +980,9 @@ func TestBuildResourceCustomK8SPod(t *testing.T) {
defaultConfig := defaultPluginConfig()
assert.NoError(t, config.SetK8sPluginConfig(defaultConfig))

basePodSpec := dummyPodSpec()
basePodSpec.NodeSelector = map[string]string{"x/custom": "foo"}

// add extraDriverToleration and extraExecutorToleration
driverExtraToleration := corev1.Toleration{
Key: "x/flyte-driver",
Expand Down Expand Up @@ -1008,7 +1017,7 @@ func TestBuildResourceCustomK8SPod(t *testing.T) {
Labels: map[string]string{"label-executor": "val-executor"},
},
}
taskTemplate := dummySparkTaskTemplateDriverExecutor("blah-1", dummySparkConf, driverK8SPod, executorK8SPod)
taskTemplate := dummySparkTaskTemplateDriverExecutor("blah-1", dummySparkConf, driverK8SPod, executorK8SPod, basePodSpec)
sparkResourceHandler := sparkResourceHandler{}

taskCtx := dummySparkTaskContext(taskTemplate, true, k8s.PluginState{})
Expand All @@ -1029,7 +1038,7 @@ func TestBuildResourceCustomK8SPod(t *testing.T) {
assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.ServiceAccount)
assert.Equal(t, sparkOp.PythonApplicationType, sparkApp.Spec.Type)
assert.Equal(t, testImage, *sparkApp.Spec.Image)
assert.Equal(t, testArgs, sparkApp.Spec.Arguments)
assert.Equal(t, append(testArgs, testArgs...), sparkApp.Spec.Arguments)
assert.Equal(t, sparkOp.RestartPolicy{
Type: sparkOp.OnFailure,
OnSubmissionFailureRetries: intPtr(int32(14)),
Expand All @@ -1052,7 +1061,7 @@ func TestBuildResourceCustomK8SPod(t *testing.T) {
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Driver.Env, "foo").Value)
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Driver.Env, "fooEnv").Value)
assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Driver.Env, "SECRET"))
assert.Equal(t, 9, len(sparkApp.Spec.Driver.Env))
assert.Equal(t, 11, len(sparkApp.Spec.Driver.Env))
assert.Equal(t, testImage, *sparkApp.Spec.Driver.Image)
assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.Driver.ServiceAccount)
assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Driver.SecurityContenxt)
Expand Down Expand Up @@ -1095,7 +1104,7 @@ func TestBuildResourceCustomK8SPod(t *testing.T) {
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Executor.Env, "foo").Value)
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Executor.Env, "fooEnv").Value)
assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Executor.Env, "SECRET"))
assert.Equal(t, 9, len(sparkApp.Spec.Executor.Env))
assert.Equal(t, 11, len(sparkApp.Spec.Executor.Env))
assert.Equal(t, testImage, *sparkApp.Spec.Executor.Image)
assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Executor.SecurityContenxt)
assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Executor.DNSConfig)
Expand Down

0 comments on commit 5edc331

Please sign in to comment.