From 1748f9c9e11f05f543dc835370ee7b7cb9c3a7e4 Mon Sep 17 00:00:00 2001 From: "Rabbani, Golam" Date: Fri, 22 Nov 2024 20:02:11 +0000 Subject: [PATCH] replace cuda with device for xpu support --- test/transformers/test_cross_entropy.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index eb052ab5f..6e1dc277b 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -736,12 +736,12 @@ def test_float32_internal(): reduction = "mean" # Initialize input tensors - X_init = torch.randn(batch_size, n_cols, dtype=torch.bfloat16, device="cuda") - Y = torch.randint(0, n_cols, (batch_size,), device="cuda") + X_init = torch.randn(batch_size, n_cols, dtype=torch.bfloat16, device=device) + Y = torch.randint(0, n_cols, (batch_size,), device=device) # Run kernel for bfloat16 X_bf16 = X_init.clone() - loss_bf16 = torch.zeros(batch_size, dtype=torch.float32, device="cuda") + loss_bf16 = torch.zeros(batch_size, dtype=torch.float32, device=device) liger_cross_entropy_kernel[(batch_size,)]( X_ptr=X_bf16, X_stride=X_bf16.stride(-2), @@ -765,7 +765,7 @@ def test_float32_internal(): # Run kernel for float32 X_fp32 = X_init.float() - loss_fp32 = torch.zeros(batch_size, dtype=torch.float32, device="cuda") + loss_fp32 = torch.zeros(batch_size, dtype=torch.float32, device=device) liger_cross_entropy_kernel[(batch_size,)]( X_ptr=X_fp32, X_stride=X_fp32.stride(-2),