Skip to content

Commit

Permalink
replace cuda with device for xpu support
Browse files Browse the repository at this point in the history
  • Loading branch information
mgrabban committed Nov 22, 2024
1 parent 90d7113 commit 1748f9c
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions test/transformers/test_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,12 +736,12 @@ def test_float32_internal():
reduction = "mean"

# Initialize input tensors
X_init = torch.randn(batch_size, n_cols, dtype=torch.bfloat16, device="cuda")
Y = torch.randint(0, n_cols, (batch_size,), device="cuda")
X_init = torch.randn(batch_size, n_cols, dtype=torch.bfloat16, device=device)
Y = torch.randint(0, n_cols, (batch_size,), device=device)

# Run kernel for bfloat16
X_bf16 = X_init.clone()
loss_bf16 = torch.zeros(batch_size, dtype=torch.float32, device="cuda")
loss_bf16 = torch.zeros(batch_size, dtype=torch.float32, device=device)
liger_cross_entropy_kernel[(batch_size,)](
X_ptr=X_bf16,
X_stride=X_bf16.stride(-2),
Expand All @@ -765,7 +765,7 @@ def test_float32_internal():

# Run kernel for float32
X_fp32 = X_init.float()
loss_fp32 = torch.zeros(batch_size, dtype=torch.float32, device="cuda")
loss_fp32 = torch.zeros(batch_size, dtype=torch.float32, device=device)
liger_cross_entropy_kernel[(batch_size,)](
X_ptr=X_fp32,
X_stride=X_fp32.stride(-2),
Expand Down

0 comments on commit 1748f9c

Please sign in to comment.