Skip to content

Commit

Permalink
temporary fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
tjtanaa committed Dec 6, 2024
1 parent 07a3a62 commit 96e9fe1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 9 deletions.
8 changes: 1 addition & 7 deletions .github/workflows/amd-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,10 @@ jobs:

- name: Setup Dependencies
run: |
python3 -m pip uninstall -y torch torchvision
python3 -m pip install --pre \
torch==2.6.0.dev20241113+rocm6.2 \
'setuptools-scm>=8' \
torchvision==0.20.0.dev20241113+rocm6.2 \
--extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2
python3 -m pip install triton==3.1.0 transformers==4.46.3
python3 -m pip install -e .[dev]
- name: Run Tests
- name: Run Unit Tests
run: |
make test
make test-convergence
Expand Down
4 changes: 2 additions & 2 deletions test/transformers/test_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ def test_float32_internal():
RETURN_Z_LOSS=0, # False
HAS_SOFTCAPPING=False,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
num_warps=32 if device=='cuda' else 16,
)

# Run kernel for float32
Expand All @@ -787,7 +787,7 @@ def test_float32_internal():
RETURN_Z_LOSS=0, # False
HAS_SOFTCAPPING=False,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
num_warps=32 if device=='cuda' else 16,
)

torch.allclose(X_bf16, X_fp32.bfloat16())
Expand Down

0 comments on commit 96e9fe1

Please sign in to comment.