Skip to content

Commit

Permalink
Merge branch 'main' into tyler/compile-liger
Browse files Browse the repository at this point in the history
  • Loading branch information
tyler-romero authored Aug 30, 2024
2 parents e4ce94f + cbc4f85 commit b4fc333
Show file tree
Hide file tree
Showing 8 changed files with 342 additions and 27 deletions.
113 changes: 106 additions & 7 deletions test/convergence/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
assert_verbose_allclose,
set_seed,
simple_collate_fn,
supports_bfloat16,
)

import pytest
Expand Down Expand Up @@ -349,23 +350,121 @@ def run_mini_model(
[
# Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 6e-4, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma1", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_gemma1",
32,
1e-4,
torch.bfloat16,
1e-8,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma1.1", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_gemma1.1",
32,
1e-4,
torch.bfloat16,
1e-8,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma2", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_gemma2",
32,
1e-4,
torch.bfloat16,
1e-8,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_llama3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_llama3",
32,
1e-4,
torch.bfloat16,
1e-8,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
# TODO: torch 2.5.0 nightly breaks mixtral test, but torch 2.3.0 works fine
# TODO: mixtral MoE structure makes the convergence flaky so disable the test for now. It needs high tol to pass.
# ("mini_mixtral", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 8e-3, 1e-5),
# ("mini_mixtral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 2.0, 1e-5, 1e-2, 1e-5),
("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_mistral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_mistral",
32,
1e-4,
torch.bfloat16,
1e-8,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_qwen2", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_qwen2",
32,
1e-4,
torch.bfloat16,
1e-8,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_phi3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_phi3",
32,
1e-4,
torch.bfloat16,
1e-8,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
],
)
def test_mini_model(
Expand Down
113 changes: 106 additions & 7 deletions test/convergence/test_mini_models_no_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
assert_verbose_allclose,
set_seed,
simple_collate_fn,
supports_bfloat16,
)

import pytest
Expand Down Expand Up @@ -291,20 +292,118 @@ def run_mini_model(
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
[
("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
("mini_llama3", 32, 1e-4, torch.bfloat16, 5e-3, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_llama3",
32,
1e-4,
torch.bfloat16,
5e-3,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_qwen2", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_qwen2",
32,
1e-4,
torch.bfloat16,
1e-8,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_phi3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_phi3",
32,
1e-4,
torch.bfloat16,
1e-8,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_mistral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_mistral",
32,
1e-4,
torch.bfloat16,
1e-8,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
# Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma1", 32, 1e-4, torch.bfloat16, 1e-2, 1e-4, 2e-1, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_gemma1",
32,
1e-4,
torch.bfloat16,
1e-2,
1e-4,
2e-1,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma1.1", 32, 1e-4, torch.bfloat16, 1e-2, 1e-4, 2e-1, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_gemma1.1",
32,
1e-4,
torch.bfloat16,
1e-2,
1e-4,
2e-1,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma2", 32, 1e-4, torch.bfloat16, 1e-2, 1e-4, 2e-1, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_gemma2",
32,
1e-4,
torch.bfloat16,
1e-2,
1e-4,
2e-1,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
],
)
def test_mini_model(
Expand Down
84 changes: 77 additions & 7 deletions test/transformers/test_cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from test.utils import supports_bfloat16

import pytest
import torch
from torch.nn import CrossEntropyLoss
Expand Down Expand Up @@ -99,14 +101,42 @@ def _test_correctness_not_last_layer_once(
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
(0.1, torch.bfloat16, 1e-8, 5e-2),
(1.0, torch.bfloat16, 1e-8, 5e-2),
(10.0, torch.bfloat16, 1e-7, 5e-2),
pytest.param(
0.1,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
1.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
10.0,
torch.bfloat16,
1e-7,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
(0.1, torch.float32, 1e-8, 1e-6),
(1.0, torch.float32, 1e-8, 1e-6),
(10.0, torch.float32, 1e-8, 1e-6),
],
)
@pytest.mark.skipif(
torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
reason="Needs 16GB+ GPU memory.",
)
def test_correctness(B, T, V, scalar, dtype, atol, rtol):
liger_ce = LigerCrossEntropyLoss()
_test_correctness_once(liger_ce, B, T, V, scalar, dtype, atol, rtol)
Expand All @@ -125,14 +155,42 @@ def test_correctness(B, T, V, scalar, dtype, atol, rtol):
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
(0.1, torch.bfloat16, 1e-8, 5e-2),
(1.0, torch.bfloat16, 1e-8, 5e-2),
(10.0, torch.bfloat16, 1e-8, 5e-2),
pytest.param(
0.1,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
1.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
10.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
(0.1, torch.float32, 1e-8, 1e-6),
(1.0, torch.float32, 1e-8, 1e-6),
(10.0, torch.float32, 1e-8, 1e-6),
],
)
@pytest.mark.skipif(
torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
reason="Needs 16GB+ GPU memory.",
)
def test_correctness_with_ignore_index(
B, T, V, ignore_index, scalar, dtype, atol, rtol
):
Expand All @@ -155,10 +213,22 @@ def test_correctness_with_ignore_index(
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
(1.0, torch.bfloat16, 1e-8, 5e-2),
pytest.param(
1.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
(1.0, torch.float32, 1e-8, 1e-6),
],
)
@pytest.mark.skipif(
torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
reason="Needs 16GB+ GPU memory.",
)
def test_correctness_not_last_layer(B, T, V, scalar, dtype, atol, rtol):
liger_ce = LigerCrossEntropyLoss()
_test_correctness_not_last_layer_once(liger_ce, B, T, V, scalar, dtype, atol, rtol)
Expand Down
Loading

0 comments on commit b4fc333

Please sign in to comment.