Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for displaying the Ray dashboard when a RayJob is active #4397

Merged
merged 11 commits into from
Nov 10, 2023
12 changes: 7 additions & 5 deletions flyteplugins/go/tasks/plugins/k8s/ray/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
pluginsConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/config"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/logs"
pluginmachinery "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog"
"github.com/flyteorg/flyte/flytestdlib/config"
)

Expand Down Expand Up @@ -78,11 +79,12 @@ type Config struct {
DeprecatedNodeIPAddress string `json:"nodeIPAddress,omitempty" pflag:"-,DEPRECATED. Please use DefaultConfig.[HeadNode|WorkerNode].IPAddress"`

// Remote Ray Cluster Config
RemoteClusterConfig pluginmachinery.ClusterConfig `json:"remoteClusterConfig" pflag:"Configuration of remote K8s cluster for ray jobs"`
Logs logs.LogConfig `json:"logs" pflag:"-,Log configuration for ray jobs"`
LogsSidecar *v1.Container `json:"logsSidecar" pflag:"-,Sidecar to inject into head pods for capturing ray job logs"`
Defaults DefaultConfig `json:"defaults" pflag:"-,Default configuration for ray jobs"`
EnableUsageStats bool `json:"enableUsageStats" pflag:",Enable usage stats for ray jobs. These stats are submitted to usage-stats.ray.io per https://docs.ray.io/en/latest/cluster/usage-stats.html"`
RemoteClusterConfig pluginmachinery.ClusterConfig `json:"remoteClusterConfig" pflag:"Configuration of remote K8s cluster for ray jobs"`
Logs logs.LogConfig `json:"logs" pflag:"-,Log configuration for ray jobs"`
LogsSidecar *v1.Container `json:"logsSidecar" pflag:"-,Sidecar to inject into head pods for capturing ray job logs"`
DashboardURLTemplate *tasklog.TemplateLogPlugin `json:"dashboardURLTemplate" pflag:",Template for URL of Ray dashboard running on a head node."`
Defaults DefaultConfig `json:"defaults" pflag:"-,Default configuration for ray jobs"`
EnableUsageStats bool `json:"enableUsageStats" pflag:",Enable usage stats for ray jobs. These stats are submitted to usage-stats.ray.io per https://docs.ray.io/en/latest/cluster/usage-stats.html"`
}

type DefaultConfig struct {
Expand Down
40 changes: 28 additions & 12 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"k8s.io/client-go/kubernetes/scheme"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"
flyteerr "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/logs"
Expand Down Expand Up @@ -437,26 +438,35 @@
return nil, fmt.Errorf("failed to initialize log plugins. Error: %w", err)
}

if logPlugin == nil {
return nil, nil
}

// TODO: Retrieve the name of head pod from rayJob.status, and add it to task logs
// RayJob CRD does not include the name of the worker or head pod for now
var taskLogs []*core.TaskLog

taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID()
logOutput, err := logPlugin.GetTaskLogs(tasklog.Input{
input := tasklog.Input{
Namespace: rayJob.Namespace,
TaskExecutionID: taskExecID,
})
}

// TODO: Retrieve the name of head pod from rayJob.status, and add it to task logs
// RayJob CRD does not include the name of the worker or head pod for now
logOutput, err := logPlugin.GetTaskLogs(input)
if err != nil {
return nil, fmt.Errorf("failed to generate task logs. Error: %w", err)
}
taskLogs = append(taskLogs, logOutput.TaskLogs...)

return &pluginsCore.TaskInfo{
Logs: logOutput.TaskLogs,
}, nil
// Handling for Ray Dashboard
dashboardURLTemplate := GetConfig().DashboardURLTemplate
if dashboardURLTemplate != nil &&
rayJob.Status.DashboardURL != "" &&
rayJob.Status.JobStatus == rayv1alpha1.JobStatusRunning {
dashboardURLOutput, err := dashboardURLTemplate.GetTaskLogs(input)
if err != nil {
return nil, fmt.Errorf("failed to generate Ray dashboard link. Error: %w", err)
}

Check warning on line 465 in flyteplugins/go/tasks/plugins/k8s/ray/ray.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/ray/ray.go#L464-L465

Added lines #L464 - L465 were not covered by tests
taskLogs = append(taskLogs, dashboardURLOutput.TaskLogs...)
}

return &pluginsCore.TaskInfo{Logs: taskLogs}, nil
}

func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
Expand Down Expand Up @@ -489,8 +499,14 @@
return pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, reason, info), nil
case rayv1alpha1.JobStatusSucceeded:
return pluginsCore.PhaseInfoSuccess(info), nil
case rayv1alpha1.JobStatusPending, rayv1alpha1.JobStatusRunning:
case rayv1alpha1.JobStatusPending:
return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil
case rayv1alpha1.JobStatusRunning:
phaseInfo := pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info)
if len(info.Logs) > 0 {
phaseInfo = phaseInfo.WithVersion(pluginsCore.DefaultPhaseVersion + 1)
}

Check warning on line 508 in flyteplugins/go/tasks/plugins/k8s/ray/ray.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/ray/ray.go#L507-L508

Added lines #L507 - L508 were not covered by tests
return phaseInfo, nil
case rayv1alpha1.JobStatusStopped:
// There is no current usage of this job status in KubeRay. It's unclear what it represents
fallthrough
Expand Down
84 changes: 73 additions & 11 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
pluginIOMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s"
mocks2 "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s/mocks"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils"
)

Expand Down Expand Up @@ -615,6 +616,8 @@ func newPluginContext() k8s.PluginContext {
},
},
})
taskExecID.OnGetUniqueNodeID().Return("unique-node")
taskExecID.OnGetGeneratedName().Return("generated-name")

tskCtx := &mocks.TaskExecutionMetadata{}
tskCtx.OnGetTaskExecutionID().Return(taskExecID)
Expand Down Expand Up @@ -642,17 +645,19 @@ func TestGetTaskPhase(t *testing.T) {
rayJobPhase rayv1alpha1.JobStatus
rayClusterPhase rayv1alpha1.JobDeploymentStatus
expectedCorePhase pluginsCore.Phase
expectedError bool
}{
{"", rayv1alpha1.JobDeploymentStatusInitializing, pluginsCore.PhaseInitializing},
{rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusFailedToGetOrCreateRayCluster, pluginsCore.PhasePermanentFailure},
{rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusWaitForDashboard, pluginsCore.PhaseRunning},
{rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusFailedJobDeploy, pluginsCore.PhasePermanentFailure},
{rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusRunning, pluginsCore.PhaseRunning},
{rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusFailedToGetJobStatus, pluginsCore.PhaseUndefined},
{rayv1alpha1.JobStatusRunning, rayv1alpha1.JobDeploymentStatusRunning, pluginsCore.PhaseRunning},
{rayv1alpha1.JobStatusFailed, rayv1alpha1.JobDeploymentStatusRunning, pluginsCore.PhasePermanentFailure},
{rayv1alpha1.JobStatusSucceeded, rayv1alpha1.JobDeploymentStatusRunning, pluginsCore.PhaseSuccess},
{rayv1alpha1.JobStatusSucceeded, rayv1alpha1.JobDeploymentStatusComplete, pluginsCore.PhaseSuccess},
{"", rayv1alpha1.JobDeploymentStatusInitializing, pluginsCore.PhaseInitializing, false},
{rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusFailedToGetOrCreateRayCluster, pluginsCore.PhasePermanentFailure, false},
{rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusWaitForDashboard, pluginsCore.PhaseRunning, false},
{rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusFailedJobDeploy, pluginsCore.PhasePermanentFailure, false},
{rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusRunning, pluginsCore.PhaseRunning, false},
{rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusFailedToGetJobStatus, pluginsCore.PhaseRunning, false},
{rayv1alpha1.JobStatusRunning, rayv1alpha1.JobDeploymentStatusRunning, pluginsCore.PhaseRunning, false},
{rayv1alpha1.JobStatusFailed, rayv1alpha1.JobDeploymentStatusRunning, pluginsCore.PhasePermanentFailure, false},
{rayv1alpha1.JobStatusSucceeded, rayv1alpha1.JobDeploymentStatusRunning, pluginsCore.PhaseSuccess, false},
{rayv1alpha1.JobStatusSucceeded, rayv1alpha1.JobDeploymentStatusComplete, pluginsCore.PhaseSuccess, false},
{rayv1alpha1.JobStatusStopped, rayv1alpha1.JobDeploymentStatusComplete, pluginsCore.PhaseUndefined, true},
}

for _, tc := range testCases {
Expand All @@ -663,12 +668,69 @@ func TestGetTaskPhase(t *testing.T) {
startTime := metav1.NewTime(time.Now())
rayObject.Status.StartTime = &startTime
phaseInfo, err := rayJobResourceHandler.GetTaskPhase(ctx, pluginCtx, rayObject)
assert.Nil(t, err)
if tc.expectedError {
assert.Error(t, err)
} else {
assert.Nil(t, err)
}
assert.Equal(t, tc.expectedCorePhase.String(), phaseInfo.Phase().String())
})
}
}

func TestGetEventInfo_DashboardURL(t *testing.T) {
pluginCtx := newPluginContext()
testCases := []struct {
name string
rayJob rayv1alpha1.RayJob
dashboardURLTemplate tasklog.TemplateLogPlugin
expectedTaskLogs []*core.TaskLog
}{
{
name: "dashboard URL displayed",
rayJob: rayv1alpha1.RayJob{
Status: rayv1alpha1.RayJobStatus{
DashboardURL: "exists",
JobStatus: rayv1alpha1.JobStatusRunning,
},
},
dashboardURLTemplate: tasklog.TemplateLogPlugin{
DisplayName: "Ray Dashboard",
TemplateURIs: []tasklog.TemplateURI{"http://test/{{.generatedName}}"},
Scheme: tasklog.TemplateSchemeTaskExecution,
},
expectedTaskLogs: []*core.TaskLog{
{
Name: "Ray Dashboard",
Uri: "http://test/generated-name",
},
},
},
{
name: "dashboard URL is not displayed",
rayJob: rayv1alpha1.RayJob{
Status: rayv1alpha1.RayJobStatus{
JobStatus: rayv1alpha1.JobStatusPending,
},
},
dashboardURLTemplate: tasklog.TemplateLogPlugin{
DisplayName: "dummy",
TemplateURIs: []tasklog.TemplateURI{"http://dummy"},
},
expectedTaskLogs: nil,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.NoError(t, SetConfig(&Config{DashboardURLTemplate: &tc.dashboardURLTemplate}))
ti, err := getEventInfoForRayJob(logs.LogConfig{}, pluginCtx, &tc.rayJob)
assert.NoError(t, err)
assert.Equal(t, tc.expectedTaskLogs, ti.Logs)
})
}
}

func TestGetPropertiesRay(t *testing.T) {
rayJobResourceHandler := rayJobResourceHandler{}
expected := k8s.PluginProperties{}
Expand Down
Loading