Skip to content

Commit

Permalink
Update on "[Not for land] Added changes for GPT-2 perf"
Browse files Browse the repository at this point in the history
Credit: felipemello1 for most of the work here (especially around chunked cross entropy)

Running on 4xH100s:
Without these changes (`torch.compile`), the max local batch size is 5:
```
[rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2024-08-19 11:10:33,811 - root - INFO - step:  1  loss: 12.2365  memory: 81.67GiB(85.93%)  wps: 5,380  mfu: 1.09%
[rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10  loss: 12.1951  memory: 81.67GiB(85.93%)  wps: 111,770  mfu: 22.68%
[rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20  loss: 11.9455  memory: 81.67GiB(85.93%)  wps: 111,714  mfu: 22.67%
[rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30  loss: 11.0407  memory: 81.67GiB(85.93%)  wps: 112,194  mfu: 22.76%
[rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40  loss:  9.9520  memory: 81.67GiB(85.93%)  wps: 112,109  mfu: 22.75%
[rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50  loss:  9.3392  memory: 81.67GiB(85.93%)  wps: 112,218  mfu: 22.77%
[rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60  loss:  8.7255  memory: 81.67GiB(85.93%)  wps: 112,198  mfu: 22.77%
[rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70  loss:  8.1659  memory: 81.67GiB(85.93%)  wps: 112,234  mfu: 22.77%
[rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80  loss:  7.8037  memory: 81.67GiB(85.93%)  wps: 111,802  mfu: 22.68%
[rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90  loss:  7.5327  memory: 81.67GiB(85.93%)  wps: 111,937  mfu: 22.71%
[rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100  loss:  7.3730  memory: 81.67GiB(85.93%)  wps: 111,803  mfu: 22.69%
```

<details>

<summary> Without these changes, no compile </summary>

Without these changes (no `torch.compile`), local batch size 5:
```
[rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
[rank0]:2024-08-19 14:24:38,558 - root - INFO - step:  1  loss: 12.2581  memory: 86.47GiB(90.99%)  wps: 6,393  mfu: 1.30%
[rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10  loss: 12.2099  memory: 86.48GiB(90.99%)  wps: 98,305  mfu: 19.95%
[rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20  loss: 11.9421  memory: 86.48GiB(90.99%)  wps: 98,230  mfu: 19.93%
[rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30  loss: 11.0090  memory: 86.48GiB(90.99%)  wps: 98,435  mfu: 19.97%
[rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40  loss:  9.9780  memory: 86.48GiB(90.99%)  wps: 99,064  mfu: 20.10%
[rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50  loss:  9.3572  memory: 86.48GiB(90.99%)  wps: 98,813  mfu: 20.05%
[rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60  loss:  8.7479  memory: 86.48GiB(90.99%)  wps: 96,567  mfu: 19.59%
[rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70  loss:  8.1769  memory: 86.48GiB(90.99%)  wps: 98,604  mfu: 20.01%
[rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80  loss:  7.8070  memory: 86.48GiB(90.99%)  wps: 98,579  mfu: 20.00%
[rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90  loss:  7.5329  memory: 86.48GiB(90.99%)  wps: 98,743  mfu: 20.04%
[rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100  loss:  7.3700  memory: 86.48GiB(90.99%)  wps: 98,818  mfu: 20.05%
```
</details>

With these changes (`torch.compile`), local batch size 32:
```
[rank0]:2024-09-06 19:48:58,342 - root - INFO - Training starts at step 1, with local batch size 32, global batch size 128, sequence length 8192, total steps 50 (warmup 200)
[rank0]:2024-09-06 19:49:08,904 - root - INFO - step:  1  loss: 12.2442  memory: 79.40GiB(83.54%)  wps: 24,819  mfu: 5.04%
[rank0]:2024-09-06 19:49:08,904 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-09-06 19:49:23,127 - root - INFO - step: 10  loss: 12.1998  memory: 80.81GiB(85.03%)  wps: 165,880  mfu: 33.66%
[rank0]:2024-09-06 19:49:38,946 - root - INFO - step: 20  loss: 11.9284  memory: 80.81GiB(85.03%)  wps: 165,732  mfu: 33.63%
[rank0]:2024-09-06 19:49:54,764 - root - INFO - step: 30  loss: 10.9587  memory: 80.81GiB(85.03%)  wps: 165,733  mfu: 33.63%
[rank0]:2024-09-06 19:50:10,566 - root - INFO - step: 40  loss:  9.8493  memory: 80.81GiB(85.03%)  wps: 165,904  mfu: 33.66%
[rank0]:2024-09-06 19:50:26,973 - root - INFO - step: 50  loss:  9.2317  memory: 80.81GiB(85.03%)  wps: 159,786  mfu: 32.42%
```


<details>
<summary> Old Results </summary>

With these changes, we can use local batch size 16:
```
[rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2024-08-19 11:16:15,523 - root - INFO - step:  1  loss: 12.2386  memory: 72.29GiB(76.06%)  wps: 21,887  mfu: 4.44%
[rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10  loss: 12.1966  memory: 72.30GiB(76.07%)  wps: 168,174  mfu: 34.12%
[rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20  loss: 11.9229  memory: 72.30GiB(76.07%)  wps: 168,196  mfu: 34.13%
[rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30  loss: 10.9399  memory: 72.30GiB(76.07%)  wps: 168,144  mfu: 34.12%
[rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40  loss:  9.8742  memory: 72.30GiB(76.07%)  wps: 167,898  mfu: 34.07%
[rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50  loss:  9.2517  memory: 72.30GiB(76.07%)  wps: 168,130  mfu: 34.11%
[rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60  loss:  8.6441  memory: 72.30GiB(76.07%)  wps: 168,435  mfu: 34.18%
[rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70  loss:  8.0827  memory: 72.30GiB(76.07%)  wps: 168,927  mfu: 34.28%
[rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80  loss:  7.7330  memory: 72.30GiB(76.07%)  wps: 168,772  mfu: 34.24%
[rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90  loss:  7.4835  memory: 72.30GiB(76.07%)  wps: 162,008  mfu: 32.87%
[rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100  loss:  7.3274  memory: 72.30GiB(76.07%)  wps: 167,963  mfu: 34.08%
```

22.7% MFU -> 34.1% MFU

</details>

[ghstack-poisoned]
  • Loading branch information
awgu committed Sep 7, 2024
2 parents b6a84e4 + f66c10d commit 8b27669
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def accumulate_chunk(input_chunk, target_chunk):

input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
target_chunks = torch.chunk(target, chunks=chunks, dim=0)
for input_chunk, target_chunk in zip(input_chunks, target_chunks):
for input_chunk, target_chunk in zip(input_chunks, target_chunks, strict=True):
grad_inputs.append(accumulate_chunk(input_chunk, target_chunk))

ctx.save_for_backward(
Expand Down

0 comments on commit 8b27669

Please sign in to comment.