Skip to content

Commit

Permalink
feat: new pulling mechanism for job statuses (#314)
Browse files Browse the repository at this point in the history
Update the driver to keep running even the user program
finishes. The driver provides two APIs:
- GetStatus(): retrieve job status
- Shutdown(): properly tear down the driver

In the controller side, it uses `pod/exec` resource
to run the driver command to invoke the driver APIs
to retrieve the job status and shutdown the driver
when job is done.

Signed-off-by: Yihong Wang <[email protected]>
  • Loading branch information
yhwang authored Oct 13, 2024
1 parent b2bec12 commit ab6bc98
Show file tree
Hide file tree
Showing 20 changed files with 590 additions and 1,749 deletions.
107 changes: 79 additions & 28 deletions cmd/lmes_driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ package main

import (
"context"
"encoding/json"
"flag"
"fmt"
"io"
"os"
"strings"
"time"

ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/log"
Expand All @@ -50,17 +50,14 @@ func (t *strArrayArg) String() string {
}

var (
taskRecipes strArrayArg
customCards strArrayArg
copy = flag.String("copy", "", "copy this binary to specified destination path")
jobNameSpace = flag.String("job-namespace", "", "Job's namespace ")
jobName = flag.String("job-name", "", "Job's name")
grpcService = flag.String("grpc-service", "", "grpc service name")
grpcPort = flag.Int("grpc-port", 8082, "grpc port")
outputPath = flag.String("output-path", OutputPath, "output path")
detectDevice = flag.Bool("detect-device", false, "detect available device(s), CUDA or CPU")
reportInterval = flag.Duration("report-interval", time.Second*10, "specify the druation interval to report the progress")
driverLog = ctrl.Log.WithName("driver")
taskRecipes strArrayArg
customCards strArrayArg
copy = flag.String("copy", "", "copy this binary to specified destination path")
getStatus = flag.Bool("get-status", false, "Get current status")
shutdown = flag.Bool("shutdown", false, "Shutdown the driver")
outputPath = flag.String("output-path", OutputPath, "output path")
detectDevice = flag.Bool("detect-device", false, "detect available device(s), CUDA or CPU")
driverLog = ctrl.Log.WithName("driver")
)

func init() {
Expand All @@ -83,7 +80,7 @@ func main() {

if *copy != "" {
// copy exec to destination
if err := CopyExec(*copy); err != nil {
if err := copyExec(*copy); err != nil {
driverLog.Error(err, "failed to copy binary")
os.Exit(1)
return
Expand All @@ -92,29 +89,34 @@ func main() {
return
}

if *getStatus {
getStatusOrDie(ctx)
return
}

if *shutdown {
shutdownOrDie(ctx)
return
}

if len(args) == 0 {
driverLog.Error(fmt.Errorf("no user program"), "empty args")
os.Exit(1)
}

driverOpt := driver.DriverOption{
Context: ctx,
JobNamespace: *jobNameSpace,
JobName: *jobName,
OutputPath: *outputPath,
GrpcService: *grpcService,
GrpcPort: *grpcPort,
DetectDevice: *detectDevice,
Logger: driverLog,
TaskRecipes: taskRecipes,
CustomCards: customCards,
Args: args,
ReportInterval: *reportInterval,
Context: ctx,
OutputPath: *outputPath,
DetectDevice: *detectDevice,
Logger: driverLog,
TaskRecipes: taskRecipes,
CustomCards: customCards,
Args: args,
}

driver, err := driver.NewDriver(&driverOpt)
if err != nil {
driverLog.Error(err, "Driver.Run failed")
driverLog.Error(err, "Driver.NewDriver failed")
os.Exit(1)
}

Expand All @@ -123,11 +125,10 @@ func main() {
driverLog.Error(err, "Driver.Run failed")
exitCode = 1
}
driver.Cleanup()
os.Exit(exitCode)
}

func CopyExec(destination string) (err error) {
func copyExec(destination string) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("copy this binary to %s: %w", destination, err)
Expand Down Expand Up @@ -161,3 +162,53 @@ func findThisBinary() (string, error) {
}
return bin, nil
}

func getStatusOrDie(ctx context.Context) {
driver, err := driver.NewDriver(&driver.DriverOption{
Context: ctx,
OutputPath: *outputPath,
DetectDevice: *detectDevice,
Logger: driverLog,
})

if err != nil {
driverLog.Error(err, "failed to initialize the driver")
os.Exit(1)
}

status, err := driver.GetStatus()
if err != nil {
driverLog.Error(err, "failed to get status", "error", err.Error())
os.Exit(1)
}

b, err := json.Marshal(status)
if err != nil {
driverLog.Error(err, "json serialization failed", "error", err.Error())
os.Exit(1)
}

fmt.Print(string(b))
os.Exit(0)
}

func shutdownOrDie(ctx context.Context) {
driver, err := driver.NewDriver(&driver.DriverOption{
Context: ctx,
OutputPath: *outputPath,
DetectDevice: *detectDevice,
Logger: driverLog,
})

if err != nil {
driverLog.Error(err, "failed to initialize the driver")
os.Exit(1)
}

err = driver.Shutdown()
if err != nil {
driverLog.Error(err, "failed to shutdown", "error", err.Error())
os.Exit(1)
}
os.Exit(0)
}
29 changes: 6 additions & 23 deletions cmd/lmes_driver/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"flag"
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/trustyai-explainability/trustyai-service-operator/controllers/lmes/driver"
Expand All @@ -31,13 +30,8 @@ import (
func Test_ArgParsing(t *testing.T) {
os.Args = []string{
"/opt/app-root/src/bin/driver",
"--job-namespace", "default",
"--job-name", "test",
"--grpc-service", "grpc-service.test.svc",
"--grpc-port", "8088",
"--output-path", "/opt/app-root/src/output",
"--detect-device",
"--report-interval", "10s",
"--task-recipe", "card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10",
"--task-recipe", "card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10",
"--",
Expand All @@ -53,30 +47,19 @@ func Test_ArgParsing(t *testing.T) {

args := flag.Args()

assert.Equal(t, "default", *jobNameSpace)
assert.Equal(t, "test", *jobName)
assert.Equal(t, "grpc-service.test.svc", *grpcService)
assert.Equal(t, 8088, *grpcPort)
assert.Equal(t, "/opt/app-root/src/output", *outputPath)
assert.Equal(t, true, *detectDevice)
assert.Equal(t, time.Second*10, *reportInterval)
assert.Equal(t, strArrayArg{
"card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10",
"card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10",
}, taskRecipes)

dOption := driver.DriverOption{
Context: context.Background(),
JobNamespace: *jobNameSpace,
JobName: *jobName,
OutputPath: *outputPath,
GrpcService: *grpcService,
GrpcPort: *grpcPort,
DetectDevice: *detectDevice,
Logger: driverLog,
TaskRecipes: taskRecipes,
Args: args,
ReportInterval: *reportInterval,
Context: context.Background(),
OutputPath: *outputPath,
DetectDevice: *detectDevice,
Logger: driverLog,
TaskRecipes: taskRecipes,
Args: args,
}

assert.Equal(t, []string{
Expand Down
3 changes: 1 addition & 2 deletions config/base/params.env
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ lmes-driver-image=quay.io/trustyai/ta-lmes-driver:latest
lmes-pod-image=quay.io/trustyai/ta-lmes-job:latest
lmes-pod-checking-interval=10s
lmes-image-pull-policy=Always
lmes-grpc-service=lmes-grpc
lmes-grpc-port=8082
lmes-max-batch-size=24
lmes-default-batch-size=8
lmes-detect-device=true
1 change: 0 additions & 1 deletion config/manager/kustomization.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
resources:
- manager.yaml
- service.yaml
apiVersion: kustomize.config.k8s.io/v1beta1
kind: Kustomization
14 changes: 0 additions & 14 deletions config/manager/service.yaml

This file was deleted.

14 changes: 0 additions & 14 deletions config/overlays/lmes/kustomization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,5 @@ kind: Kustomization
resources:
- ../../base

replacements:
- source:
kind: Service
version: v1
name: lmes-grpc
fieldPath: .metadata.name
targets:
- select:
kind: ConfigMap
version: v1
name: config
fieldPaths:
- .data.lmes-grpc-service

patchesStrategicMerge:
- lmes-only-patch.yaml
14 changes: 0 additions & 14 deletions config/overlays/odh/kustomization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,3 @@ apiVersion: kustomize.config.k8s.io/v1beta1
kind: Kustomization
resources:
- ../../base

replacements:
- source:
kind: Service
version: v1
name: lmes-grpc
fieldPath: .metadata.name
targets:
- select:
kind: ConfigMap
version: v1
name: config
fieldPaths:
- .data.lmes-grpc-service
10 changes: 10 additions & 0 deletions config/rbac/role.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ rules:
- patch
- update
- watch
- apiGroups:
- ""
resources:
- pods/exec
verbs:
- create
- delete
- get
- list
- watch
- apiGroups:
- ""
resources:
Expand Down
Loading

0 comments on commit ab6bc98

Please sign in to comment.