Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[POC] Showed more memory efficient FSDP wrapping #382

Draft
wants to merge 4 commits into
base: gh/awgu/6/base
Choose a base branch
from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Jun 3, 2024

Stack from ghstack (oldest at bottom):

This requires pytorch/pytorch#127786.

Experiment

  • Llama3-8B on 8xH100, 1D FSDP, local batch size 2, selective op AC, compiled_rmsnorm, torch.compile enabled per transformer block, fused AdamW
    • With this PR (68.09 GiB reserved memory):
    [rank0]:2024-07-11 10:55:21,533 - root - INFO - step:  1  loss: 12.2308  memory: 60.27GiB(63.41%)  wps: 233  mfu: 1.37%
    [rank0]:2024-07-11 10:55:21,534 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-07-11 10:55:24,063 - root - INFO - step:  2  loss: 12.0520  memory: 68.09GiB(71.65%)  wps: 6,479  mfu: 37.94%
    [rank0]:2024-07-11 10:55:26,596 - root - INFO - step:  3  loss: 11.7165  memory: 68.09GiB(71.65%)  wps: 6,470  mfu: 37.89%
    [rank0]:2024-07-11 10:55:29,139 - root - INFO - step:  4  loss: 11.3078  memory: 68.09GiB(71.65%)  wps: 6,445  mfu: 37.74%
    [rank0]:2024-07-11 10:55:31,681 - root - INFO - step:  5  loss: 10.8763  memory: 68.09GiB(71.65%)  wps: 6,446  mfu: 37.75%
    
    • Without this PR (69.04 GiB reserved memory):
    [rank0]:2024-07-11 11:03:35,749 - root - INFO - step:  1  loss: 12.2646  memory: 61.21GiB(64.41%)  wps: 305  mfu: 1.79%
    [rank0]:2024-07-11 11:03:35,749 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-07-11 11:03:38,284 - root - INFO - step:  2  loss: 12.0713  memory: 69.04GiB(72.65%)  wps: 6,464  mfu: 37.85%
    [rank0]:2024-07-11 11:03:40,821 - root - INFO - step:  3  loss: 11.7398  memory: 69.04GiB(72.65%)  wps: 6,460  mfu: 37.83%
    [rank0]:2024-07-11 11:03:43,356 - root - INFO - step:  4  loss: 11.3238  memory: 69.04GiB(72.65%)  wps: 6,462  mfu: 37.84%
    [rank0]:2024-07-11 11:03:45,898 - root - INFO - step:  5  loss: 10.9178  memory: 69.04GiB(72.65%)  wps: 6,448  mfu: 37.76%
    
  • Llama3-8B on 8xH100, 1D FSDP, local batch size 1, no AC, compiled_rmsnorm, torch.compile enabled per transformer block, fused AdamW
    • With this PR (68.36 GiB reserved memory):
    [rank0]:2024-07-11 12:53:24,747 - root - INFO - step:  1  loss: 12.2439  memory: 58.58GiB(61.63%)  wps: 148  mfu: 0.87%
    [rank0]:2024-07-11 12:53:24,750 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-07-11 12:53:26,042 - root - INFO - step:  2  loss: 12.0557  memory: 68.36GiB(71.93%)  wps: 6,342  mfu: 37.14%
    [rank0]:2024-07-11 12:53:27,338 - root - INFO - step:  3  loss: 11.7423  memory: 68.36GiB(71.93%)  wps: 6,324  mfu: 37.03%
    [rank0]:2024-07-11 12:53:28,630 - root - INFO - step:  4  loss: 11.3138  memory: 68.36GiB(71.93%)  wps: 6,343  mfu: 37.15%
    [rank0]:2024-07-11 12:53:29,927 - root - INFO - step:  5  loss: 10.9011  memory: 68.36GiB(71.93%)  wps: 6,319  mfu: 37.00%
    
    • Without this PR (67.50 GiB reserved memory):
    [rank0]:2024-07-11 12:50:09,792 - root - INFO - step:  1  loss: 12.2539  memory: 63.58GiB(66.90%)  wps: 146  mfu: 0.86%
    [rank0]:2024-07-11 12:50:09,792 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-07-11 12:50:11,087 - root - INFO - step:  2  loss: 12.0905  memory: 67.50GiB(71.02%)  wps: 6,328  mfu: 37.06%
    [rank0]:2024-07-11 12:50:12,385 - root - INFO - step:  3  loss: 11.7652  memory: 67.50GiB(71.02%)  wps: 6,314  mfu: 36.97%
    [rank0]:2024-07-11 12:50:13,680 - root - INFO - step:  4  loss: 11.2644  memory: 67.50GiB(71.02%)  wps: 6,327  mfu: 37.05%
    [rank0]:2024-07-11 12:50:14,978 - root - INFO - step:  5  loss: 10.8718  memory: 67.50GiB(71.02%)  wps: 6,315  mfu: 36.98%
    

For some reason, without AC, the new wrapping actually uses more memory. This could be due to memory fragmentation or compile reasons and needs more investigation.

awgu added a commit that referenced this pull request Jun 3, 2024
ghstack-source-id: 4c0382b83b7a84d9294c2dd0f15c51d527075cae
Pull Request resolved: #382
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 3, 2024
awgu added a commit that referenced this pull request Jun 3, 2024
ghstack-source-id: 2133694b979250588b0067fa4b19d6c094e06797
Pull Request resolved: #382
awgu added a commit to pytorch/pytorch that referenced this pull request Jul 11, 2024
This PR allows `fully_shard`'s first argument to be `List[nn.Module]` instead of strictly `nn.Module`. This allows more flexible grouping of modules/parameters for communication, which can lead to memory savings and/or more efficient communication.

**Approach**
At a high level, we can think of a model as a tree of modules. Previously, we could only select specific module nodes in this tree as representing one FSDP parameter group. With this PR, we can select a group of module nodes, effectively becoming a single super node.

To implement the runtime schedule, we define new forward hooks that run based on the following semantics:
- If a module is the first to run the pre-hook, actually run the given pre-hook. Otherwise, the pre-hook is no-op.
- If a module is the last to run the post-hook, actually run the given post-hook. Otherwise, the post-hook is a no-op.
- First and last are determined by scoreboarding against a set of the modules.
- This set must get cleared at the end of backward in the case that >=1 module in the list is never used, in which case we still want the forward hooks to run in the next forward after this backward.

Beyond these new forward hooks, everything else is some simple generalization from `Module` to `List[Module]` or `Tuple[Module, ...]`.

**Examples**
This PR enables wrapping Llama models more efficiently by grouping the final norm and output linear together: pytorch/torchtitan#382.


cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Jul 11, 2024
This PR allows `fully_shard`'s first argument to be `List[nn.Module]` instead of strictly `nn.Module`. This allows more flexible grouping of modules/parameters for communication, which can lead to memory savings and/or more efficient communication.

**Approach**
At a high level, we can think of a model as a tree of modules. Previously, we could only select specific module nodes in this tree as representing one FSDP parameter group. With this PR, we can select a group of module nodes, effectively becoming a single super node.

To implement the runtime schedule, we define new forward hooks that run based on the following semantics:
- If a module is the first to run the pre-hook, actually run the given pre-hook. Otherwise, the pre-hook is no-op.
- If a module is the last to run the post-hook, actually run the given post-hook. Otherwise, the post-hook is a no-op.
- First and last are determined by scoreboarding against a set of the modules.
- This set must get cleared at the end of backward in the case that >=1 module in the list is never used, in which case we still want the forward hooks to run in the next forward after this backward.

Beyond these new forward hooks, everything else is some simple generalization from `Module` to `List[Module]` or `Tuple[Module, ...]`.

**Examples**
This PR enables wrapping Llama models more efficiently by grouping the final norm and output linear together: pytorch/torchtitan#382.


cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
awgu added a commit that referenced this pull request Jul 11, 2024
ghstack-source-id: 36a58a9fc9fea8b85783f9b50be96a0684603fdc
Pull Request resolved: #382
This requires pytorch/pytorch#127786.

**Experiment**
- Llama3-8B on 8xH100, 1D FSDP, local batch size 2, selective op AC
    - With this PR (68.09 GiB reserved memory):
    ```
    [rank0]:2024-07-11 10:48:25,949 - root - INFO - step:  1  loss: 12.2554  memory: 60.27GiB(63.41%)  wps: 1,943  mfu: 11.38%
    [rank0]:2024-07-11 10:48:25,949 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-07-11 10:48:28,506 - root - INFO - step:  2  loss: 12.0724  memory: 68.09GiB(71.65%)  wps: 6,408  mfu: 37.53%
    [rank0]:2024-07-11 10:48:31,063 - root - INFO - step:  3  loss: 11.7467  memory: 68.09GiB(71.65%)  wps: 6,410  mfu: 37.53%
    [rank0]:2024-07-11 10:48:33,621 - root - INFO - step:  4  loss: 11.3360  memory: 68.09GiB(71.65%)  wps: 6,406  mfu: 37.51%
    [rank0]:2024-07-11 10:48:36,182 - root - INFO - step:  5  loss: 10.8909  memory: 68.09GiB(71.65%)  wps: 6,399  mfu: 37.47%
    ```
    - Without this PR (69.04 GiB reserved memory):
    ```
    [rank0]:2024-07-11 10:47:04,798 - root - INFO - step:  1  loss: 12.2421  memory: 61.21GiB(64.41%)  wps: 1,933  mfu: 11.32%
    [rank0]:2024-07-11 10:47:04,798 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-07-11 10:47:07,357 - root - INFO - step:  2  loss: 12.0625  memory: 69.04GiB(72.65%)  wps: 6,405  mfu: 37.50%
    [rank0]:2024-07-11 10:47:09,918 - root - INFO - step:  3  loss: 11.7242  memory: 69.04GiB(72.65%)  wps: 6,397  mfu: 37.46%
    [rank0]:2024-07-11 10:47:12,480 - root - INFO - step:  4  loss: 11.3072  memory: 69.04GiB(72.65%)  wps: 6,398  mfu: 37.47%
    [rank0]:2024-07-11 10:47:15,044 - root - INFO - step:  5  loss: 10.8761  memory: 69.04GiB(72.65%)  wps: 6,390  mfu: 37.42%
    ```


[ghstack-poisoned]
awgu added a commit that referenced this pull request Jul 11, 2024
ghstack-source-id: 9c9e2fef294d9c66ede9f1b8a9ab65510c2ed230
Pull Request resolved: #382
awgu added a commit to pytorch/pytorch that referenced this pull request Jul 15, 2024
This PR allows `fully_shard`'s first argument to be `List[nn.Module]` instead of strictly `nn.Module`. This allows more flexible grouping of modules/parameters for communication, which can lead to memory savings and/or more efficient communication.

**Approach**
At a high level, we can think of a model as a tree of modules. Previously, we could only select specific module nodes in this tree as representing one FSDP parameter group. With this PR, we can select a group of module nodes, effectively becoming a single super node.

To implement the runtime schedule, we define new forward hooks that run based on the following semantics:
- If a module is the first to run the pre-hook, actually run the given pre-hook. Otherwise, the pre-hook is no-op.
- If a module is the last to run the post-hook, actually run the given post-hook. Otherwise, the post-hook is a no-op.
- First and last are determined by scoreboarding against a set of the modules.
- This set must get cleared at the end of backward in the case that >=1 module in the list is never used, in which case we still want the forward hooks to run in the next forward after this backward.

Beyond these new forward hooks, everything else is some simple generalization from `Module` to `List[Module]` or `Tuple[Module, ...]`.

**Examples**
This PR enables wrapping Llama models more efficiently by grouping the final norm and output linear together: pytorch/torchtitan#382.

If at least one of the modules in the list does not run forward before backward, then there will be a warning message like:
```
1 of the 2 modules passed to fully_shard did not run forward before backward, which is error-prone since FSDP post-forward/pre-backward logic will not run for these modules. We recommend passing only modules that run forward together. Modules that did not run forward: [FSDPLinear(in_features=1, out_features=1, bias=True)]
```

To-do: add a test for shared embedding/output projection passed as list (e.g. change `multi_module` from `bool` to `None`, embedding/norm/output, or norm/output)


cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Jul 15, 2024
This PR allows `fully_shard`'s first argument to be `List[nn.Module]` instead of strictly `nn.Module`. This allows more flexible grouping of modules/parameters for communication, which can lead to memory savings and/or more efficient communication.

**Approach**
At a high level, we can think of a model as a tree of modules. Previously, we could only select specific module nodes in this tree as representing one FSDP parameter group. With this PR, we can select a group of module nodes, effectively becoming a single super node.

To implement the runtime schedule, we define new forward hooks that run based on the following semantics:
- If a module is the first to run the pre-hook, actually run the given pre-hook. Otherwise, the pre-hook is no-op.
- If a module is the last to run the post-hook, actually run the given post-hook. Otherwise, the post-hook is a no-op.
- First and last are determined by scoreboarding against a set of the modules.
- This set must get cleared at the end of backward in the case that >=1 module in the list is never used, in which case we still want the forward hooks to run in the next forward after this backward.

Beyond these new forward hooks, everything else is some simple generalization from `Module` to `List[Module]` or `Tuple[Module, ...]`.

**Examples**
This PR enables wrapping Llama models more efficiently by grouping the final norm and output linear together: pytorch/torchtitan#382.

If at least one of the modules in the list does not run forward before backward, then there will be a warning message like:
```
1 of the 2 modules passed to fully_shard did not run forward before backward, which is error-prone since FSDP post-forward/pre-backward logic will not run for these modules. We recommend passing only modules that run forward together. Modules that did not run forward: [FSDPLinear(in_features=1, out_features=1, bias=True)]
```

To-do: add a test for shared embedding/output projection passed as list (e.g. change `multi_module` from `bool` to `None`, embedding/norm/output, or norm/output)


cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Jul 15, 2024
This PR allows `fully_shard`'s first argument to be `List[nn.Module]` instead of strictly `nn.Module`. This allows more flexible grouping of modules/parameters for communication, which can lead to memory savings and/or more efficient communication.

**Approach**
At a high level, we can think of a model as a tree of modules. Previously, we could only select specific module nodes in this tree as representing one FSDP parameter group. With this PR, we can select a group of module nodes, effectively becoming a single super node.

To implement the runtime schedule, we define new forward hooks that run based on the following semantics:
- If a module is the first to run the pre-hook, actually run the given pre-hook. Otherwise, the pre-hook is no-op.
- If a module is the last to run the post-hook, actually run the given post-hook. Otherwise, the post-hook is a no-op.
- First and last are determined by scoreboarding against a set of the modules.
- This set must get cleared at the end of backward in the case that >=1 module in the list is never used, in which case we still want the forward hooks to run in the next forward after this backward.

Beyond these new forward hooks, everything else is some simple generalization from `Module` to `List[Module]` or `Tuple[Module, ...]`.

**Examples**
This PR enables wrapping Llama models more efficiently by grouping the final norm and output linear together: pytorch/torchtitan#382.

If at least one of the modules in the list does not run forward before backward, then there will be a warning message like:
```
1 of the 2 modules passed to fully_shard did not run forward before backward, which is error-prone since FSDP post-forward/pre-backward logic will not run for these modules. We recommend passing only modules that run forward together. Modules that did not run forward: [FSDPLinear(in_features=1, out_features=1, bias=True)]
```



cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Jul 15, 2024
This PR allows `fully_shard`'s first argument to be `List[nn.Module]` instead of strictly `nn.Module`. This allows more flexible grouping of modules/parameters for communication, which can lead to memory savings and/or more efficient communication.

**Approach**
At a high level, we can think of a model as a tree of modules. Previously, we could only select specific module nodes in this tree as representing one FSDP parameter group. With this PR, we can select a group of module nodes, effectively becoming a single super node.

To implement the runtime schedule, we define new forward hooks that run based on the following semantics:
- If a module is the first to run the pre-hook, actually run the given pre-hook. Otherwise, the pre-hook is no-op.
- If a module is the last to run the post-hook, actually run the given post-hook. Otherwise, the post-hook is a no-op.
- First and last are determined by scoreboarding against a set of the modules.
- This set must get cleared at the end of backward in the case that >=1 module in the list is never used, in which case we still want the forward hooks to run in the next forward after this backward.

Beyond these new forward hooks, everything else is some simple generalization from `Module` to `List[Module]` or `Tuple[Module, ...]`.

**Examples**
This PR enables wrapping Llama models more efficiently by grouping the final norm and output linear together: pytorch/torchtitan#382.

If at least one of the modules in the list does not run forward before backward, then there will be a warning message like:
```
1 of the 2 modules passed to fully_shard did not run forward before backward, which is error-prone since FSDP post-forward/pre-backward logic will not run for these modules. We recommend passing only modules that run forward together. Modules that did not run forward: [FSDPLinear(in_features=1, out_features=1, bias=True)]
```



cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Jul 15, 2024
This PR allows `fully_shard`'s first argument to be `List[nn.Module]` instead of strictly `nn.Module`. This allows more flexible grouping of modules/parameters for communication, which can lead to memory savings and/or more efficient communication.

**Approach**
At a high level, we can think of a model as a tree of modules. Previously, we could only select specific module nodes in this tree as representing one FSDP parameter group. With this PR, we can select a group of module nodes, effectively becoming a single super node.

To implement the runtime schedule, we define new forward hooks that run based on the following semantics:
- If a module is the first to run the pre-hook, actually run the given pre-hook. Otherwise, the pre-hook is no-op.
- If a module is the last to run the post-hook, actually run the given post-hook. Otherwise, the post-hook is a no-op.
- First and last are determined by scoreboarding against a set of the modules.
- This set must get cleared at the end of backward in the case that >=1 module in the list is never used, in which case we still want the forward hooks to run in the next forward after this backward.

Beyond these new forward hooks, everything else is some simple generalization from `Module` to `List[Module]` or `Tuple[Module, ...]`.

**Examples**
This PR enables wrapping Llama models more efficiently by grouping the final norm and output linear together: pytorch/torchtitan#382.

If at least one of the modules in the list does not run forward before backward, then there will be a warning message like:
```
1 of the 2 modules passed to fully_shard did not run forward before backward, which is error-prone since FSDP post-forward/pre-backward logic will not run for these modules. We recommend passing only modules that run forward together. Modules that did not run forward: [FSDPLinear(in_features=1, out_features=1, bias=True)]
```



cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Jul 15, 2024
This PR allows `fully_shard`'s first argument to be `List[nn.Module]` instead of strictly `nn.Module`. This allows more flexible grouping of modules/parameters for communication, which can lead to memory savings and/or more efficient communication.

**Approach**
At a high level, we can think of a model as a tree of modules. Previously, we could only select specific module nodes in this tree as representing one FSDP parameter group. With this PR, we can select a group of module nodes, effectively becoming a single super node.

To implement the runtime schedule, we define new forward hooks that run based on the following semantics:
- If a module is the first to run the pre-hook, actually run the given pre-hook. Otherwise, the pre-hook is no-op.
- If a module is the last to run the post-hook, actually run the given post-hook. Otherwise, the post-hook is a no-op.
- First and last are determined by scoreboarding against a set of the modules.
- This set must get cleared at the end of backward in the case that >=1 module in the list is never used, in which case we still want the forward hooks to run in the next forward after this backward.

Beyond these new forward hooks, everything else is some simple generalization from `Module` to `List[Module]` or `Tuple[Module, ...]`.

**Examples**
This PR enables wrapping Llama models more efficiently by grouping the final norm and output linear together: pytorch/torchtitan#382.

If at least one of the modules in the list does not run forward before backward, then there will be a warning message like:
```
1 of the 2 modules passed to fully_shard did not run forward before backward, which is error-prone since FSDP post-forward/pre-backward logic will not run for these modules. We recommend passing only modules that run forward together. Modules that did not run forward: [FSDPLinear(in_features=1, out_features=1, bias=True)]
```



cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jul 15, 2024
This PR allows `fully_shard`'s first argument to be `List[nn.Module]` instead of strictly `nn.Module`. This allows more flexible grouping of modules/parameters for communication, which can lead to memory savings and/or more efficient communication.

**Approach**
At a high level, we can think of a model as a tree of modules. Previously, we could only select specific module nodes in this tree as representing one FSDP parameter group. With this PR, we can select a group of module nodes, effectively becoming a single super node.

To implement the runtime schedule, we define new forward hooks that run based on the following semantics:
- If a module is the first to run the pre-hook, actually run the given pre-hook. Otherwise, the pre-hook is no-op.
- If a module is the last to run the post-hook, actually run the given post-hook. Otherwise, the post-hook is a no-op.
- First and last are determined by scoreboarding against a set of the modules.
- This set must get cleared at the end of backward in the case that >=1 module in the list is never used, in which case we still want the forward hooks to run in the next forward after this backward.

Beyond these new forward hooks, everything else is some simple generalization from `Module` to `List[Module]` or `Tuple[Module, ...]`.

**Examples**
This PR enables wrapping Llama models more efficiently by grouping the final norm and output linear together: pytorch/torchtitan#382.

If at least one of the modules in the list does not run forward before backward, then there will be a warning message like:
```
1 of the 2 modules passed to fully_shard did not run forward before backward, which is error-prone since FSDP post-forward/pre-backward logic will not run for these modules. We recommend passing only modules that run forward together. Modules that did not run forward: [FSDPLinear(in_features=1, out_features=1, bias=True)]
```

Pull Request resolved: #127786
Approved by: https://github.com/yf225, https://github.com/weifengpy
ghstack dependencies: #127773
awgu added a commit to pytorch/pytorch that referenced this pull request Jul 17, 2024
This PR allows `fully_shard`'s first argument to be `List[nn.Module]` instead of strictly `nn.Module`. This allows more flexible grouping of modules/parameters for communication, which can lead to memory savings and/or more efficient communication.

**Approach**
At a high level, we can think of a model as a tree of modules. Previously, we could only select specific module nodes in this tree as representing one FSDP parameter group. With this PR, we can select a group of module nodes, effectively becoming a single super node.

To implement the runtime schedule, we define new forward hooks that run based on the following semantics:
- If a module is the first to run the pre-hook, actually run the given pre-hook. Otherwise, the pre-hook is no-op.
- If a module is the last to run the post-hook, actually run the given post-hook. Otherwise, the post-hook is a no-op.
- First and last are determined by scoreboarding against a set of the modules.
- This set must get cleared at the end of backward in the case that >=1 module in the list is never used, in which case we still want the forward hooks to run in the next forward after this backward.

Beyond these new forward hooks, everything else is some simple generalization from `Module` to `List[Module]` or `Tuple[Module, ...]`.

**Examples**
This PR enables wrapping Llama models more efficiently by grouping the final norm and output linear together: pytorch/torchtitan#382.

If at least one of the modules in the list does not run forward before backward, then there will be a warning message like:
```
1 of the 2 modules passed to fully_shard did not run forward before backward, which is error-prone since FSDP post-forward/pre-backward logic will not run for these modules. We recommend passing only modules that run forward together. Modules that did not run forward: [FSDPLinear(in_features=1, out_features=1, bias=True)]
```

Pull Request resolved: #127786
Approved by: https://github.com/yf225, https://github.com/weifengpy
ghstack dependencies: #127773
ghstack-source-id: c70870546bed6108b163062b88415cbba3d37925
awgu added a commit to pytorch/pytorch that referenced this pull request Jul 17, 2024
This PR allows `fully_shard`'s first argument to be `List[nn.Module]` instead of strictly `nn.Module`. This allows more flexible grouping of modules/parameters for communication, which can lead to memory savings and/or more efficient communication.

**Approach**
At a high level, we can think of a model as a tree of modules. Previously, we could only select specific module nodes in this tree as representing one FSDP parameter group. With this PR, we can select a group of module nodes, effectively becoming a single super node.

To implement the runtime schedule, we define new forward hooks that run based on the following semantics:
- If a module is the first to run the pre-hook, actually run the given pre-hook. Otherwise, the pre-hook is no-op.
- If a module is the last to run the post-hook, actually run the given post-hook. Otherwise, the post-hook is a no-op.
- First and last are determined by scoreboarding against a set of the modules.
- This set must get cleared at the end of backward in the case that >=1 module in the list is never used, in which case we still want the forward hooks to run in the next forward after this backward.

Beyond these new forward hooks, everything else is some simple generalization from `Module` to `List[Module]` or `Tuple[Module, ...]`.

**Examples**
This PR enables wrapping Llama models more efficiently by grouping the final norm and output linear together: pytorch/torchtitan#382.

If at least one of the modules in the list does not run forward before backward, then there will be a warning message like:
```
1 of the 2 modules passed to fully_shard did not run forward before backward, which is error-prone since FSDP post-forward/pre-backward logic will not run for these modules. We recommend passing only modules that run forward together. Modules that did not run forward: [FSDPLinear(in_features=1, out_features=1, bias=True)]
```



cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse penguinwu tianyu-l yf225 chauhang

[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Jul 17, 2024
This PR allows `fully_shard`'s first argument to be `List[nn.Module]` instead of strictly `nn.Module`. This allows more flexible grouping of modules/parameters for communication, which can lead to memory savings and/or more efficient communication.

**Approach**
At a high level, we can think of a model as a tree of modules. Previously, we could only select specific module nodes in this tree as representing one FSDP parameter group. With this PR, we can select a group of module nodes, effectively becoming a single super node.

To implement the runtime schedule, we define new forward hooks that run based on the following semantics:
- If a module is the first to run the pre-hook, actually run the given pre-hook. Otherwise, the pre-hook is no-op.
- If a module is the last to run the post-hook, actually run the given post-hook. Otherwise, the post-hook is a no-op.
- First and last are determined by scoreboarding against a set of the modules.
- This set must get cleared at the end of backward in the case that >=1 module in the list is never used, in which case we still want the forward hooks to run in the next forward after this backward.

Beyond these new forward hooks, everything else is some simple generalization from `Module` to `List[Module]` or `Tuple[Module, ...]`.

**Examples**
This PR enables wrapping Llama models more efficiently by grouping the final norm and output linear together: pytorch/torchtitan#382.

If at least one of the modules in the list does not run forward before backward, then there will be a warning message like:
```
1 of the 2 modules passed to fully_shard did not run forward before backward, which is error-prone since FSDP post-forward/pre-backward logic will not run for these modules. We recommend passing only modules that run forward together. Modules that did not run forward: [FSDPLinear(in_features=1, out_features=1, bias=True)]
```



cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse penguinwu tianyu-l yf225 chauhang

[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Jul 17, 2024
This PR allows `fully_shard`'s first argument to be `List[nn.Module]` instead of strictly `nn.Module`. This allows more flexible grouping of modules/parameters for communication, which can lead to memory savings and/or more efficient communication.

**Approach**
At a high level, we can think of a model as a tree of modules. Previously, we could only select specific module nodes in this tree as representing one FSDP parameter group. With this PR, we can select a group of module nodes, effectively becoming a single super node.

To implement the runtime schedule, we define new forward hooks that run based on the following semantics:
- If a module is the first to run the pre-hook, actually run the given pre-hook. Otherwise, the pre-hook is no-op.
- If a module is the last to run the post-hook, actually run the given post-hook. Otherwise, the post-hook is a no-op.
- First and last are determined by scoreboarding against a set of the modules.
- This set must get cleared at the end of backward in the case that >=1 module in the list is never used, in which case we still want the forward hooks to run in the next forward after this backward.

Beyond these new forward hooks, everything else is some simple generalization from `Module` to `List[Module]` or `Tuple[Module, ...]`.

**Examples**
This PR enables wrapping Llama models more efficiently by grouping the final norm and output linear together: pytorch/torchtitan#382.

If at least one of the modules in the list does not run forward before backward, then there will be a warning message like:
```
1 of the 2 modules passed to fully_shard did not run forward before backward, which is error-prone since FSDP post-forward/pre-backward logic will not run for these modules. We recommend passing only modules that run forward together. Modules that did not run forward: [FSDPLinear(in_features=1, out_features=1, bias=True)]
```

[ghstack-poisoned]
awgu added a commit to pytorch/pytorch that referenced this pull request Jul 17, 2024
This PR allows `fully_shard`'s first argument to be `List[nn.Module]` instead of strictly `nn.Module`. This allows more flexible grouping of modules/parameters for communication, which can lead to memory savings and/or more efficient communication.

**Approach**
At a high level, we can think of a model as a tree of modules. Previously, we could only select specific module nodes in this tree as representing one FSDP parameter group. With this PR, we can select a group of module nodes, effectively becoming a single super node.

To implement the runtime schedule, we define new forward hooks that run based on the following semantics:
- If a module is the first to run the pre-hook, actually run the given pre-hook. Otherwise, the pre-hook is no-op.
- If a module is the last to run the post-hook, actually run the given post-hook. Otherwise, the post-hook is a no-op.
- First and last are determined by scoreboarding against a set of the modules.
- This set must get cleared at the end of backward in the case that >=1 module in the list is never used, in which case we still want the forward hooks to run in the next forward after this backward.

Beyond these new forward hooks, everything else is some simple generalization from `Module` to `List[Module]` or `Tuple[Module, ...]`.

**Examples**
This PR enables wrapping Llama models more efficiently by grouping the final norm and output linear together: pytorch/torchtitan#382.

If at least one of the modules in the list does not run forward before backward, then there will be a warning message like:
```
1 of the 2 modules passed to fully_shard did not run forward before backward, which is error-prone since FSDP post-forward/pre-backward logic will not run for these modules. We recommend passing only modules that run forward together. Modules that did not run forward: [FSDPLinear(in_features=1, out_features=1, bias=True)]
```

ghstack-source-id: c70870546bed6108b163062b88415cbba3d37925
Pull Request resolved: #130949
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jul 17, 2024
This PR allows `fully_shard`'s first argument to be `List[nn.Module]` instead of strictly `nn.Module`. This allows more flexible grouping of modules/parameters for communication, which can lead to memory savings and/or more efficient communication.

**Approach**
At a high level, we can think of a model as a tree of modules. Previously, we could only select specific module nodes in this tree as representing one FSDP parameter group. With this PR, we can select a group of module nodes, effectively becoming a single super node.

To implement the runtime schedule, we define new forward hooks that run based on the following semantics:
- If a module is the first to run the pre-hook, actually run the given pre-hook. Otherwise, the pre-hook is no-op.
- If a module is the last to run the post-hook, actually run the given post-hook. Otherwise, the post-hook is a no-op.
- First and last are determined by scoreboarding against a set of the modules.
- This set must get cleared at the end of backward in the case that >=1 module in the list is never used, in which case we still want the forward hooks to run in the next forward after this backward.

Beyond these new forward hooks, everything else is some simple generalization from `Module` to `List[Module]` or `Tuple[Module, ...]`.

**Examples**
This PR enables wrapping Llama models more efficiently by grouping the final norm and output linear together: pytorch/torchtitan#382.

If at least one of the modules in the list does not run forward before backward, then there will be a warning message like:
```
1 of the 2 modules passed to fully_shard did not run forward before backward, which is error-prone since FSDP post-forward/pre-backward logic will not run for these modules. We recommend passing only modules that run forward together. Modules that did not run forward: [FSDPLinear(in_features=1, out_features=1, bias=True)]
```

---

**Changes for reland:** none since breakage was from PR below

Pull Request resolved: #130949
Approved by: https://github.com/weifengpy
ghstack dependencies: #130947
DiweiSun pushed a commit to DiweiSun/pytorch that referenced this pull request Jul 22, 2024
This PR allows `fully_shard`'s first argument to be `List[nn.Module]` instead of strictly `nn.Module`. This allows more flexible grouping of modules/parameters for communication, which can lead to memory savings and/or more efficient communication.

**Approach**
At a high level, we can think of a model as a tree of modules. Previously, we could only select specific module nodes in this tree as representing one FSDP parameter group. With this PR, we can select a group of module nodes, effectively becoming a single super node.

To implement the runtime schedule, we define new forward hooks that run based on the following semantics:
- If a module is the first to run the pre-hook, actually run the given pre-hook. Otherwise, the pre-hook is no-op.
- If a module is the last to run the post-hook, actually run the given post-hook. Otherwise, the post-hook is a no-op.
- First and last are determined by scoreboarding against a set of the modules.
- This set must get cleared at the end of backward in the case that >=1 module in the list is never used, in which case we still want the forward hooks to run in the next forward after this backward.

Beyond these new forward hooks, everything else is some simple generalization from `Module` to `List[Module]` or `Tuple[Module, ...]`.

**Examples**
This PR enables wrapping Llama models more efficiently by grouping the final norm and output linear together: pytorch/torchtitan#382.

If at least one of the modules in the list does not run forward before backward, then there will be a warning message like:
```
1 of the 2 modules passed to fully_shard did not run forward before backward, which is error-prone since FSDP post-forward/pre-backward logic will not run for these modules. We recommend passing only modules that run forward together. Modules that did not run forward: [FSDPLinear(in_features=1, out_features=1, bias=True)]
```

---

**Changes for reland:** none since breakage was from PR below

Pull Request resolved: pytorch#130949
Approved by: https://github.com/weifengpy
ghstack dependencies: pytorch#130947
bigfootjon pushed a commit to pytorch/pytorch that referenced this pull request Jul 24, 2024
This PR allows `fully_shard`'s first argument to be `List[nn.Module]` instead of strictly `nn.Module`. This allows more flexible grouping of modules/parameters for communication, which can lead to memory savings and/or more efficient communication.

**Approach**
At a high level, we can think of a model as a tree of modules. Previously, we could only select specific module nodes in this tree as representing one FSDP parameter group. With this PR, we can select a group of module nodes, effectively becoming a single super node.

To implement the runtime schedule, we define new forward hooks that run based on the following semantics:
- If a module is the first to run the pre-hook, actually run the given pre-hook. Otherwise, the pre-hook is no-op.
- If a module is the last to run the post-hook, actually run the given post-hook. Otherwise, the post-hook is a no-op.
- First and last are determined by scoreboarding against a set of the modules.
- This set must get cleared at the end of backward in the case that >=1 module in the list is never used, in which case we still want the forward hooks to run in the next forward after this backward.

Beyond these new forward hooks, everything else is some simple generalization from `Module` to `List[Module]` or `Tuple[Module, ...]`.

**Examples**
This PR enables wrapping Llama models more efficiently by grouping the final norm and output linear together: pytorch/torchtitan#382.

If at least one of the modules in the list does not run forward before backward, then there will be a warning message like:
```
1 of the 2 modules passed to fully_shard did not run forward before backward, which is error-prone since FSDP post-forward/pre-backward logic will not run for these modules. We recommend passing only modules that run forward together. Modules that did not run forward: [FSDPLinear(in_features=1, out_features=1, bias=True)]
```

---

**Changes for reland:** none since breakage was from PR below

Pull Request resolved: #130949
Approved by: https://github.com/weifengpy
ghstack dependencies: #130947

(cherry picked from commit 31e3330)
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
This PR allows `fully_shard`'s first argument to be `List[nn.Module]` instead of strictly `nn.Module`. This allows more flexible grouping of modules/parameters for communication, which can lead to memory savings and/or more efficient communication.

**Approach**
At a high level, we can think of a model as a tree of modules. Previously, we could only select specific module nodes in this tree as representing one FSDP parameter group. With this PR, we can select a group of module nodes, effectively becoming a single super node.

To implement the runtime schedule, we define new forward hooks that run based on the following semantics:
- If a module is the first to run the pre-hook, actually run the given pre-hook. Otherwise, the pre-hook is no-op.
- If a module is the last to run the post-hook, actually run the given post-hook. Otherwise, the post-hook is a no-op.
- First and last are determined by scoreboarding against a set of the modules.
- This set must get cleared at the end of backward in the case that >=1 module in the list is never used, in which case we still want the forward hooks to run in the next forward after this backward.

Beyond these new forward hooks, everything else is some simple generalization from `Module` to `List[Module]` or `Tuple[Module, ...]`.

**Examples**
This PR enables wrapping Llama models more efficiently by grouping the final norm and output linear together: pytorch/torchtitan#382.

If at least one of the modules in the list does not run forward before backward, then there will be a warning message like:
```
1 of the 2 modules passed to fully_shard did not run forward before backward, which is error-prone since FSDP post-forward/pre-backward logic will not run for these modules. We recommend passing only modules that run forward together. Modules that did not run forward: [FSDPLinear(in_features=1, out_features=1, bias=True)]
```

Pull Request resolved: pytorch#127786
Approved by: https://github.com/yf225, https://github.com/weifengpy
ghstack dependencies: pytorch#127773
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
This PR allows `fully_shard`'s first argument to be `List[nn.Module]` instead of strictly `nn.Module`. This allows more flexible grouping of modules/parameters for communication, which can lead to memory savings and/or more efficient communication.

**Approach**
At a high level, we can think of a model as a tree of modules. Previously, we could only select specific module nodes in this tree as representing one FSDP parameter group. With this PR, we can select a group of module nodes, effectively becoming a single super node.

To implement the runtime schedule, we define new forward hooks that run based on the following semantics:
- If a module is the first to run the pre-hook, actually run the given pre-hook. Otherwise, the pre-hook is no-op.
- If a module is the last to run the post-hook, actually run the given post-hook. Otherwise, the post-hook is a no-op.
- First and last are determined by scoreboarding against a set of the modules.
- This set must get cleared at the end of backward in the case that >=1 module in the list is never used, in which case we still want the forward hooks to run in the next forward after this backward.

Beyond these new forward hooks, everything else is some simple generalization from `Module` to `List[Module]` or `Tuple[Module, ...]`.

**Examples**
This PR enables wrapping Llama models more efficiently by grouping the final norm and output linear together: pytorch/torchtitan#382.

If at least one of the modules in the list does not run forward before backward, then there will be a warning message like:
```
1 of the 2 modules passed to fully_shard did not run forward before backward, which is error-prone since FSDP post-forward/pre-backward logic will not run for these modules. We recommend passing only modules that run forward together. Modules that did not run forward: [FSDPLinear(in_features=1, out_features=1, bias=True)]
```

---

**Changes for reland:** none since breakage was from PR below

Pull Request resolved: pytorch#130949
Approved by: https://github.com/weifengpy
ghstack dependencies: pytorch#130947
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants