Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adapt CI tests to use compiled_rmsnorm #451

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 20 additions & 23 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def build_test_list():
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule 1f1b",
"--training.data_parallel_degree 1",
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
"--model.norm_type rmsnorm", # compiled_rmsnorm / fused_rmsnorm crashes with PP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious what is the reason that compiled RMSNorm is incompatible with PP

],
],
"PP 1D test 1f1b",
Expand All @@ -85,7 +85,7 @@ def build_test_list():
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule gpipe",
"--training.data_parallel_degree 1",
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
"--model.norm_type rmsnorm", # compiled_rmsnorm / fused_rmsnorm crashes with PP
],
],
"PP 1D test gpipe",
Expand All @@ -101,7 +101,7 @@ def build_test_list():
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule 1f1b",
"--training.data_parallel_degree 2",
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
"--model.norm_type rmsnorm", # compiled_rmsnorm / fused_rmsnorm crashes with PP
],
],
"PP+DP 1f1b 2D test",
Expand All @@ -116,7 +116,7 @@ def build_test_list():
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule gpipe",
"--training.data_parallel_degree 2",
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
"--model.norm_type rmsnorm", # compiled_rmsnorm / fused_rmsnorm crashes with PP
],
],
"PP+DP gpipe 2D test",
Expand All @@ -130,7 +130,6 @@ def build_test_list():
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--training.tensor_parallel_degree 2",
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP
],
],
"PP+TP 2D test",
Expand All @@ -144,7 +143,6 @@ def build_test_list():
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_split_mode tracer",
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with tracer
],
],
"PP tracer frontend test",
Expand All @@ -162,7 +160,16 @@ def build_test_list():
OverrideDefinitions(
[
[
"--training.compile --model.norm_type=rmsnorm",
"--training.tensor_parallel_degree 2",
],
],
"2D eager",
"2d_eager",
),
OverrideDefinitions(
[
[
"--training.compile",
],
],
"1D compile",
Expand All @@ -182,29 +189,20 @@ def build_test_list():
OverrideDefinitions(
[
[
"--training.compile --training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
"--training.compile --training.tensor_parallel_degree 2",
],
],
"2D compile",
"2d_compile",
),
OverrideDefinitions(
[
[
"--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
],
],
"Eager mode 2DParallel with rmsnorm",
"eager_2d_rmsnorm",
),
OverrideDefinitions(
[
[
"--training.tensor_parallel_degree 2 --model.norm_type=fused_rmsnorm",
],
],
"Eager mode 2DParallel with fused_rmsnorm",
"eager_2d_fused_rmsnorm",
"2D eager with fused_rmsnorm",
"2d_eager_fused_rmsnorm",
),
OverrideDefinitions(
[
Expand Down Expand Up @@ -248,7 +246,6 @@ def build_test_list():
"--experimental.pipeline_parallel_split_points layers.4",
"--training.data_parallel_degree 2",
"--training.tensor_parallel_degree 2",
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP
],
[
"--training.steps 20",
Expand All @@ -257,7 +254,6 @@ def build_test_list():
"--experimental.pipeline_parallel_split_points layers.4",
"--training.data_parallel_degree 2",
"--training.tensor_parallel_degree 2",
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP
],
],
"PP+DP+TP 3D test with save/load resume ckpt",
Expand All @@ -272,7 +268,7 @@ def build_test_list():
"--experimental.pipeline_parallel_degree 4",
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
"--experimental.pipeline_parallel_schedule interleaved_1f1b",
"--model.norm_type rmsnorm", # fused_rmsnorm throws cuda context error with pp
"--model.norm_type rmsnorm", # compiled_rmsnorm / fused_rmsnorm crashes with PP
],
],
"PP looped 1f1b test",
Expand All @@ -292,7 +288,8 @@ def build_test_list():
OverrideDefinitions(
[
[
"--memory_estimation.enabled --model.norm_type rmsnorm",
"--memory_estimation.enabled",
"--model.norm_type rmsnorm", # estimation mode does not support compiled_rmsnorm yet
]
],
"FSDP2 Memory Tracking and Estimation",
Expand Down
Loading