diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index 45e26c9e3..293528b2c 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -162,11 +162,10 @@ func pullImage(ctx context.Context, client docker.APIClient, taskParams DockerIm } func createContainer(ctx context.Context, client docker.APIClient, dockerParams DockerParameters, taskParams DockerImageConfig) (string, error) { - runtime, err := getRuntime(ctx, client) + gpuRequest, err := requestGpuIfAvailable(ctx, client) if err != nil { return "", tracerr.Wrap(err) } - mounts, err := dockerParams.DockerMounts() if err != nil { return "", tracerr.Wrap(err) @@ -183,8 +182,10 @@ func createContainer(ctx context.Context, client docker.APIClient, dockerParams PortBindings: bindPorts(dockerParams.DockerPorts()...), PublishAllPorts: true, Sysctls: map[string]string{}, - Runtime: runtime, - Mounts: mounts, + Resources: container.Resources{ + DeviceRequests: gpuRequest, + }, + Mounts: mounts, } resp, err := client.ContainerCreate(ctx, containerConfig, hostConfig, nil, nil, "") if err != nil { @@ -260,17 +261,21 @@ func getNetworkMode() container.NetworkMode { return "default" } -func getRuntime(ctx context.Context, client docker.APIClient) (string, error) { +func requestGpuIfAvailable(ctx context.Context, client docker.APIClient) ([]container.DeviceRequest, error) { info, err := client.Info(ctx) if err != nil { - return "", tracerr.Wrap(err) + return nil, tracerr.Wrap(err) } - for name := range info.Runtimes { - if name == consts.NVIDIA_RUNTIME { - return name, nil + + for runtime := range info.Runtimes { + if runtime == consts.NVIDIA_RUNTIME { + return []container.DeviceRequest{ + {Capabilities: [][]string{{"gpu"}}, Count: -1}, // --gpus=all + }, nil } } - return info.DefaultRuntime, nil + + return nil, nil } /* DockerParameters interface implementation for CLIArgs */