From f79d10c9a2d80c3076f4be22694a3102fb9e0d76 Mon Sep 17 00:00:00 2001 From: qzhu Date: Fri, 4 Oct 2024 12:43:13 -0500 Subject: [PATCH] [YUNIKORN-2892] Log correct termination type when releasing task in shim (#917) Closes: #917 Signed-off-by: Craig Condit --- pkg/cache/task.go | 6 ++++-- pkg/common/si_helper.go | 4 ++-- pkg/common/si_helper_test.go | 29 +++++++++++++++++++++++++++-- 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/pkg/cache/task.go b/pkg/cache/task.go index 5c0be9dec..8b758b6f6 100644 --- a/pkg/cache/task.go +++ b/pkg/cache/task.go @@ -494,6 +494,8 @@ func (task *Task) beforeTaskCompleted() { // releaseAllocation sends the release request for the Allocation to the core. func (task *Task) releaseAllocation() { + terminationType := common.GetTerminationTypeFromString(task.terminationType) + // scheduler api might be nil in some tests if task.context.apiProvider.GetAPIs().SchedulerAPI != nil { log.Log(log.ShimCacheTask).Debug("prepare to send release request", @@ -502,7 +504,7 @@ func (task *Task) releaseAllocation() { zap.String("taskAlias", task.alias), zap.String("allocationKey", task.allocationKey), zap.String("task", task.GetTaskState()), - zap.String("terminationType", task.terminationType)) + zap.String("terminationType", string(terminationType))) // send an AllocationReleaseRequest var releaseRequest *si.AllocationRequest @@ -526,7 +528,7 @@ func (task *Task) releaseAllocation() { task.applicationID, task.taskID, task.application.partition, - task.terminationType, + terminationType, ) if releaseRequest.Releases != nil { diff --git a/pkg/common/si_helper.go b/pkg/common/si_helper.go index 82a22530c..d94a708cf 100644 --- a/pkg/common/si_helper.go +++ b/pkg/common/si_helper.go @@ -121,13 +121,13 @@ func GetTerminationTypeFromString(terminationTypeStr string) si.TerminationType return si.TerminationType_STOPPED_BY_RM } -func CreateReleaseRequestForTask(appID, taskID, partition, terminationType string) *si.AllocationRequest { +func CreateReleaseRequestForTask(appID, taskID, partition string, terminationType si.TerminationType) *si.AllocationRequest { allocToRelease := make([]*si.AllocationRelease, 1) allocToRelease[0] = &si.AllocationRelease{ ApplicationID: appID, AllocationKey: taskID, PartitionName: partition, - TerminationType: GetTerminationTypeFromString(terminationType), + TerminationType: terminationType, Message: "task completed", } diff --git a/pkg/common/si_helper_test.go b/pkg/common/si_helper_test.go index b9464bd8c..9ccf619a9 100644 --- a/pkg/common/si_helper_test.go +++ b/pkg/common/si_helper_test.go @@ -32,7 +32,7 @@ const nodeID = "node-01" func TestCreateReleaseRequestForTask(t *testing.T) { // with allocationKey - request := CreateReleaseRequestForTask("app01", "task01", "default", "STOPPED_BY_RM") + request := CreateReleaseRequestForTask("app01", "task01", "default", si.TerminationType_STOPPED_BY_RM) assert.Assert(t, request.Releases != nil) assert.Assert(t, request.Releases.AllocationsToRelease != nil) assert.Equal(t, len(request.Releases.AllocationsToRelease), 1) @@ -41,7 +41,7 @@ func TestCreateReleaseRequestForTask(t *testing.T) { assert.Equal(t, request.Releases.AllocationsToRelease[0].PartitionName, "default") assert.Equal(t, request.Releases.AllocationsToRelease[0].TerminationType, si.TerminationType_STOPPED_BY_RM) - request = CreateReleaseRequestForTask("app01", "task01", "default", "UNKNOWN_TERMINATION_TYPE") + request = CreateReleaseRequestForTask("app01", "task01", "default", si.TerminationType_UNKNOWN_TERMINATION_TYPE) assert.Assert(t, request.Releases != nil) assert.Assert(t, request.Releases.AllocationsToRelease != nil) assert.Equal(t, len(request.Releases.AllocationsToRelease), 1) @@ -390,3 +390,28 @@ func TestCreateAllocationForTask(t *testing.T) { assert.Equal(t, tags[common.DomainK8s+common.GroupMeta+"podName"], podName1) assert.Equal(t, alloc1.Priority, int32(100)) } + +// TestGetTerminationTypeFromString tests the GetTerminationTypeFromString function. +func TestGetTerminationTypeFromString(t *testing.T) { + tests := []struct { + input string + expected si.TerminationType + }{ + {"UNKNOWN_TERMINATION_TYPE", si.TerminationType_UNKNOWN_TERMINATION_TYPE}, + {"STOPPED_BY_RM", si.TerminationType_STOPPED_BY_RM}, + {"TIMEOUT", si.TerminationType_TIMEOUT}, + {"PREEMPTED_BY_SCHEDULER", si.TerminationType_PREEMPTED_BY_SCHEDULER}, + {"PLACEHOLDER_REPLACED", si.TerminationType_PLACEHOLDER_REPLACED}, + {"INVALID_TYPE", si.TerminationType_STOPPED_BY_RM}, + {"", si.TerminationType_STOPPED_BY_RM}, + } + + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + result := GetTerminationTypeFromString(test.input) + if result != test.expected { + t.Errorf("For input '%s', expected %v, got %v", test.input, test.expected, result) + } + }) + } +}