From 06efe5bae5f1422eb5eed10a30fd61c45355e0ac Mon Sep 17 00:00:00 2001 From: imtzer Date: Sat, 29 Jun 2024 14:28:24 +0800 Subject: [PATCH] prevent driver pod from being deleted before its status is processed by the operator (#2054) Signed-off-by: imtzer --- pkg/controller/sparkapplication/controller.go | 38 ++++++ .../sparkapplication/controller_test.go | 117 ++++++++++++++++++ pkg/webhook/patch.go | 15 ++- pkg/webhook/webhook_test.go | 2 +- 4 files changed, 170 insertions(+), 2 deletions(-) diff --git a/pkg/controller/sparkapplication/controller.go b/pkg/controller/sparkapplication/controller.go index 3e9b373b8..2413f6e4c 100644 --- a/pkg/controller/sparkapplication/controller.go +++ b/pkg/controller/sparkapplication/controller.go @@ -51,6 +51,7 @@ import ( crdlisters "github.com/kubeflow/spark-operator/pkg/client/listers/sparkoperator.k8s.io/v1beta2" "github.com/kubeflow/spark-operator/pkg/config" "github.com/kubeflow/spark-operator/pkg/util" + "github.com/kubeflow/spark-operator/pkg/webhook" ) const ( @@ -613,6 +614,9 @@ func (c *Controller) syncSparkApplication(key string) error { return err } case v1beta2.CompletedState, v1beta2.FailedState: + if err := c.removeDriverPodFinalizer(app); err != nil { + return err + } if c.hasApplicationExpired(app) { glog.Infof("Garbage collecting expired SparkApplication %s/%s", app.Namespace, app.Name) err := c.crdClient.SparkoperatorV1beta2().SparkApplications(app.Namespace).Delete(context.TODO(), app.Name, metav1.DeleteOptions{GracePeriodSeconds: int64ptr(0)}) @@ -893,6 +897,10 @@ func (c *Controller) deleteSparkResources(app *v1beta2.SparkApplication) error { driverPodName = getDriverPodName(app) } + if err := c.removeDriverPodFinalizer(app); err != nil { + return fmt.Errorf("delete spark resource, %w", err) + } + glog.V(2).Infof("Deleting pod %s in namespace %s", driverPodName, app.Namespace) err := c.kubeClient.CoreV1().Pods(app.Namespace).Delete(context.TODO(), driverPodName, metav1.DeleteOptions{}) if err != nil && !errors.IsNotFound(err) { @@ -1125,6 +1133,36 @@ func (c *Controller) cleanUpOnTermination(oldApp, newApp *v1beta2.SparkApplicati return nil } +func (c *Controller) removeDriverPodFinalizer(app *v1beta2.SparkApplication) error { + driverPodName := app.Status.DriverInfo.PodName + if driverPodName == "" { + driverPodName = getDriverPodName(app) + } + pod, err := c.kubeClient.CoreV1().Pods(app.Namespace).Get(context.TODO(), driverPodName, metav1.GetOptions{}) + if errors.IsNotFound(err) { + return nil + } + if err != nil { + return fmt.Errorf("get driver pod %s failed, %w", driverPodName, err) + } + oldFinalizer := pod.Finalizers + var newFinalizer []string + for _, finalizer := range oldFinalizer { + if finalizer != webhook.DriverFinalize { + newFinalizer = append(newFinalizer, finalizer) + } + } + if len(oldFinalizer) != len(newFinalizer) { + pod.Finalizers = newFinalizer + _, err := c.kubeClient.CoreV1().Pods(app.Namespace).Update(context.TODO(), pod, metav1.UpdateOptions{}) + if err != nil { + return fmt.Errorf("remove driver pod finalizer failed, %w", err) + } + } + + return nil +} + func int64ptr(n int64) *int64 { return &n } diff --git a/pkg/controller/sparkapplication/controller_test.go b/pkg/controller/sparkapplication/controller_test.go index 44f9003db..1573f36d0 100644 --- a/pkg/controller/sparkapplication/controller_test.go +++ b/pkg/controller/sparkapplication/controller_test.go @@ -19,6 +19,7 @@ package sparkapplication import ( "context" "fmt" + "log" "os" "os/exec" "strings" @@ -922,6 +923,122 @@ func TestSyncSparkApplication_SubmissionSuccess(t *testing.T) { for _, test := range testcases { testFn(test, t) } + + // Test remove driver finalizer + testFn2 := func(test testcase, t *testing.T) { + ctrl, _ := newFakeController(test.app) + _, err := ctrl.crdClient.SparkoperatorV1beta2().SparkApplications(test.app.Namespace).Create(context.TODO(), test.app, metav1.CreateOptions{}) + if err != nil { + t.Fatal(err) + } + pod := &apiv1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Finalizers: []string{}, + Name: getDriverPodName(test.app), + }, + } + _, err = ctrl.kubeClient.CoreV1().Pods(test.app.Namespace).Create(context.TODO(), pod, metav1.CreateOptions{}) + if err != nil { + t.Fatal(err) + } + + execCommand = func(command string, args ...string) *exec.Cmd { + cs := []string{"-test.run=TestHelperProcessSuccess", "--", command} + cs = append(cs, args...) + cmd := exec.Command(os.Args[0], cs...) + cmd.Env = []string{"GO_WANT_HELPER_PROCESS=1"} + return cmd + } + + err = ctrl.syncSparkApplication(fmt.Sprintf("%s/%s", test.app.Namespace, test.app.Name)) + assert.Nil(t, err) + updatedApp, err := ctrl.crdClient.SparkoperatorV1beta2().SparkApplications(test.app.Namespace).Get(context.TODO(), test.app.Name, metav1.GetOptions{}) + assert.Nil(t, err) + assert.Equal(t, test.expectedState, updatedApp.Status.AppState.State) + _, err = ctrl.kubeClient.CoreV1().Pods(test.app.Namespace).Get(context.TODO(), getDriverPodName(test.app), metav1.GetOptions{}) + if !errors.IsNotFound(err) { + log.Fatal(err) + } + } + + testcases = []testcase{ + { + app: &v1beta2.SparkApplication{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo", + Namespace: "default", + }, + Status: v1beta2.SparkApplicationStatus{ + AppState: v1beta2.ApplicationState{ + State: v1beta2.SucceedingState, + }, + }, + Spec: v1beta2.SparkApplicationSpec{ + RestartPolicy: restartPolicyAlways, + }, + }, + expectedState: v1beta2.PendingRerunState, + }, + { + app: &v1beta2.SparkApplication{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo", + Namespace: "default", + }, + Status: v1beta2.SparkApplicationStatus{ + AppState: v1beta2.ApplicationState{ + State: v1beta2.FailingState, + }, + SubmissionAttempts: 1, + ExecutionAttempts: 1, + TerminationTime: metav1.Time{Time: metav1.Now().Add(-2000 * time.Second)}, + }, + Spec: v1beta2.SparkApplicationSpec{ + RestartPolicy: restartPolicyAlways, + }, + }, + expectedState: v1beta2.PendingRerunState, + }, + { + app: &v1beta2.SparkApplication{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo", + Namespace: "default", + }, + Status: v1beta2.SparkApplicationStatus{ + AppState: v1beta2.ApplicationState{ + State: v1beta2.FailedSubmissionState, + }, + SubmissionAttempts: 1, + LastSubmissionAttemptTime: metav1.Time{Time: metav1.Now().Add(-2000 * time.Second)}, + }, + Spec: v1beta2.SparkApplicationSpec{ + RestartPolicy: restartPolicyAlways, + }, + }, + expectedState: v1beta2.FailedSubmissionState, + }, + { + app: &v1beta2.SparkApplication{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo", + Namespace: "default", + }, + Status: v1beta2.SparkApplicationStatus{ + AppState: v1beta2.ApplicationState{ + State: v1beta2.InvalidatingState, + }, + }, + Spec: v1beta2.SparkApplicationSpec{ + RestartPolicy: restartPolicyOnFailure, + }, + }, + expectedState: v1beta2.PendingRerunState, + }, + } + for _, test := range testcases { + testFn2(test, t) + } } func TestSyncSparkApplication_ExecutingState(t *testing.T) { diff --git a/pkg/webhook/patch.go b/pkg/webhook/patch.go index a7c20a816..084eb1d1c 100644 --- a/pkg/webhook/patch.go +++ b/pkg/webhook/patch.go @@ -32,7 +32,9 @@ import ( ) const ( - maxNameLength = 63 + maxNameLength = 63 + addOperation = "add" + DriverFinalize = "kubeflow/spark-operator" ) // patchOperation represents a RFC6902 JSON patch operation. @@ -47,6 +49,7 @@ func patchSparkPod(pod *corev1.Pod, app *v1beta2.SparkApplication) []patchOperat if util.IsDriverPod(pod) { patchOps = append(patchOps, addOwnerReference(pod, app)) + patchOps = append(patchOps, addFinalizer(pod, app)) } patchOps = append(patchOps, addVolumes(pod, app)...) @@ -854,3 +857,13 @@ func addShareProcessNamespace(pod *corev1.Pod, app *v1beta2.SparkApplication) *p } return &patchOperation{Op: "add", Path: "/spec/shareProcessNamespace", Value: *shareProcessNamespace} } + +func addFinalizer(pod *corev1.Pod, app *v1beta2.SparkApplication) patchOperation { + var value []string + if len(pod.ObjectMeta.Finalizers) == 0 { + value = []string{DriverFinalize} + } else { + value = append(pod.Finalizers, DriverFinalize) + } + return patchOperation{Op: addOperation, Path: "/metadata/finalizers", Value: value} +} diff --git a/pkg/webhook/webhook_test.go b/pkg/webhook/webhook_test.go index 6f2e2f088..eb532be5d 100644 --- a/pkg/webhook/webhook_test.go +++ b/pkg/webhook/webhook_test.go @@ -179,7 +179,7 @@ func TestMutatePod(t *testing.T) { assert.True(t, len(response.Patch) > 0) var patchOps []*patchOperation json.Unmarshal(response.Patch, &patchOps) - assert.Equal(t, 6, len(patchOps)) + assert.Equal(t, 7, len(patchOps)) } func serializePod(pod *corev1.Pod) ([]byte, error) {