From a612c7677aa1073839e1b47a663f41ecc55fa88e Mon Sep 17 00:00:00 2001 From: Daniel Levi-Minzi <51272568+dleviminzi@users.noreply.github.com> Date: Wed, 26 Feb 2025 12:32:35 -0500 Subject: [PATCH] Fix: Correctly format the PCI path when checking device existence (#988) There was some incorrect logic around the parsing of output from `nvidia-smi`. Refer to issue #984 for more context. --- pkg/worker/gpu_info.go | 15 ++++++++------- pkg/worker/gpu_info_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/pkg/worker/gpu_info.go b/pkg/worker/gpu_info.go index 5c6a63668..cf9c7dd28 100644 --- a/pkg/worker/gpu_info.go +++ b/pkg/worker/gpu_info.go @@ -84,15 +84,16 @@ func (c *NvidiaInfoClient) AvailableGPUDevices() ([]int, error) { continue } - // PCI bus_id is shown to be "domain:bus:device.function", but the folder in /proc/driver/nvidia/gpus is just "bus:device.function" - busId := strings.ToLower( - strings.TrimPrefix( - strings.TrimSpace(parts[1]), domain, - ), - ) + smiBusIdParts := strings.Split(parts[1], ":") + if len(smiBusIdParts) != 3 { + return nil, fmt.Errorf("unexpected bus id format from nvidia-smi: %s", line) + } + + // The bus id from nvidia-smi comes as xxxxxxxx:xx:xx.x so convert it to the format xxxx:xx:xx.x + systemBusId := strings.Join([]string{domain, smiBusIdParts[1], smiBusIdParts[2]}, ":") gpuIndex := strings.TrimSpace(parts[2]) - if exists, err := checkGPUExists(busId); err == nil && exists { + if exists, err := checkGPUExists(systemBusId); err == nil && exists { index, err := strconv.Atoi(strings.TrimSpace(gpuIndex)) if err != nil { return nil, err diff --git a/pkg/worker/gpu_info_test.go b/pkg/worker/gpu_info_test.go index cb1520ac0..fa8f25c4e 100644 --- a/pkg/worker/gpu_info_test.go +++ b/pkg/worker/gpu_info_test.go @@ -2,6 +2,7 @@ package worker import ( "os" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -58,6 +59,7 @@ func TestAvailableGPUDevicesAllVisibleDevices(t *testing.T) { } checkGPUExists = func(busId string) (bool, error) { + // check format matches xxxx:xx:xx.x return true, nil } @@ -68,3 +70,34 @@ func TestAvailableGPUDevicesAllVisibleDevices(t *testing.T) { assert.NoError(t, err) assert.Equal(t, []int{0, 1, 2, 3, 4, 5, 6, 7}, devices) } + +func TestAvailableGPUDevicesWithNonZeroPCIDomain(t *testing.T) { + originalQueryDevices := queryDevices + defer func() { queryDevices = originalQueryDevices }() + + originalCheckGPUExists := checkGPUExists + defer func() { checkGPUExists = originalCheckGPUExists }() + + queryDevices = func() ([]byte, error) { + mockOutput := `0x0001, 00000001:23:00.0, 0, GPU-afb8c77a-62ef-a631-48d0-edc9670fef25` + return []byte(mockOutput), nil + } + + checkGPUExists = func(busId string) (bool, error) { + // check format matches xxxx:xx:xx.x + parts := strings.Split(busId, ":") + assert.Equal(t, 3, len(parts)) + assert.Equal(t, 4, len(parts[0])) + assert.Equal(t, 2, len(parts[1])) + assert.Equal(t, 4, len(parts[2])) + assert.Contains(t, parts[2], ".") + return true, nil + } + + client := &NvidiaInfoClient{} + os.Setenv("NVIDIA_VISIBLE_DEVICES", "all") + + devices, err := client.AvailableGPUDevices() + assert.NoError(t, err) + assert.Equal(t, []int{0}, devices) +}