diff --git a/test_runner.py b/test_runner.py index a7c95ce1..94de464a 100755 --- a/test_runner.py +++ b/test_runner.py @@ -69,7 +69,7 @@ def build_test_list(): "--experimental.pipeline_parallel_split_points layers.4", "--experimental.pipeline_parallel_schedule 1f1b", "--training.data_parallel_degree 1", - "--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP + "--model.norm_type rmsnorm", # compiled_rmsnorm / fused_rmsnorm crashes with PP ], ], "PP 1D test 1f1b", @@ -85,7 +85,7 @@ def build_test_list(): "--experimental.pipeline_parallel_split_points layers.4", "--experimental.pipeline_parallel_schedule gpipe", "--training.data_parallel_degree 1", - "--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP + "--model.norm_type rmsnorm", # compiled_rmsnorm / fused_rmsnorm crashes with PP ], ], "PP 1D test gpipe", @@ -101,7 +101,7 @@ def build_test_list(): "--experimental.pipeline_parallel_split_points layers.4", "--experimental.pipeline_parallel_schedule 1f1b", "--training.data_parallel_degree 2", - "--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP + "--model.norm_type rmsnorm", # compiled_rmsnorm / fused_rmsnorm crashes with PP ], ], "PP+DP 1f1b 2D test", @@ -116,7 +116,7 @@ def build_test_list(): "--experimental.pipeline_parallel_split_points layers.4", "--experimental.pipeline_parallel_schedule gpipe", "--training.data_parallel_degree 2", - "--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP + "--model.norm_type rmsnorm", # compiled_rmsnorm / fused_rmsnorm crashes with PP ], ], "PP+DP gpipe 2D test", @@ -130,7 +130,6 @@ def build_test_list(): "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.4", "--training.tensor_parallel_degree 2", - "--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP ], ], "PP+TP 2D test", @@ -144,7 +143,6 @@ def build_test_list(): "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.4", "--experimental.pipeline_parallel_split_mode tracer", - "--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with tracer ], ], "PP tracer frontend test", @@ -162,7 +160,16 @@ def build_test_list(): OverrideDefinitions( [ [ - "--training.compile --model.norm_type=rmsnorm", + "--training.tensor_parallel_degree 2", + ], + ], + "2D eager", + "2d_eager", + ), + OverrideDefinitions( + [ + [ + "--training.compile", ], ], "1D compile", @@ -182,29 +189,20 @@ def build_test_list(): OverrideDefinitions( [ [ - "--training.compile --training.tensor_parallel_degree 2 --model.norm_type=rmsnorm", + "--training.compile --training.tensor_parallel_degree 2", ], ], "2D compile", "2d_compile", ), - OverrideDefinitions( - [ - [ - "--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm", - ], - ], - "Eager mode 2DParallel with rmsnorm", - "eager_2d_rmsnorm", - ), OverrideDefinitions( [ [ "--training.tensor_parallel_degree 2 --model.norm_type=fused_rmsnorm", ], ], - "Eager mode 2DParallel with fused_rmsnorm", - "eager_2d_fused_rmsnorm", + "2D eager with fused_rmsnorm", + "2d_eager_fused_rmsnorm", ), OverrideDefinitions( [ @@ -248,7 +246,6 @@ def build_test_list(): "--experimental.pipeline_parallel_split_points layers.4", "--training.data_parallel_degree 2", "--training.tensor_parallel_degree 2", - "--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP ], [ "--training.steps 20", @@ -257,7 +254,6 @@ def build_test_list(): "--experimental.pipeline_parallel_split_points layers.4", "--training.data_parallel_degree 2", "--training.tensor_parallel_degree 2", - "--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP ], ], "PP+DP+TP 3D test with save/load resume ckpt", @@ -272,7 +268,7 @@ def build_test_list(): "--experimental.pipeline_parallel_degree 4", "--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7", "--experimental.pipeline_parallel_schedule interleaved_1f1b", - "--model.norm_type rmsnorm", # fused_rmsnorm throws cuda context error with pp + "--model.norm_type rmsnorm", # compiled_rmsnorm / fused_rmsnorm crashes with PP ], ], "PP looped 1f1b test", @@ -292,7 +288,8 @@ def build_test_list(): OverrideDefinitions( [ [ - "--memory_estimation.enabled --model.norm_type rmsnorm", + "--memory_estimation.enabled", + "--model.norm_type rmsnorm", # estimation mode does not support compiled_rmsnorm yet ] ], "FSDP2 Memory Tracking and Estimation",