Skip to content

Commit

Permalink
TestLowCpuMemUsage UT get device by device_name (#6397)
Browse files Browse the repository at this point in the history
Co-authored-by: Shaik Raza Sikander <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
3 people authored Aug 29, 2024
1 parent a7ffe54 commit 89c4d9f
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,12 @@ def verify_injection(module):
verify_injection(model)


# Used to Get Device name
def getDeviceId(local_rank):
device = torch.device(f"{get_accelerator().device_name(local_rank)}")
return device


# Verify that test is valid
def validate_test(model_w_task, dtype, enable_cuda_graph, enable_triton):
model, task = model_w_task
Expand Down Expand Up @@ -484,8 +490,8 @@ def test(
pytest.skip(f"Acceleraor {get_accelerator().device_name()} does not support {dtype}.")

local_rank = int(os.getenv("LOCAL_RANK", "0"))

pipe = pipeline(task, model=model, model_kwargs={"low_cpu_mem_usage": True}, device=local_rank, framework="pt")
device = getDeviceId(local_rank)
pipe = pipeline(task, model=model, model_kwargs={"low_cpu_mem_usage": True}, device=device, framework="pt")
bs_output = pipe(query, **inf_kwargs)
pipe.model = deepspeed.init_inference(pipe.model,
mp_size=self.world_size,
Expand Down

0 comments on commit 89c4d9f

Please sign in to comment.