Skip to content

Commit

Permalink
add unit tests for terminatesubtasks
Browse files Browse the repository at this point in the history
  • Loading branch information
pvditt committed Dec 4, 2023
1 parent 8eeb643 commit 324cfa7
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 16 deletions.
26 changes: 10 additions & 16 deletions flyteplugins/go/tasks/plugins/array/k8s/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,29 +351,23 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube
return err
}

// return immediately if subtask has completed or not yet started
if existingPhase.IsTerminal() || existingPhase == core.PhaseUndefined {
// still write subtask to buffer to persist to admin
externalResources = append(externalResources, &core.ExternalResource{
ExternalID: stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(),
Index: uint32(originalIdx),
RetryAttempt: uint32(retryAttempt),
Phase: existingPhase,
})
continue
}

err = terminateFunction(ctx, stCtx, config, kubeClient)
if err != nil {
messageCollector.Collect(childIdx, err.Error())
isAbortedSubtask := false
if !existingPhase.IsTerminal() && existingPhase != core.PhaseUndefined {
// only terminate subtask if it has completed or has not yet started
err = terminateFunction(ctx, stCtx, config, kubeClient)
if err != nil {
messageCollector.Collect(childIdx, err.Error())
} else {
isAbortedSubtask = true
}
}

externalResources = append(externalResources, &core.ExternalResource{
ExternalID: stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(),
Index: uint32(originalIdx),
RetryAttempt: uint32(retryAttempt),
Phase: existingPhase,
IsAbortedSubtask: true,
IsAbortedSubtask: isAbortedSubtask,
})
}

Expand Down
93 changes: 93 additions & 0 deletions flyteplugins/go/tasks/plugins/array/k8s/management_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -538,3 +538,96 @@ func TestCheckSubTasksState(t *testing.T) {
}
})
}

func TestTerminateSubtasks(t *testing.T) {
ctx := context.Background()
subtaskCount := 3
config := Config{
MaxArrayJobSize: int64(subtaskCount * 10),
ResourceConfig: ResourceConfig{
PrimaryLabel: "p",
Limit: subtaskCount,
},
}
kubeClient := mocks.KubeClient{}
kubeClient.OnGetClient().Return(mocks.NewFakeKubeClient())
kubeClient.OnGetCache().Return(mocks.NewFakeKubeCache())

compactArray := arrayCore.NewPhasesCompactArray(uint(subtaskCount))
compactArray.SetItem(0, 8) // PhasePermanentFailure
compactArray.SetItem(1, 0) // PhaseUndefined
compactArray.SetItem(2, 5) // PhaseRunning

currentState := &arrayCore.State{
CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions,
ExecutionArraySize: subtaskCount,
OriginalArraySize: int64(subtaskCount),
OriginalMinSuccesses: int64(subtaskCount),
ArrayStatus: arraystatus.ArrayStatus{
Detailed: compactArray,
},
IndexesToCache: arrayCore.InvertBitSet(bitarray.NewBitSet(uint(subtaskCount)), uint(subtaskCount)), // set all tasks to be cached
}

t.Run("TerminateSubtasks", func(t *testing.T) {
resourceManager := mocks.ResourceManager{}
resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusGranted, nil)
eventRecorder := mocks.EventsRecorder{}
eventRecorder.OnRecordRawMatch(mock.Anything, mock.Anything).Return(nil)
tCtx := getMockTaskExecutionContext(ctx, 0)
tCtx.OnResourceManager().Return(&resourceManager)
tCtx.OnEventsRecorder().Return(&eventRecorder)

terminateCounter := 0
mockTerminateFunction := func(ctx context.Context, subTaskCtx SubTaskExecutionContext, cfg *Config, kubeClient core.KubeClient) error {
terminateCounter++
return nil
}

err := TerminateSubTasks(ctx, tCtx, &kubeClient, &config, mockTerminateFunction, currentState)
assert.Equal(t, 1, terminateCounter)
assert.Nil(t, err)

args := eventRecorder.Calls[0].Arguments
phaseInfo, ok := args.Get(1).(core.PhaseInfo)
assert.True(t, ok)

externalResources := phaseInfo.Info().ExternalResources
assert.Len(t, externalResources, subtaskCount)

assert.False(t, externalResources[0].IsAbortedSubtask)
assert.False(t, externalResources[1].IsAbortedSubtask)
assert.True(t, externalResources[2].IsAbortedSubtask)
})

t.Run("TerminateSubtasksWithFailure", func(t *testing.T) {
resourceManager := mocks.ResourceManager{}
resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(core.AllocationStatusGranted, nil)
eventRecorder := mocks.EventsRecorder{}
eventRecorder.OnRecordRawMatch(mock.Anything, mock.Anything).Return(nil)
tCtx := getMockTaskExecutionContext(ctx, 0)
tCtx.OnResourceManager().Return(&resourceManager)
tCtx.OnEventsRecorder().Return(&eventRecorder)

terminateCounter := 0
mockTerminateFunction := func(ctx context.Context, subTaskCtx SubTaskExecutionContext, cfg *Config, kubeClient core.KubeClient) error {
terminateCounter++
return fmt.Errorf("error")
}

err := TerminateSubTasks(ctx, tCtx, &kubeClient, &config, mockTerminateFunction, currentState)
assert.NotNil(t, err)
assert.Equal(t, 1, terminateCounter)

args := eventRecorder.Calls[0].Arguments
phaseInfo, ok := args.Get(1).(core.PhaseInfo)
assert.True(t, ok)

externalResources := phaseInfo.Info().ExternalResources
assert.Len(t, externalResources, subtaskCount)

assert.False(t, externalResources[0].IsAbortedSubtask)
assert.False(t, externalResources[1].IsAbortedSubtask)
assert.False(t, externalResources[2].IsAbortedSubtask)
})
}
1 change: 1 addition & 0 deletions flytepropeller/pkg/controller/nodes/task/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ func (t Handler) Abort(ctx context.Context, nCtx interfaces.NodeExecutionContext
logger.Errorf(ctx, "Abort failed when calling plugin abort.")
return err
}

evRecorder := nCtx.EventsRecorder()
logger.Debugf(ctx, "Sending buffered Task events.")
for _, ev := range tCtx.ber.GetAll(ctx) {
Expand Down

0 comments on commit 324cfa7

Please sign in to comment.