From d71a41bdaea42c5cd7e2fb33eeb30ab651769ef7 Mon Sep 17 00:00:00 2001 From: Jason Senthil Date: Thu, 30 Jan 2025 15:21:17 -0800 Subject: [PATCH] add fsdp2 support (#967) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/967 Reviewed By: anshulverma Differential Revision: D68735961 fbshipit-source-id: 69bdd1bd700dd58f4c92ed6ba8bc4ae0b4432dc0 --- tests/utils/test_prepare_module_gpu.py | 133 +++++++++++++++++++++ torchtnt/utils/prepare_module.py | 159 ++++++++++++++++++++++++- 2 files changed, 291 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_prepare_module_gpu.py b/tests/utils/test_prepare_module_gpu.py index 9b47eaa283..ba3583f6e2 100644 --- a/tests/utils/test_prepare_module_gpu.py +++ b/tests/utils/test_prepare_module_gpu.py @@ -7,9 +7,20 @@ # pyre-strict import unittest +from typing import Any import torch from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +try: + from torch.distributed.fsdp import fully_shard +except ImportError: + + def noop(*args: Any, **kwargs: Any) -> None: + pass + + fully_shard = noop + from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision from torch.nn.parallel import DistributedDataParallel as DDP from torchtnt.utils.distributed import spawn_multi_process @@ -17,9 +28,11 @@ from torchtnt.utils.prepare_module import ( _is_fsdp_module, DDPStrategy, + FSDP2Strategy, FSDPStrategy, prepare_ddp, prepare_fsdp, + prepare_fsdp2, prepare_module, ) from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu @@ -111,6 +124,11 @@ def _test_is_fsdp_module() -> None: model = FSDP(model) assert _is_fsdp_module(model) + # test fsdp2 + model = torch.nn.Linear(1, 1, device=device) + fully_shard(model) + assert _is_fsdp_module(model) + @skip_if_not_distributed @skip_if_not_gpu def test_fdsp_precision(self) -> None: @@ -276,3 +294,118 @@ def _test_prepare_module_fsdp_string_wrapped_in_fsdp() -> None: tc = unittest.TestCase() tc.assertTrue(isinstance(fsdp_module, FSDP)) + + @skip_if_not_distributed + @skip_if_not_gpu + def test_prepare_fsdp2(self) -> None: + """ + Launch tests of FSDP2 strategy + """ + + spawn_multi_process( + 1, + "nccl", + self._test_prepare_fsdp2_none_sharded_raises, + ) + + spawn_multi_process( + 1, + "nccl", + self._test_prepare_fsdp2_shard_all, + ) + + spawn_multi_process( + 1, + "nccl", + self._test_prepare_fsdp2_submodule, + ) + + spawn_multi_process( + 1, + "nccl", + self._test_prepare_fsdp2_meta_device, + ) + + @staticmethod + def _test_prepare_fsdp2_none_sharded_raises() -> None: + """ + Test with a strategy that does not shard any modules, should raise error + """ + tc = unittest.TestCase() + + module = SimpleModule() + device = torch.device("cuda") + strategy = FSDP2Strategy(modules_to_shard=[]) + with tc.assertRaises(ValueError): + prepare_fsdp2(module, device, strategy) + + @staticmethod + def _test_prepare_fsdp2_shard_all() -> None: + """ + Test with a strategy that shards all modules + """ + tc = unittest.TestCase() + + module = SimpleModule() + device = torch.device("cuda") + strategy = FSDP2Strategy(modules_to_shard="all") + prepare_fsdp2(module, device, strategy) + + for submodule in module.modules(): + tc.assertTrue(_is_fsdp_module(submodule)) + + @staticmethod + def _test_prepare_fsdp2_submodule() -> None: + """ + Test with a strategy that shards modules (either str or module type) + """ + tc = unittest.TestCase() + + for t in (torch.nn.Linear, "Linear"): + module = SimpleModule() + device = torch.device("cuda") + strategy = FSDP2Strategy(modules_to_shard=(t,)) + prepare_fsdp2(module, device, strategy) + + for submodule in module.modules(): + if isinstance(submodule, torch.nn.Conv2d): + tc.assertFalse(_is_fsdp_module(submodule)) + else: + # linear and SimpleModule are fsdp modules + tc.assertTrue(_is_fsdp_module(submodule)) + + @staticmethod + def _test_prepare_fsdp2_meta_device() -> None: + """ + Test with a strategy that shards specific modules on meta device + """ + tc = unittest.TestCase() + + module = SimpleModule(meta_device=True) + device = torch.device("cuda") + strategy = FSDP2Strategy(modules_to_shard=(torch.nn.Linear,)) + prepare_fsdp2(module, device, strategy) + + for submodule in module.modules(): + if isinstance(submodule, torch.nn.Conv2d): + tc.assertFalse(_is_fsdp_module(submodule)) + else: + # linear and SimpleModule are fsdp modules + tc.assertTrue(_is_fsdp_module(submodule)) + + +class SimpleModule(torch.nn.Module): + def __init__(self, meta_device: bool = False) -> None: + super(SimpleModule, self).__init__() + self.linear = torch.nn.Linear(10, 10, device="meta" if meta_device else None) + self.conv1 = torch.nn.Conv2d( + in_channels=3, + out_channels=16, + kernel_size=3, + stride=1, + padding=1, + device="meta" if meta_device else None, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) diff --git a/torchtnt/utils/prepare_module.py b/torchtnt/utils/prepare_module.py index b3b190add2..ca40f87b1b 100644 --- a/torchtnt/utils/prepare_module.py +++ b/torchtnt/utils/prepare_module.py @@ -15,8 +15,11 @@ ContextManager, Dict, Iterable, + Literal, Optional, + Set, Tuple, + Type, Union, ) @@ -30,6 +33,29 @@ checkpoint_wrapper, CheckpointImpl, ) +from torch.distributed.device_mesh import init_device_mesh + +try: + from torch.distributed.fsdp import ( + CPUOffloadPolicy, + fully_shard, + MixedPrecisionPolicy, + ) + from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState +except ImportError: + + def noop(*args: Any, **kwargs: Any) -> None: + pass + + class NOOP: + def __init__(self, *args: Any, **kwargs: Any) -> None: + pass + + fully_shard = noop + MixedPrecisionPolicy = NOOP + CPUOffloadPolicy = NOOP + FSDPState = NOOP + from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, StateDictType as _StateDictType, @@ -146,6 +172,52 @@ def __post_init__(self) -> None: self.mixed_precision = self.mixed_precision.to_native_mixed_precision() +@dataclass +class FSDP2Strategy(Strategy): + """ + Dataclass representing the `FSDP2 `_ strategy. + For more details on the args, see the link. + + Args: + modules_to_shard: A list of modules that should be sharded across devices. Options are 'all' to shard all submodules, or a list of module names/module types. + reshard_after_forward: If True, reshards parameters after the forward pass to optimize memory usage. + mp_policy: Controls mixed precision policy. If only dtype is provided, it will be used to cast all relevant parts of model. If None, no mixed precision is used + cpu_offload: If True, enables CPU offloading of model parameters to reduce GPU memory usage. + + Note: + It is recommended to specify specific modules to shard to avoid unnecessary sharding of all submodules, which has + communication overhead. + + Example: + >>> model + TransformerDecoder( + (tok_embeddings): Embedding(128256, 4096) + (layers): ModuleList( + (0-31): 32 x TransformerSelfAttentionLayer( + (attn): MultiHeadAttention( + (q_proj): Linear(in_features=4096, out_features=4096, bias=False) + (k_proj): Linear(in_features=4096, out_features=1024, bias=False) + (v_proj): Linear(in_features=4096, out_features=1024, bias=False) + (output_proj): Linear(in_features=4096, out_features=4096, bias=False) + (pos_embeddings): RotaryPositionalEmbeddings() + ) + ... + ) + (output): Linear(in_features=4096, out_features=128256, bias=False) + ) + >>> # You can either specify the module to shard as a name ("Linear") or the module type (torch.nn.Linear) + >>> strategy = FSDP2Strategy(modules_to_shard=["TransformerSelfAttentionLayer", "Linear"]) + """ + + modules_to_shard: Union[ + Literal["all"], + Iterable[Union[str, Type[torch.nn.Module]]], + ] = "all" + reshard_after_forward: Union[bool, int] = True + mp_policy: Optional[Union[torch.dtype, MixedPrecisionPolicy]] = None + cpu_offload: bool = False + + @dataclass class TorchCompileParams: """ @@ -272,6 +344,89 @@ def prepare_fsdp( return module +def prepare_fsdp2( + module: torch.nn.Module, + device: torch.device, + strategy: Optional[FSDP2Strategy] = None, + process_group: Optional[ProcessGroup] = None, +) -> torch.nn.Module: + """ + Utility to move a module to device and wrap in `FSDP2 `_ + + Args: + module: module to be wrapped in FSDP + device: device to which module will be moved + strategy: an instance of :class:`~torchtnt.utils.prepare_module.FSDP2Strategy` which defines the settings of FSDP APIs + """ + strategy = strategy or FSDP2Strategy() + + # prepare kwargs for fully_shard api + pg = process_group or dist.distributed_c10d._get_default_group() + mesh = init_device_mesh(device.type, mesh_shape=(pg.size(),)) + fsdp_kwargs: Dict[str, Any] = { + "mesh": mesh, # TODO we only configure 1D mesh for now, look into supporting HSDP + "reshard_after_forward": strategy.reshard_after_forward, + } + if strategy.cpu_offload: + fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() + if (mp_policy := strategy.mp_policy) is not None: + if isinstance(mp_policy, MixedPrecisionPolicy): + fsdp_kwargs["mixed_precision"] = mp_policy + else: + fsdp_kwargs["mixed_precision"] = MixedPrecisionPolicy( + param_dtype=mp_policy, + reduce_dtype=mp_policy, + output_dtype=mp_policy, + cast_forward_inputs=True, + ) + + # parse out the modules_to_shard argument + modules_to_shard = strategy.modules_to_shard + + shard_all = modules_to_shard == "all" + shard_module_names: Set[str] = set() + shard_module_types: Tuple[Type[torch.nn.Module], ...] = () + if not shard_all: + assert ( + type(modules_to_shard) is not str + ), f"modules_to_shard must be an iterable of modules or 'all', got {shard_all}" + + for item in modules_to_shard: + if isinstance(item, str): + shard_module_names.add(item) + else: + shard_module_types = shard_module_types + (item,) + + # apply the fsdp2 sharding bottoms up + num_layers_sharded = 0 + for _, m in reversed(list(module.named_modules())): + if shard_all: + # fully_shard does not support containers that do not implement forward + if not isinstance(m, (torch.nn.ModuleList, torch.nn.ModuleDict)): + fully_shard(m, **fsdp_kwargs) + num_layers_sharded += 1 + elif ( + isinstance(m, shard_module_types) or type(m).__name__ in shard_module_names + ): + # if m exists in shard_module_types, then shard it + fully_shard(m, **fsdp_kwargs) + num_layers_sharded += 1 + + if num_layers_sharded == 0: + raise ValueError( + "No layer modules were sharded with fsdp2. Please check if shard conditions are working as expected." + ) + + # shard the top level model, so that all params are moved off cpu to gpu + if not _is_fsdp_module(module): + fully_shard(module, **fsdp_kwargs) + + # materialized sharded meta weights to device + materialize_meta_params(module, device) + + return module + + class FSDPOptimizerWrapper: """ Wrapper for FSDP optimizer to call specific FSDP optimizer state checkpointing APIs. @@ -301,7 +456,7 @@ def _is_fsdp_module(module: torch.nn.Module) -> bool: # Also check for composable FSDP API maybe_composable_state = _get_module_state(module) if maybe_composable_state is not None: - return isinstance(maybe_composable_state, _FSDPState) + return isinstance(maybe_composable_state, (_FSDPState, FSDPState)) return False @@ -366,6 +521,8 @@ def prepare_module( "Torch compile requires FSDPStrategy's use_orig_params to be True, since AOTAutograd needs to be aware of the original parameters" ) module = prepare_fsdp(module, device, strategy) + elif isinstance(strategy, FSDP2Strategy): + module = prepare_fsdp2(module, device, strategy) else: module = module.to(device)