Skip to content

Commit

Permalink
add fsdp2 support (#967)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #967

Reviewed By: anshulverma

Differential Revision: D68735961

fbshipit-source-id: 69bdd1bd700dd58f4c92ed6ba8bc4ae0b4432dc0
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Jan 30, 2025
1 parent 7272dbd commit d71a41b
Show file tree
Hide file tree
Showing 2 changed files with 291 additions and 1 deletion.
133 changes: 133 additions & 0 deletions tests/utils/test_prepare_module_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,32 @@

# pyre-strict
import unittest
from typing import Any

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

try:
from torch.distributed.fsdp import fully_shard
except ImportError:

def noop(*args: Any, **kwargs: Any) -> None:
pass

fully_shard = noop

from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
from torch.nn.parallel import DistributedDataParallel as DDP
from torchtnt.utils.distributed import spawn_multi_process
from torchtnt.utils.env import init_from_env
from torchtnt.utils.prepare_module import (
_is_fsdp_module,
DDPStrategy,
FSDP2Strategy,
FSDPStrategy,
prepare_ddp,
prepare_fsdp,
prepare_fsdp2,
prepare_module,
)
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
Expand Down Expand Up @@ -111,6 +124,11 @@ def _test_is_fsdp_module() -> None:
model = FSDP(model)
assert _is_fsdp_module(model)

# test fsdp2
model = torch.nn.Linear(1, 1, device=device)
fully_shard(model)
assert _is_fsdp_module(model)

@skip_if_not_distributed
@skip_if_not_gpu
def test_fdsp_precision(self) -> None:
Expand Down Expand Up @@ -276,3 +294,118 @@ def _test_prepare_module_fsdp_string_wrapped_in_fsdp() -> None:
tc = unittest.TestCase()

tc.assertTrue(isinstance(fsdp_module, FSDP))

@skip_if_not_distributed
@skip_if_not_gpu
def test_prepare_fsdp2(self) -> None:
"""
Launch tests of FSDP2 strategy
"""

spawn_multi_process(
1,
"nccl",
self._test_prepare_fsdp2_none_sharded_raises,
)

spawn_multi_process(
1,
"nccl",
self._test_prepare_fsdp2_shard_all,
)

spawn_multi_process(
1,
"nccl",
self._test_prepare_fsdp2_submodule,
)

spawn_multi_process(
1,
"nccl",
self._test_prepare_fsdp2_meta_device,
)

@staticmethod
def _test_prepare_fsdp2_none_sharded_raises() -> None:
"""
Test with a strategy that does not shard any modules, should raise error
"""
tc = unittest.TestCase()

module = SimpleModule()
device = torch.device("cuda")
strategy = FSDP2Strategy(modules_to_shard=[])
with tc.assertRaises(ValueError):
prepare_fsdp2(module, device, strategy)

@staticmethod
def _test_prepare_fsdp2_shard_all() -> None:
"""
Test with a strategy that shards all modules
"""
tc = unittest.TestCase()

module = SimpleModule()
device = torch.device("cuda")
strategy = FSDP2Strategy(modules_to_shard="all")
prepare_fsdp2(module, device, strategy)

for submodule in module.modules():
tc.assertTrue(_is_fsdp_module(submodule))

@staticmethod
def _test_prepare_fsdp2_submodule() -> None:
"""
Test with a strategy that shards modules (either str or module type)
"""
tc = unittest.TestCase()

for t in (torch.nn.Linear, "Linear"):
module = SimpleModule()
device = torch.device("cuda")
strategy = FSDP2Strategy(modules_to_shard=(t,))
prepare_fsdp2(module, device, strategy)

for submodule in module.modules():
if isinstance(submodule, torch.nn.Conv2d):
tc.assertFalse(_is_fsdp_module(submodule))
else:
# linear and SimpleModule are fsdp modules
tc.assertTrue(_is_fsdp_module(submodule))

@staticmethod
def _test_prepare_fsdp2_meta_device() -> None:
"""
Test with a strategy that shards specific modules on meta device
"""
tc = unittest.TestCase()

module = SimpleModule(meta_device=True)
device = torch.device("cuda")
strategy = FSDP2Strategy(modules_to_shard=(torch.nn.Linear,))
prepare_fsdp2(module, device, strategy)

for submodule in module.modules():
if isinstance(submodule, torch.nn.Conv2d):
tc.assertFalse(_is_fsdp_module(submodule))
else:
# linear and SimpleModule are fsdp modules
tc.assertTrue(_is_fsdp_module(submodule))


class SimpleModule(torch.nn.Module):
def __init__(self, meta_device: bool = False) -> None:
super(SimpleModule, self).__init__()
self.linear = torch.nn.Linear(10, 10, device="meta" if meta_device else None)
self.conv1 = torch.nn.Conv2d(
in_channels=3,
out_channels=16,
kernel_size=3,
stride=1,
padding=1,
device="meta" if meta_device else None,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
159 changes: 158 additions & 1 deletion torchtnt/utils/prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
ContextManager,
Dict,
Iterable,
Literal,
Optional,
Set,
Tuple,
Type,
Union,
)

Expand All @@ -30,6 +33,29 @@
checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.device_mesh import init_device_mesh

try:
from torch.distributed.fsdp import (
CPUOffloadPolicy,
fully_shard,
MixedPrecisionPolicy,
)
from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState
except ImportError:

def noop(*args: Any, **kwargs: Any) -> None:
pass

class NOOP:
def __init__(self, *args: Any, **kwargs: Any) -> None:
pass

fully_shard = noop
MixedPrecisionPolicy = NOOP
CPUOffloadPolicy = NOOP
FSDPState = NOOP

from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
StateDictType as _StateDictType,
Expand Down Expand Up @@ -146,6 +172,52 @@ def __post_init__(self) -> None:
self.mixed_precision = self.mixed_precision.to_native_mixed_precision()


@dataclass
class FSDP2Strategy(Strategy):
"""
Dataclass representing the `FSDP2 <https://pytorch.org/docs/2.6/distributed.fsdp.fully_shard.html>`_ strategy.
For more details on the args, see the link.
Args:
modules_to_shard: A list of modules that should be sharded across devices. Options are 'all' to shard all submodules, or a list of module names/module types.
reshard_after_forward: If True, reshards parameters after the forward pass to optimize memory usage.
mp_policy: Controls mixed precision policy. If only dtype is provided, it will be used to cast all relevant parts of model. If None, no mixed precision is used
cpu_offload: If True, enables CPU offloading of model parameters to reduce GPU memory usage.
Note:
It is recommended to specify specific modules to shard to avoid unnecessary sharding of all submodules, which has
communication overhead.
Example:
>>> model
TransformerDecoder(
(tok_embeddings): Embedding(128256, 4096)
(layers): ModuleList(
(0-31): 32 x TransformerSelfAttentionLayer(
(attn): MultiHeadAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear(in_features=4096, out_features=1024, bias=False)
(output_proj): Linear(in_features=4096, out_features=4096, bias=False)
(pos_embeddings): RotaryPositionalEmbeddings()
)
...
)
(output): Linear(in_features=4096, out_features=128256, bias=False)
)
>>> # You can either specify the module to shard as a name ("Linear") or the module type (torch.nn.Linear)
>>> strategy = FSDP2Strategy(modules_to_shard=["TransformerSelfAttentionLayer", "Linear"])
"""

modules_to_shard: Union[
Literal["all"],
Iterable[Union[str, Type[torch.nn.Module]]],
] = "all"
reshard_after_forward: Union[bool, int] = True
mp_policy: Optional[Union[torch.dtype, MixedPrecisionPolicy]] = None
cpu_offload: bool = False


@dataclass
class TorchCompileParams:
"""
Expand Down Expand Up @@ -272,6 +344,89 @@ def prepare_fsdp(
return module


def prepare_fsdp2(
module: torch.nn.Module,
device: torch.device,
strategy: Optional[FSDP2Strategy] = None,
process_group: Optional[ProcessGroup] = None,
) -> torch.nn.Module:
"""
Utility to move a module to device and wrap in `FSDP2 <https://pytorch.org/docs/2.6/distributed.fsdp.fully_shard.html>`_
Args:
module: module to be wrapped in FSDP
device: device to which module will be moved
strategy: an instance of :class:`~torchtnt.utils.prepare_module.FSDP2Strategy` which defines the settings of FSDP APIs
"""
strategy = strategy or FSDP2Strategy()

# prepare kwargs for fully_shard api
pg = process_group or dist.distributed_c10d._get_default_group()
mesh = init_device_mesh(device.type, mesh_shape=(pg.size(),))
fsdp_kwargs: Dict[str, Any] = {
"mesh": mesh, # TODO we only configure 1D mesh for now, look into supporting HSDP
"reshard_after_forward": strategy.reshard_after_forward,
}
if strategy.cpu_offload:
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()
if (mp_policy := strategy.mp_policy) is not None:
if isinstance(mp_policy, MixedPrecisionPolicy):
fsdp_kwargs["mixed_precision"] = mp_policy
else:
fsdp_kwargs["mixed_precision"] = MixedPrecisionPolicy(
param_dtype=mp_policy,
reduce_dtype=mp_policy,
output_dtype=mp_policy,
cast_forward_inputs=True,
)

# parse out the modules_to_shard argument
modules_to_shard = strategy.modules_to_shard

shard_all = modules_to_shard == "all"
shard_module_names: Set[str] = set()
shard_module_types: Tuple[Type[torch.nn.Module], ...] = ()
if not shard_all:
assert (
type(modules_to_shard) is not str
), f"modules_to_shard must be an iterable of modules or 'all', got {shard_all}"

for item in modules_to_shard:
if isinstance(item, str):
shard_module_names.add(item)
else:
shard_module_types = shard_module_types + (item,)

# apply the fsdp2 sharding bottoms up
num_layers_sharded = 0
for _, m in reversed(list(module.named_modules())):
if shard_all:
# fully_shard does not support containers that do not implement forward
if not isinstance(m, (torch.nn.ModuleList, torch.nn.ModuleDict)):
fully_shard(m, **fsdp_kwargs)
num_layers_sharded += 1
elif (
isinstance(m, shard_module_types) or type(m).__name__ in shard_module_names
):
# if m exists in shard_module_types, then shard it
fully_shard(m, **fsdp_kwargs)
num_layers_sharded += 1

if num_layers_sharded == 0:
raise ValueError(
"No layer modules were sharded with fsdp2. Please check if shard conditions are working as expected."
)

# shard the top level model, so that all params are moved off cpu to gpu
if not _is_fsdp_module(module):
fully_shard(module, **fsdp_kwargs)

# materialized sharded meta weights to device
materialize_meta_params(module, device)

return module


class FSDPOptimizerWrapper:
"""
Wrapper for FSDP optimizer to call specific FSDP optimizer state checkpointing APIs.
Expand Down Expand Up @@ -301,7 +456,7 @@ def _is_fsdp_module(module: torch.nn.Module) -> bool:
# Also check for composable FSDP API
maybe_composable_state = _get_module_state(module)
if maybe_composable_state is not None:
return isinstance(maybe_composable_state, _FSDPState)
return isinstance(maybe_composable_state, (_FSDPState, FSDPState))

return False

Expand Down Expand Up @@ -366,6 +521,8 @@ def prepare_module(
"Torch compile requires FSDPStrategy's use_orig_params to be True, since AOTAutograd needs to be aware of the original parameters"
)
module = prepare_fsdp(module, device, strategy)
elif isinstance(strategy, FSDP2Strategy):
module = prepare_fsdp2(module, device, strategy)
else:
module = module.to(device)

Expand Down

0 comments on commit d71a41b

Please sign in to comment.