Skip to content

Commit

Permalink
[FSDP2] Allowed List[nn.Module] as arg (#127786)
Browse files Browse the repository at this point in the history
This PR allows `fully_shard`'s first argument to be `List[nn.Module]` instead of strictly `nn.Module`. This allows more flexible grouping of modules/parameters for communication, which can lead to memory savings and/or more efficient communication.

**Approach**
At a high level, we can think of a model as a tree of modules. Previously, we could only select specific module nodes in this tree as representing one FSDP parameter group. With this PR, we can select a group of module nodes, effectively becoming a single super node.

To implement the runtime schedule, we define new forward hooks that run based on the following semantics:
- If a module is the first to run the pre-hook, actually run the given pre-hook. Otherwise, the pre-hook is no-op.
- If a module is the last to run the post-hook, actually run the given post-hook. Otherwise, the post-hook is a no-op.
- First and last are determined by scoreboarding against a set of the modules.
- This set must get cleared at the end of backward in the case that >=1 module in the list is never used, in which case we still want the forward hooks to run in the next forward after this backward.

Beyond these new forward hooks, everything else is some simple generalization from `Module` to `List[Module]` or `Tuple[Module, ...]`.

**Examples**
This PR enables wrapping Llama models more efficiently by grouping the final norm and output linear together: pytorch/torchtitan#382.

If at least one of the modules in the list does not run forward before backward, then there will be a warning message like:
```
1 of the 2 modules passed to fully_shard did not run forward before backward, which is error-prone since FSDP post-forward/pre-backward logic will not run for these modules. We recommend passing only modules that run forward together. Modules that did not run forward: [FSDPLinear(in_features=1, out_features=1, bias=True)]
```

Pull Request resolved: #127786
Approved by: https://github.com/yf225, https://github.com/weifengpy
ghstack dependencies: #127773
ghstack-source-id: c70870546bed6108b163062b88415cbba3d37925
  • Loading branch information
awgu committed Jul 17, 2024
1 parent 010b4f8 commit 9d953fb
Show file tree
Hide file tree
Showing 8 changed files with 402 additions and 72 deletions.
142 changes: 140 additions & 2 deletions test/distributed/_composable/fsdp/test_fully_shard_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _init_fsdp_param_group(
)
fsdp_param_group = FSDPParamGroup(
list(module.parameters()),
module,
(module,),
mesh_info,
post_forward_mesh_info,
self.device,
Expand Down Expand Up @@ -176,7 +176,7 @@ def check_all_gathered_params(
orig_params, reshard_after_forward
)
fsdp_params = fsdp_param_group.fsdp_params
module = fsdp_param_group.module
module = fsdp_param_group.modules[0]

# Sanity check that the parameter sharding is as expected
for orig_param, param in zip(orig_params, module.parameters()):
Expand Down Expand Up @@ -772,6 +772,144 @@ def set_backward_prefetch(model: Transformer, num_to_prefetch: int) -> None:
self.assertEqual(events, expected_backward_events)
events.clear()

@skip_if_lt_x_gpu(2)
def test_fully_shard_multi_module_backward_prefetch(self):
n_layers = 5
model_args = ModelArgs(n_layers=n_layers, checkpoint_activations=True)
model = Transformer(model_args)
for i in range(n_layers):
if i == 0:
fully_shard(model.layers[i])
elif i % 2 == 1:
fully_shard([model.layers[i], model.layers[i + 1]])
fully_shard([model.tok_embeddings, model.pos_embeddings])
fully_shard([model.norm, model.output], reshard_after_forward=False)
fully_shard(model)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)

events: List[EventType] = []
unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events
)
post_backward_with_record = self._get_post_backward_with_record(
FSDPParamGroup.post_backward, events
)
inp = torch.randint(
0, model_args.vocab_size, (2, model_args.max_seq_len), device="cuda"
)
with patch_unshard(unshard_with_record), patch_post_backward(
post_backward_with_record
):
for iter_idx in range(3):
loss = model(inp)
expected_events = [
(
"unshard",
"tok_embeddings, pos_embeddings",
TrainingState.FORWARD,
),
("unshard", "layers.0", TrainingState.FORWARD),
("unshard", "layers.1, layers.2", TrainingState.FORWARD),
("unshard", "layers.3, layers.4", TrainingState.FORWARD),
("unshard", "norm, output", TrainingState.FORWARD),
]
self.assertEqual(events, expected_events)
events.clear()
loss.sum().backward()
expected_events = [
# (norm, output) does not reshard after forward, so there is
# no unshard to begin backward
("unshard", "layers.3, layers.4", TrainingState.PRE_BACKWARD),
("post_backward", "norm, output", TrainingState.POST_BACKWARD),
("unshard", "layers.1, layers.2", TrainingState.PRE_BACKWARD),
(
"post_backward",
"layers.3, layers.4",
TrainingState.POST_BACKWARD,
),
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
(
"post_backward",
"layers.1, layers.2",
TrainingState.POST_BACKWARD,
),
(
"unshard",
"tok_embeddings, pos_embeddings",
TrainingState.PRE_BACKWARD,
),
("post_backward", "layers.0", TrainingState.POST_BACKWARD),
(
"post_backward",
"tok_embeddings, pos_embeddings",
TrainingState.POST_BACKWARD,
),
]
events.clear()
optim.step()
optim.zero_grad()

@skip_if_lt_x_gpu(2)
def test_fully_shard_multi_module_unused_module(self):
class ModuleWithUnusedLinear(nn.Module):
def __init__(self):
super().__init__()
self.unused_lin = nn.Linear(1, 1)
self.lin = nn.Linear(16, 16)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return nn.functional.relu(self.lin(x))

model = nn.Sequential(
ModuleWithUnusedLinear(), ModuleWithUnusedLinear(), nn.Linear(16, 16)
)
fully_shard([model[0].unused_lin, model[0].lin], reshard_after_forward=True)
fully_shard([model[1].unused_lin, model[1].lin], reshard_after_forward=True)
fully_shard(model)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)

events: List[EventType] = []
unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events
)
post_backward_with_record = self._get_post_backward_with_record(
FSDPParamGroup.post_backward, events
)
inp = torch.randn((2, 16), device="cuda")
with patch_unshard(unshard_with_record), patch_post_backward(
post_backward_with_record
):
for iter_idx in range(3):
loss = model(inp)
expected_events = [
("unshard", "", TrainingState.FORWARD),
("unshard", "0.unused_lin, 0.lin", TrainingState.FORWARD),
("unshard", "1.unused_lin, 1.lin", TrainingState.FORWARD),
]
self.assertEqual(events, expected_events)
events.clear()
loss.sum().backward()
expected_events = [
# Since both `model[0]` and `model[1]` have unused modules
# that never ran forward, they do not reshard after forward
# despite setting it to `True`. Check that there are no
# unshards in backward.
(
"post_backward",
"1.unused_lin, 1.lin",
TrainingState.POST_BACKWARD,
),
(
"post_backward",
"0.unused_lin, 0.lin",
TrainingState.POST_BACKWARD,
),
("post_backward", "", TrainingState.POST_BACKWARD),
]
events.clear()
optim.step()
optim.zero_grad()

def _init_transformer(
self,
n_layers: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_dynamo_trace_use_training_state(self):
# Construct a dummy FSDPParamGroup, since we just want to test the `use_training_state` ctx manager.
param_group = FSDPParamGroup(
[], # params: List[nn.Parameter],
torch.nn.Linear(1, 1), # module: nn.Module,
(torch.nn.Linear(1, 1),), # module: Tuple[nn.Module, ...],
None, # mesh_info: FSDPMeshInfo,
None, # post_forward_mesh_info: Optional[FSDPMeshInfo],
None, # device: torch.device,
Expand Down
76 changes: 69 additions & 7 deletions test/distributed/_composable/fsdp/test_fully_shard_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def world_size(self) -> int:
def test_managed_modules_single(self):
model = MLP(8)
# Assume calling `fully_shard` on `model`
managed_modules = _get_managed_modules(model)
managed_modules = _get_managed_modules((model,))
expected_managed_modules = list(model.modules())
self._check_managed_modules(managed_modules, expected_managed_modules)

Expand All @@ -159,7 +159,7 @@ def test_managed_modules_nested(self):
model = nn.Sequential(*[MLP(8) for _ in range(2)])
fully_shard(model[0])
# Assume calling `fully_shard` on `model`
managed_modules = _get_managed_modules(model)
managed_modules = _get_managed_modules((model,))
expected_managed_modules = list(model[1].modules()) + [model]
self._check_managed_modules(managed_modules, expected_managed_modules)

Expand All @@ -169,7 +169,7 @@ def test_managed_modules_nested_fully_shard_and_replicate(self):
replicate(model[0])
fully_shard(model[2])
# Assume calling `fully_shard` on `model`
managed_modules = _get_managed_modules(model)
managed_modules = _get_managed_modules((model,))
expected_managed_modules = list(model[1].modules()) + [model]
self._check_managed_modules(managed_modules, expected_managed_modules)

Expand All @@ -178,11 +178,26 @@ def test_managed_modules_duplicate(self):
mlp = MLP(8)
model = nn.Sequential(mlp, mlp) # duplicate MLP
# Assume calling `fully_shard` on `model`
managed_modules = _get_managed_modules(model)
managed_modules = _get_managed_modules((model,))
# Check that the duplicate module is only counted once
expected_managed_modules = list(mlp.modules()) + [model]
self._check_managed_modules(managed_modules, expected_managed_modules)

@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_managed_modules_list_of_mlps(self):
model = nn.Sequential(*[MLP(8) for _ in range(5)])
# Assume calling `fully_shard` on `[model[0], model[1], model[2]]`
managed_modules = _get_managed_modules((model[0], model[1], model[2]))
expected_managed_modules = (
list(model[0].modules())
+ list(model[1].modules())
+ list(model[2].modules())
)
self._check_managed_modules(managed_modules, expected_managed_modules)
# Assume calling `fully_shard` on `[model[1], model[3]]`
managed_modules = _get_managed_modules((model[1], model[3]))
expected_managed_modules = list(model[1].modules()) + list(model[3].modules())

def _check_managed_modules(
self,
managed_modules: List[nn.Module],
Expand All @@ -199,7 +214,7 @@ def test_managed_states_shared_params_and_buffers(self):
model[2].in_proj.weight = model[1].in_proj.weight
model[1].buffer = model[2].buffer
# Assume calling `fully_shard` on `model`
managed_modules = _get_managed_modules(model)
managed_modules = _get_managed_modules((model,))
params, buffers = _get_managed_states(managed_modules)
expected_params = list(model.parameters()) # de-dups shared
expected_buffers = list(model.buffers()) # de-dups shared
Expand All @@ -210,12 +225,30 @@ def test_managed_states_nested_fully_shard(self):
model = nn.Sequential(*[MLP(8, with_buffer=True) for _ in range(2)])
fully_shard(model[0])
# Assume calling `fully_shard` on `model`
managed_modules = _get_managed_modules(model)
managed_modules = _get_managed_modules((model,))
params, buffers = _get_managed_states(managed_modules)
expected_params = list(model[1].parameters())
expected_buffers = list(model[1].buffers())
self._check_managed_states(params, buffers, expected_params, expected_buffers)

@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_managed_states_list_of_mlps(self):
model = nn.Sequential(*[MLP(8, with_buffer=True) for _ in range(5)])
# Assume calling `fully_shard` on `[model[0], model[1], model[2]]`
managed_modules = _get_managed_modules((model[0], model[1], model[2]))
params, buffers = _get_managed_states(managed_modules)
expected_params = (
list(model[0].parameters())
+ list(model[1].parameters())
+ list(model[2].parameters())
)
expected_buffers = (
list(model[0].buffers())
+ list(model[1].buffers())
+ list(model[2].buffers())
)
self._check_managed_states(params, buffers, expected_params, expected_buffers)

def _check_managed_states(
self,
managed_params: List[nn.Parameter],
Expand All @@ -238,7 +271,7 @@ def world_size(self) -> int:
def test_get_param_module_infos_shared_params(self):
model = nn.Sequential(*[MLP(8) for _ in range(2)])
model[0].in_proj.weight = model[1].in_proj.weight
managed_modules = _get_managed_modules(model)
managed_modules = _get_managed_modules((model,))
params, _ = _get_managed_states(managed_modules)
param_module_infos = _get_param_module_infos(params, model)
self.assertEqual(len(param_module_infos), len(params))
Expand Down Expand Up @@ -283,6 +316,26 @@ def test_get_param_module_infos_duplicates(self):
ParamModuleInfo(mlp.out_proj, "bias", [], []),
]

@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_get_param_module_infos_list_of_mlps(self):
model = nn.Sequential(*[MLP(8) for _ in range(2)])
managed_modules = _get_managed_modules((model[0], model[1]))
params, _ = _get_managed_states(managed_modules)
param_module_infos = _get_param_module_infos(params, model)
self.assertEqual(len(param_module_infos), len(params))
expected_param_module_infos = [
ParamModuleInfo(model[0].in_proj, "weight", [], []),
ParamModuleInfo(model[0].in_proj, "bias", [], []),
ParamModuleInfo(model[0].out_proj, "weight", [], []),
ParamModuleInfo(model[0].out_proj, "bias", [], []),
ParamModuleInfo(model[1].in_proj, "weight", [], []),
ParamModuleInfo(model[1].in_proj, "bias", [], []),
ParamModuleInfo(model[1].out_proj, "weight", [], []),
ParamModuleInfo(model[1].out_proj, "bias", [], []),
]
self.assertEqual(len(param_module_infos), len(expected_param_module_infos))
self.assertEqual(param_module_infos, expected_param_module_infos)


class TestFullyShardShardedParameterTensor(FSDPTestMultiThread):
@property
Expand Down Expand Up @@ -468,6 +521,15 @@ def test_fully_shard_double_lazy_init(self):
with self.assertRaisesRegex(RuntimeError, regex):
root_state._lazy_init()

@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_fully_shard_multi_module_root(self):
model = nn.Sequential(MLP(8), MLP(8))
fully_shard([model[0], model[1]])
root_state = fully_shard.state(model[0])
regex = "FSDP requires a single root module but got "
with self.assertRaisesRegex(RuntimeError, regex):
root_state._lazy_init()


class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
@property
Expand Down
Loading

0 comments on commit 9d953fb

Please sign in to comment.