Skip to content

Commit

Permalink
feat: new pulling mechanism for job statuses
Browse files Browse the repository at this point in the history
Update the driver to keep running even the user program
finishes. The driver provides two APIs:
- GetStatus(): retrieve job status
- Shutdown(): properly tear down the driver

In the controller side, it uses `pod/exec` resource
to run the driver command to invoke the driver APIs
to retrieve the job status and shutdown the driver
when job is done.

Signed-off-by: Yihong Wang <[email protected]>
  • Loading branch information
yhwang committed Oct 10, 2024
1 parent b2bec12 commit 741224e
Show file tree
Hide file tree
Showing 20 changed files with 590 additions and 1,749 deletions.
107 changes: 79 additions & 28 deletions cmd/lmes_driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ package main

import (
"context"
"encoding/json"
"flag"
"fmt"
"io"
"os"
"strings"
"time"

ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/log"
Expand All @@ -50,17 +50,14 @@ func (t *strArrayArg) String() string {
}

var (
taskRecipes strArrayArg
customCards strArrayArg
copy = flag.String("copy", "", "copy this binary to specified destination path")
jobNameSpace = flag.String("job-namespace", "", "Job's namespace ")
jobName = flag.String("job-name", "", "Job's name")
grpcService = flag.String("grpc-service", "", "grpc service name")
grpcPort = flag.Int("grpc-port", 8082, "grpc port")
outputPath = flag.String("output-path", OutputPath, "output path")
detectDevice = flag.Bool("detect-device", false, "detect available device(s), CUDA or CPU")
reportInterval = flag.Duration("report-interval", time.Second*10, "specify the druation interval to report the progress")
driverLog = ctrl.Log.WithName("driver")
taskRecipes strArrayArg
customCards strArrayArg
copy = flag.String("copy", "", "copy this binary to specified destination path")
getStatus = flag.Bool("get-status", false, "Get current status")
shutdown = flag.Bool("shutdown", false, "Shutdown the driver")
outputPath = flag.String("output-path", OutputPath, "output path")
detectDevice = flag.Bool("detect-device", false, "detect available device(s), CUDA or CPU")
driverLog = ctrl.Log.WithName("driver")
)

func init() {
Expand All @@ -83,7 +80,7 @@ func main() {

if *copy != "" {
// copy exec to destination
if err := CopyExec(*copy); err != nil {
if err := copyExec(*copy); err != nil {
driverLog.Error(err, "failed to copy binary")
os.Exit(1)
return
Expand All @@ -92,29 +89,34 @@ func main() {
return
}

if *getStatus {
getStatusOrDie(ctx)
return
}

if *shutdown {
shutdownOrDie(ctx)
return
}

if len(args) == 0 {
driverLog.Error(fmt.Errorf("no user program"), "empty args")
os.Exit(1)
}

driverOpt := driver.DriverOption{
Context: ctx,
JobNamespace: *jobNameSpace,
JobName: *jobName,
OutputPath: *outputPath,
GrpcService: *grpcService,
GrpcPort: *grpcPort,
DetectDevice: *detectDevice,
Logger: driverLog,
TaskRecipes: taskRecipes,
CustomCards: customCards,
Args: args,
ReportInterval: *reportInterval,
Context: ctx,
OutputPath: *outputPath,
DetectDevice: *detectDevice,
Logger: driverLog,
TaskRecipes: taskRecipes,
CustomCards: customCards,
Args: args,
}

driver, err := driver.NewDriver(&driverOpt)
if err != nil {
driverLog.Error(err, "Driver.Run failed")
driverLog.Error(err, "Driver.NewDriver failed")
os.Exit(1)
}

Expand All @@ -123,11 +125,10 @@ func main() {
driverLog.Error(err, "Driver.Run failed")
exitCode = 1
}
driver.Cleanup()
os.Exit(exitCode)
}

func CopyExec(destination string) (err error) {
func copyExec(destination string) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("copy this binary to %s: %w", destination, err)
Expand Down Expand Up @@ -161,3 +162,53 @@ func findThisBinary() (string, error) {
}
return bin, nil
}

func getStatusOrDie(ctx context.Context) {
driver, err := driver.NewDriver(&driver.DriverOption{
Context: ctx,
OutputPath: *outputPath,
DetectDevice: *detectDevice,
Logger: driverLog,
})

if err != nil {
driverLog.Error(err, "failed to initialize the driver")
os.Exit(1)
}

status, err := driver.GetStatus()
if err != nil {
driverLog.Error(err, "failed to get status", "error", err.Error())
os.Exit(1)
}

b, err := json.Marshal(status)
if err != nil {
driverLog.Error(err, "json serialization failed", "error", err.Error())
os.Exit(1)
}

fmt.Print(string(b))
os.Exit(0)
}

func shutdownOrDie(ctx context.Context) {
driver, err := driver.NewDriver(&driver.DriverOption{
Context: ctx,
OutputPath: *outputPath,
DetectDevice: *detectDevice,
Logger: driverLog,
})

if err != nil {
driverLog.Error(err, "failed to initialize the driver")
os.Exit(1)
}

err = driver.Shutdown()
if err != nil {
driverLog.Error(err, "failed to shutdown", "error", err.Error())
os.Exit(1)
}
os.Exit(0)
}
29 changes: 6 additions & 23 deletions cmd/lmes_driver/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"flag"
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/trustyai-explainability/trustyai-service-operator/controllers/lmes/driver"
Expand All @@ -31,13 +30,8 @@ import (
func Test_ArgParsing(t *testing.T) {
os.Args = []string{
"/opt/app-root/src/bin/driver",
"--job-namespace", "default",
"--job-name", "test",
"--grpc-service", "grpc-service.test.svc",
"--grpc-port", "8088",
"--output-path", "/opt/app-root/src/output",
"--detect-device",
"--report-interval", "10s",
"--task-recipe", "card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10",
"--task-recipe", "card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10",
"--",
Expand All @@ -53,30 +47,19 @@ func Test_ArgParsing(t *testing.T) {

args := flag.Args()

assert.Equal(t, "default", *jobNameSpace)
assert.Equal(t, "test", *jobName)
assert.Equal(t, "grpc-service.test.svc", *grpcService)
assert.Equal(t, 8088, *grpcPort)
assert.Equal(t, "/opt/app-root/src/output", *outputPath)
assert.Equal(t, true, *detectDevice)
assert.Equal(t, time.Second*10, *reportInterval)
assert.Equal(t, strArrayArg{
"card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10",
"card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10",
}, taskRecipes)

dOption := driver.DriverOption{
Context: context.Background(),
JobNamespace: *jobNameSpace,
JobName: *jobName,
OutputPath: *outputPath,
GrpcService: *grpcService,
GrpcPort: *grpcPort,
DetectDevice: *detectDevice,
Logger: driverLog,
TaskRecipes: taskRecipes,
Args: args,
ReportInterval: *reportInterval,
Context: context.Background(),
OutputPath: *outputPath,
DetectDevice: *detectDevice,
Logger: driverLog,
TaskRecipes: taskRecipes,
Args: args,
}

assert.Equal(t, []string{
Expand Down
3 changes: 1 addition & 2 deletions config/base/params.env
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ lmes-driver-image=quay.io/trustyai/ta-lmes-driver:latest
lmes-pod-image=quay.io/trustyai/ta-lmes-job:latest
lmes-pod-checking-interval=10s
lmes-image-pull-policy=Always
lmes-grpc-service=lmes-grpc
lmes-grpc-port=8082
lmes-max-batch-size=24
lmes-default-batch-size=8
lmes-detect-device=true
1 change: 0 additions & 1 deletion config/manager/kustomization.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
resources:
- manager.yaml
- service.yaml
apiVersion: kustomize.config.k8s.io/v1beta1
kind: Kustomization
14 changes: 0 additions & 14 deletions config/manager/service.yaml

This file was deleted.

14 changes: 0 additions & 14 deletions config/overlays/lmes/kustomization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,5 @@ kind: Kustomization
resources:
- ../../base

replacements:
- source:
kind: Service
version: v1
name: lmes-grpc
fieldPath: .metadata.name
targets:
- select:
kind: ConfigMap
version: v1
name: config
fieldPaths:
- .data.lmes-grpc-service

patchesStrategicMerge:
- lmes-only-patch.yaml
14 changes: 0 additions & 14 deletions config/overlays/odh/kustomization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,3 @@ apiVersion: kustomize.config.k8s.io/v1beta1
kind: Kustomization
resources:
- ../../base

replacements:
- source:
kind: Service
version: v1
name: lmes-grpc
fieldPath: .metadata.name
targets:
- select:
kind: ConfigMap
version: v1
name: config
fieldPaths:
- .data.lmes-grpc-service
10 changes: 10 additions & 0 deletions config/rbac/role.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ rules:
- patch
- update
- watch
- apiGroups:
- ""
resources:
- pods/exec
verbs:
- create
- delete
- get
- list
- watch
- apiGroups:
- ""
resources:
Expand Down
Loading

0 comments on commit 741224e

Please sign in to comment.