Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fsdp2 support #967

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions tests/utils/test_prepare_module_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,32 @@

# pyre-strict
import unittest
from typing import Any

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

try:
from torch.distributed.fsdp import fully_shard
except ImportError:

def noop(*args: Any, **kwargs: Any) -> None:
pass

fully_shard = noop

from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
from torch.nn.parallel import DistributedDataParallel as DDP
from torchtnt.utils.distributed import spawn_multi_process
from torchtnt.utils.env import init_from_env
from torchtnt.utils.prepare_module import (
_is_fsdp_module,
DDPStrategy,
FSDP2Strategy,
FSDPStrategy,
prepare_ddp,
prepare_fsdp,
prepare_fsdp2,
prepare_module,
)
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
Expand Down Expand Up @@ -111,6 +124,11 @@ def _test_is_fsdp_module() -> None:
model = FSDP(model)
assert _is_fsdp_module(model)

# test fsdp2
model = torch.nn.Linear(1, 1, device=device)
fully_shard(model)
assert _is_fsdp_module(model)

@skip_if_not_distributed
@skip_if_not_gpu
def test_fdsp_precision(self) -> None:
Expand Down Expand Up @@ -276,3 +294,118 @@ def _test_prepare_module_fsdp_string_wrapped_in_fsdp() -> None:
tc = unittest.TestCase()

tc.assertTrue(isinstance(fsdp_module, FSDP))

@skip_if_not_distributed
@skip_if_not_gpu
def test_prepare_fsdp2(self) -> None:
"""
Launch tests of FSDP2 strategy
"""

spawn_multi_process(
1,
"nccl",
self._test_prepare_fsdp2_none_sharded_raises,
)

spawn_multi_process(
1,
"nccl",
self._test_prepare_fsdp2_shard_all,
)

spawn_multi_process(
1,
"nccl",
self._test_prepare_fsdp2_submodule,
)

spawn_multi_process(
1,
"nccl",
self._test_prepare_fsdp2_meta_device,
)

@staticmethod
def _test_prepare_fsdp2_none_sharded_raises() -> None:
"""
Test with a strategy that does not shard any modules, should raise error
"""
tc = unittest.TestCase()

module = SimpleModule()
device = torch.device("cuda")
strategy = FSDP2Strategy(modules_to_shard=[])
with tc.assertRaises(ValueError):
prepare_fsdp2(module, device, strategy)

@staticmethod
def _test_prepare_fsdp2_shard_all() -> None:
"""
Test with a strategy that shards all modules
"""
tc = unittest.TestCase()

module = SimpleModule()
device = torch.device("cuda")
strategy = FSDP2Strategy(modules_to_shard="all")
prepare_fsdp2(module, device, strategy)

for submodule in module.modules():
tc.assertTrue(_is_fsdp_module(submodule))

@staticmethod
def _test_prepare_fsdp2_submodule() -> None:
"""
Test with a strategy that shards modules (either str or module type)
"""
tc = unittest.TestCase()

for t in (torch.nn.Linear, "Linear"):
module = SimpleModule()
device = torch.device("cuda")
strategy = FSDP2Strategy(modules_to_shard=(t,))
prepare_fsdp2(module, device, strategy)

for submodule in module.modules():
if isinstance(submodule, torch.nn.Conv2d):
tc.assertFalse(_is_fsdp_module(submodule))
else:
# linear and SimpleModule are fsdp modules
tc.assertTrue(_is_fsdp_module(submodule))

@staticmethod
def _test_prepare_fsdp2_meta_device() -> None:
"""
Test with a strategy that shards specific modules on meta device
"""
tc = unittest.TestCase()

module = SimpleModule(meta_device=True)
device = torch.device("cuda")
strategy = FSDP2Strategy(modules_to_shard=(torch.nn.Linear,))
prepare_fsdp2(module, device, strategy)

for submodule in module.modules():
if isinstance(submodule, torch.nn.Conv2d):
tc.assertFalse(_is_fsdp_module(submodule))
else:
# linear and SimpleModule are fsdp modules
tc.assertTrue(_is_fsdp_module(submodule))


class SimpleModule(torch.nn.Module):
def __init__(self, meta_device: bool = False) -> None:
super(SimpleModule, self).__init__()
self.linear = torch.nn.Linear(10, 10, device="meta" if meta_device else None)
self.conv1 = torch.nn.Conv2d(
in_channels=3,
out_channels=16,
kernel_size=3,
stride=1,
padding=1,
device="meta" if meta_device else None,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
159 changes: 158 additions & 1 deletion torchtnt/utils/prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
ContextManager,
Dict,
Iterable,
Literal,
Optional,
Set,
Tuple,
Type,
Union,
)

Expand All @@ -30,6 +33,29 @@
checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.device_mesh import init_device_mesh

try:
from torch.distributed.fsdp import (
CPUOffloadPolicy,
fully_shard,
MixedPrecisionPolicy,
)
from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState
except ImportError:

def noop(*args: Any, **kwargs: Any) -> None:
pass

class NOOP:
def __init__(self, *args: Any, **kwargs: Any) -> None:
pass

fully_shard = noop
MixedPrecisionPolicy = NOOP
CPUOffloadPolicy = NOOP
FSDPState = NOOP

from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
StateDictType as _StateDictType,
Expand Down Expand Up @@ -146,6 +172,52 @@ def __post_init__(self) -> None:
self.mixed_precision = self.mixed_precision.to_native_mixed_precision()


@dataclass
class FSDP2Strategy(Strategy):
"""
Dataclass representing the `FSDP2 <https://pytorch.org/docs/2.6/distributed.fsdp.fully_shard.html>`_ strategy.
For more details on the args, see the link.

Args:
modules_to_shard: A list of modules that should be sharded across devices. Options are 'all' to shard all submodules, or a list of module names/module types.
reshard_after_forward: If True, reshards parameters after the forward pass to optimize memory usage.
mp_policy: Controls mixed precision policy. If only dtype is provided, it will be used to cast all relevant parts of model. If None, no mixed precision is used
cpu_offload: If True, enables CPU offloading of model parameters to reduce GPU memory usage.

Note:
It is recommended to specify specific modules to shard to avoid unnecessary sharding of all submodules, which has
communication overhead.

Example:
>>> model
TransformerDecoder(
(tok_embeddings): Embedding(128256, 4096)
(layers): ModuleList(
(0-31): 32 x TransformerSelfAttentionLayer(
(attn): MultiHeadAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear(in_features=4096, out_features=1024, bias=False)
(output_proj): Linear(in_features=4096, out_features=4096, bias=False)
(pos_embeddings): RotaryPositionalEmbeddings()
)
...
)
(output): Linear(in_features=4096, out_features=128256, bias=False)
)
>>> # You can either specify the module to shard as a name ("Linear") or the module type (torch.nn.Linear)
>>> strategy = FSDP2Strategy(modules_to_shard=["TransformerSelfAttentionLayer", "Linear"])
"""

modules_to_shard: Union[
Literal["all"],
Iterable[Union[str, Type[torch.nn.Module]]],
] = "all"
reshard_after_forward: Union[bool, int] = True
mp_policy: Optional[Union[torch.dtype, MixedPrecisionPolicy]] = None
cpu_offload: bool = False


@dataclass
class TorchCompileParams:
"""
Expand Down Expand Up @@ -272,6 +344,89 @@ def prepare_fsdp(
return module


def prepare_fsdp2(
module: torch.nn.Module,
device: torch.device,
strategy: Optional[FSDP2Strategy] = None,
process_group: Optional[ProcessGroup] = None,
) -> torch.nn.Module:
"""
Utility to move a module to device and wrap in `FSDP2 <https://pytorch.org/docs/2.6/distributed.fsdp.fully_shard.html>`_

Args:
module: module to be wrapped in FSDP
device: device to which module will be moved
strategy: an instance of :class:`~torchtnt.utils.prepare_module.FSDP2Strategy` which defines the settings of FSDP APIs
"""
strategy = strategy or FSDP2Strategy()

# prepare kwargs for fully_shard api
pg = process_group or dist.distributed_c10d._get_default_group()
mesh = init_device_mesh(device.type, mesh_shape=(pg.size(),))
fsdp_kwargs: Dict[str, Any] = {
"mesh": mesh, # TODO we only configure 1D mesh for now, look into supporting HSDP
"reshard_after_forward": strategy.reshard_after_forward,
}
if strategy.cpu_offload:
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()
if (mp_policy := strategy.mp_policy) is not None:
if isinstance(mp_policy, MixedPrecisionPolicy):
fsdp_kwargs["mixed_precision"] = mp_policy
else:
fsdp_kwargs["mixed_precision"] = MixedPrecisionPolicy(
param_dtype=mp_policy,
reduce_dtype=mp_policy,
output_dtype=mp_policy,
cast_forward_inputs=True,
)

# parse out the modules_to_shard argument
modules_to_shard = strategy.modules_to_shard

shard_all = modules_to_shard == "all"
shard_module_names: Set[str] = set()
shard_module_types: Tuple[Type[torch.nn.Module], ...] = ()
if not shard_all:
assert (
type(modules_to_shard) is not str
), f"modules_to_shard must be an iterable of modules or 'all', got {shard_all}"

for item in modules_to_shard:
if isinstance(item, str):
shard_module_names.add(item)
else:
shard_module_types = shard_module_types + (item,)

# apply the fsdp2 sharding bottoms up
num_layers_sharded = 0
for _, m in reversed(list(module.named_modules())):
if shard_all:
# fully_shard does not support containers that do not implement forward
if not isinstance(m, (torch.nn.ModuleList, torch.nn.ModuleDict)):
fully_shard(m, **fsdp_kwargs)
num_layers_sharded += 1
elif (
isinstance(m, shard_module_types) or type(m).__name__ in shard_module_names
):
# if m exists in shard_module_types, then shard it
fully_shard(m, **fsdp_kwargs)
num_layers_sharded += 1

if num_layers_sharded == 0:
raise ValueError(
"No layer modules were sharded with fsdp2. Please check if shard conditions are working as expected."
)

# shard the top level model, so that all params are moved off cpu to gpu
if not _is_fsdp_module(module):
fully_shard(module, **fsdp_kwargs)

# materialized sharded meta weights to device
materialize_meta_params(module, device)

return module


class FSDPOptimizerWrapper:
"""
Wrapper for FSDP optimizer to call specific FSDP optimizer state checkpointing APIs.
Expand Down Expand Up @@ -301,7 +456,7 @@ def _is_fsdp_module(module: torch.nn.Module) -> bool:
# Also check for composable FSDP API
maybe_composable_state = _get_module_state(module)
if maybe_composable_state is not None:
return isinstance(maybe_composable_state, _FSDPState)
return isinstance(maybe_composable_state, (_FSDPState, FSDPState))

return False

Expand Down Expand Up @@ -366,6 +521,8 @@ def prepare_module(
"Torch compile requires FSDPStrategy's use_orig_params to be True, since AOTAutograd needs to be aware of the original parameters"
)
module = prepare_fsdp(module, device, strategy)
elif isinstance(strategy, FSDP2Strategy):
module = prepare_fsdp2(module, device, strategy)
else:
module = module.to(device)

Expand Down