From 237e7b17eb3e8b984449579fc90e0c72bb5949aa Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 6 Sep 2024 17:50:38 +0000 Subject: [PATCH 1/2] update test --- tests/trainer/test_fsdp_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index 5f97c6092c..eb2d0f7940 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -840,8 +840,8 @@ def test_fsdp_partitioned_state_dict_load( s3_ephemeral_prefix, request, ): - if use_tp: - pytest.skip('TP on PyTorch 2.3 has sharded state dict issues.') + if use_tp and version.parse(torch.__version__) < version.parse('2.4.0'): + pytest.skip('TP has full state dict issues before PyTorch 2.4.') if weights_only and autoresume: pytest.skip('Weights only with autoresume is not supported') if (use_tp or use_hsdp) and version.parse(torch.__version__) < version.parse('2.3.0'): From 6116c21371dd30bbeb29a455b4547ce69645e6d1 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 6 Sep 2024 18:09:46 +0000 Subject: [PATCH 2/2] only look at relevant tests --- tests/trainer/test_fsdp_checkpoint.py | 42 +++++++++++++-------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index eb2d0f7940..55f5495c53 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -796,29 +796,29 @@ def mock_get_checkpoint_validation_function(): @pytest.mark.gpu -@pytest.mark.parametrize('use_remote', [pytest.param(True, marks=pytest.mark.remote), False]) +# @pytest.mark.parametrize('use_remote', [pytest.param(True, marks=pytest.mark.remote), False]) @pytest.mark.parametrize( 'weights_only,optimizer,precision,autoresume,load_ignore_keys,use_symlink,use_tp,use_hsdp', [ - pytest.param(False, 'adamw', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)), - pytest.param(True, 'adamw', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)), - pytest.param(False, 'adam', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)), - pytest.param(False, 'adamw', 'amp_fp16', False, None, False, False, False, marks=pytest.mark.world_size(2)), - pytest.param(False, 'adamw', 'amp_bf16', True, None, False, False, False, marks=pytest.mark.world_size(2)), - pytest.param( - False, - 'adamw', - 'amp_bf16', - False, - ['rng'], - False, - False, - False, - marks=pytest.mark.world_size(2), - ), - pytest.param(False, 'adamw', 'amp_bf16', False, None, True, False, False, marks=pytest.mark.world_size(2)), + # pytest.param(False, 'adamw', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)), + # pytest.param(True, 'adamw', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)), + # pytest.param(False, 'adam', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)), + # pytest.param(False, 'adamw', 'amp_fp16', False, None, False, False, False, marks=pytest.mark.world_size(2)), + # pytest.param(False, 'adamw', 'amp_bf16', True, None, False, False, False, marks=pytest.mark.world_size(2)), + # pytest.param( + # False, + # 'adamw', + # 'amp_bf16', + # False, + # ['rng'], + # False, + # False, + # False, + # marks=pytest.mark.world_size(2), + # ), + # pytest.param(False, 'adamw', 'amp_bf16', False, None, True, False, False, marks=pytest.mark.world_size(2)), pytest.param(False, 'adamw', 'amp_bf16', False, None, False, True, False, marks=pytest.mark.world_size(4)), - pytest.param(False, 'adamw', 'amp_bf16', False, None, False, False, True, marks=pytest.mark.world_size(4)), + # pytest.param(False, 'adamw', 'amp_bf16', False, None, False, False, True, marks=pytest.mark.world_size(4)), ], ) @pytest.mark.filterwarnings(r'ignore:TypedStorage is deprecated.:UserWarning') @@ -835,13 +835,13 @@ def test_fsdp_partitioned_state_dict_load( use_symlink: bool, use_tp: bool, use_hsdp: bool, - use_remote, s3_bucket, s3_ephemeral_prefix, request, + use_remote = False, ): if use_tp and version.parse(torch.__version__) < version.parse('2.4.0'): - pytest.skip('TP has full state dict issues before PyTorch 2.4.') + pytest.skip('TP has sharded state dict issues before PyTorch 2.4.') if weights_only and autoresume: pytest.skip('Weights only with autoresume is not supported') if (use_tp or use_hsdp) and version.parse(torch.__version__) < version.parse('2.3.0'):