Skip to content

Commit

Permalink
feat: add device detection in lmes driver
Browse files Browse the repository at this point in the history
Added a new feature in LMES driver to detect the available
devices by using the PyTorch API. This feature can be disabled
by passing the `--detect-device false` option.

Signed-off-by: Yihong Wang <[email protected]>
  • Loading branch information
yhwang committed Sep 16, 2024
1 parent db7ae08 commit 3ed6af0
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 5 deletions.
2 changes: 2 additions & 0 deletions cmd/lmes_driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ var (
grpcService = flag.String("grpc-service", "", "grpc service name")
grpcPort = flag.Int("grpc-port", 8082, "grpc port")
outputPath = flag.String("output-path", OutputPath, "output path")
detectDevice = flag.Bool("detect-device", true, "detect available device(s), CUDA or CPU")
reportInterval = flag.Duration("report-interval", time.Second*10, "specify the druation interval to report the progress")
driverLog = ctrl.Log.WithName("driver")
)
Expand Down Expand Up @@ -83,6 +84,7 @@ func main() {
OutputPath: *outputPath,
GrpcService: *grpcService,
GrpcPort: *grpcPort,
DetectDevice: *detectDevice,
Logger: driverLog,
Args: args,
ReportInterval: *reportInterval,
Expand Down
2 changes: 2 additions & 0 deletions controllers/lmes/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ const (
GrpcClientSecretKey = "lmes-grpc-client-secret"
MaxBatchSizeKey = "lmes-max-batch-size"
DefaultBatchSizeKey = "lmes-default-batch-size"
DetectDeviceKey = "lmes-detect-device"
DriverReportIntervalKey = "driver-report-interval"
GrpcServerCertEnv = "GRPC_SERVER_CERT"
GrpcServerKeyEnv = "GRPC_SERVER_KEY"
Expand All @@ -51,5 +52,6 @@ const (
DefaultGrpcClientSecret = "grpc-client-cert"
DefaultMaxBatchSize = 24
DefaultBatchSize = 8
DefaultDetectDevice = true
ServiceName = "LMES"
)
53 changes: 53 additions & 0 deletions controllers/lmes/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type DriverOption struct {
GrpcService string
GrpcPort int
OutputPath string
DetectDevice bool
Logger logr.Logger
Args []string
ReportInterval time.Duration
Expand Down Expand Up @@ -200,7 +201,59 @@ func getGRPCClientConn(option *DriverOption) (clientConn *grpc.ClientConn, err e
return
}

func (d *driverImpl) detectDevice() error {
if d == nil || !d.Option.DetectDevice {
return nil
}

// assuming python and torch python package are available.
// use torch python API to detect CUDA's availability
out, err := exec.Command(
"python",
"-c",
"import torch; print('=={}:{}=='.format(torch.cuda.is_available(), torch.cuda.device_count()));",
).Output()
if err != nil {
return fmt.Errorf("failed to detect available device(s): %v", err)
}

re := regexp.MustCompile(`(?m)^==(True|False):(\d+?)==$`)
matches := re.FindStringSubmatch(string(out))
if matches == nil {
return fmt.Errorf("failed to find the matched output")
}

patchDevice(d.Option.Args, matches[1] == "True")

return nil
}

func patchDevice(args []string, hasCuda bool) {
var device = "cpu"
if hasCuda {
device = "cuda"
}
// patch the python command in the Option.Arg by adding the `--device cuda` option
// find the string with the `python -m lm_eval` prefix. usually it should be the last one
for idx, arg := range args {
if strings.HasPrefix(arg, "python -m lm_eval") {
if !strings.Contains(arg, "--device") {
args[idx] = fmt.Sprintf("%s --device %s", arg, device)
}
break
}
}
}

func (d *driverImpl) exec() error {

// Detect available devices if needed
if err := d.detectDevice(); err != nil {
return err
}

fmt.Printf("%q\n", d.Option.Args)

// Run user program.
var args []string
if len(d.Option.Args) > 1 {
Expand Down
73 changes: 73 additions & 0 deletions controllers/lmes/driver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,76 @@ func Test_ProgressUpdate(t *testing.T) {
assert.Nil(t, os.Remove("./stderr.log"))
assert.Nil(t, os.Remove("./stdout.log"))
}

func Test_DetectDeviceError(t *testing.T) {
server := grpc.NewServer()
progresssServer := ProgressUpdateServer{}
v1beta1.RegisterLMEvalJobUpdateServiceServer(server, &progresssServer)
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", 8082))
assert.Nil(t, err)
go server.Serve(lis)

driver, err := NewDriver(&DriverOption{
Context: context.Background(),
JobNamespace: "fms-lm-eval-service-system",
JobName: "evaljob-sample",
GrpcService: "localhost",
GrpcPort: 8082,
OutputPath: ".",
DetectDevice: true,
Logger: driverLog,
Args: []string{"sh", "-ec", "python -m lm_eval --output_path ./output --model test --model_args arg1=value1 --tasks task1,task2"},
ReportInterval: time.Second * 5,
})
assert.Nil(t, err)

assert.Nil(t, driver.Run())
assert.Equal(t, []string{
"update status from the driver: running",
"failed to detect available device(s): exit status 1",
}, progresssServer.progressMsgs)

server.Stop()

// the following files don't exist for this case
assert.NotNil(t, os.Remove("./stderr.log"))
assert.NotNil(t, os.Remove("./stdout.log"))
}

func Test_PatchDevice(t *testing.T) {
driverOpt := DriverOption{
Context: context.Background(),
JobNamespace: "fms-lm-eval-service-system",
JobName: "evaljob-sample",
GrpcService: "localhost",
GrpcPort: 8082,
OutputPath: ".",
DetectDevice: true,
Logger: driverLog,
Args: []string{"sh", "-ec", "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2"},
ReportInterval: time.Second * 5,
}

// append `--device cuda`
patchDevice(driverOpt.Args, true)
assert.Equal(t,
"python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --device cuda",
driverOpt.Args[2],
)

// append `--device cpu`
driverOpt.Args = []string{"sh", "-ec", "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2"}
patchDevice(driverOpt.Args, false)
assert.Equal(t,
"python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --device cpu",
driverOpt.Args[2],
)

// no change because `--device cpu` exists
driverOpt.Args = []string{"sh", "-ec", "python -m lm_eval --device cpu --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2"}
patchDevice(driverOpt.Args, true)
assert.Equal(t,
"python -m lm_eval --device cpu --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2",
driverOpt.Args[2],
)
}
7 changes: 5 additions & 2 deletions controllers/lmes/lmevaljob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ var (
"DriverReportInterval": DriverReportIntervalKey,
"DefaultBatchSize": DefaultBatchSizeKey,
"MaxBatchSize": MaxBatchSizeKey,
"DetectDevice": DetectDeviceKey,
}
)

Expand Down Expand Up @@ -101,6 +102,7 @@ type ServiceOptions struct {
GrpcClientSecret string
MaxBatchSize int
DefaultBatchSize int
DetectDevice bool
grpcTLSMode TLSMode
}

Expand Down Expand Up @@ -303,6 +305,7 @@ func (r *LMEvalJobReconciler) constructOptionsFromConfigMap(
GrpcServerSecret: DefaultGrpcServerSecret,
GrpcClientSecret: DefaultGrpcClientSecret,
MaxBatchSize: DefaultMaxBatchSize,
DetectDevice: DefaultDetectDevice,
DefaultBatchSize: DefaultBatchSize,
}

Expand Down Expand Up @@ -679,8 +682,7 @@ func (r *LMEvalJobReconciler) generateArgs(job *lmesv1alpha1.LMEvalJob, log logr
}

cmds := make([]string, 0, 10)
// FIXME: use CPU for now
cmds = append(cmds, "python", "-m", "lm_eval", "--output_path", "/opt/app-root/src/output", "--device", "cpu")
cmds = append(cmds, "python", "-m", "lm_eval", "--output_path", "/opt/app-root/src/output")
// --model
cmds = append(cmds, "--model", job.Spec.Model)
// --model_args
Expand Down Expand Up @@ -732,6 +734,7 @@ func (r *LMEvalJobReconciler) generateCmd(job *lmesv1alpha1.LMEvalJob) []string
"--grpc-service", fmt.Sprintf("%s.%s.svc", r.options.GrpcService, r.Namespace),
"--grpc-port", strconv.Itoa(r.options.GrpcPort),
"--output-path", "/opt/app-root/src/output",
"--detect-device", fmt.Sprintf("%t", r.options.DetectDevice),
"--report-interval", r.options.DriverReportInterval.String(),
"--",
}
Expand Down
6 changes: 3 additions & 3 deletions controllers/lmes/lmevaljob_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -760,22 +760,22 @@ func Test_GenerateArgBatchSize(t *testing.T) {
// no batchSize in the job, use default batchSize
assert.Equal(t, []string{
"sh", "-ec",
"python -m lm_eval --output_path /opt/app-root/src/output --device cpu --model test --model_args arg1=value1 --tasks task1,task2 --batch_size 8",
"python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --batch_size 8",
}, lmevalRec.generateArgs(job, log))

// exceed the max-batch-size, use max-batch-size
var biggerBatchSize = 30
job.Spec.BatchSize = &biggerBatchSize
assert.Equal(t, []string{
"sh", "-ec",
"python -m lm_eval --output_path /opt/app-root/src/output --device cpu --model test --model_args arg1=value1 --tasks task1,task2 --batch_size 24",
"python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --batch_size 24",
}, lmevalRec.generateArgs(job, log))

// normal batchSize
var normalBatchSize = 16
job.Spec.BatchSize = &normalBatchSize
assert.Equal(t, []string{
"sh", "-ec",
"python -m lm_eval --output_path /opt/app-root/src/output --device cpu --model test --model_args arg1=value1 --tasks task1,task2 --batch_size 16",
"python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --batch_size 16",
}, lmevalRec.generateArgs(job, log))
}

0 comments on commit 3ed6af0

Please sign in to comment.