From f66c10dbc2e5c0a6ce8c1417301e4c582f8a9153 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Fri, 6 Sep 2024 19:53:09 -0700 Subject: [PATCH] Update base for Update on "[Not for land] Added changes for GPT-2 perf" Credit: felipemello1 for most of the work here (especially around chunked cross entropy) Running on 4xH100s: Without these changes (`torch.compile`), the max local batch size is 5: ``` [rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200) [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:2024-08-19 11:10:33,811 - root - INFO - step: 1 loss: 12.2365 memory: 81.67GiB(85.93%) wps: 5,380 mfu: 1.09% [rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10 loss: 12.1951 memory: 81.67GiB(85.93%) wps: 111,770 mfu: 22.68% [rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20 loss: 11.9455 memory: 81.67GiB(85.93%) wps: 111,714 mfu: 22.67% [rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30 loss: 11.0407 memory: 81.67GiB(85.93%) wps: 112,194 mfu: 22.76% [rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40 loss: 9.9520 memory: 81.67GiB(85.93%) wps: 112,109 mfu: 22.75% [rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50 loss: 9.3392 memory: 81.67GiB(85.93%) wps: 112,218 mfu: 22.77% [rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60 loss: 8.7255 memory: 81.67GiB(85.93%) wps: 112,198 mfu: 22.77% [rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70 loss: 8.1659 memory: 81.67GiB(85.93%) wps: 112,234 mfu: 22.77% [rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80 loss: 7.8037 memory: 81.67GiB(85.93%) wps: 111,802 mfu: 22.68% [rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90 loss: 7.5327 memory: 81.67GiB(85.93%) wps: 111,937 mfu: 22.71% [rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100 loss: 7.3730 memory: 81.67GiB(85.93%) wps: 111,803 mfu: 22.69% ```
Without these changes, no compile Without these changes (no `torch.compile`), local batch size 5: ``` [rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200) [rank0]:2024-08-19 14:24:38,558 - root - INFO - step: 1 loss: 12.2581 memory: 86.47GiB(90.99%) wps: 6,393 mfu: 1.30% [rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10 loss: 12.2099 memory: 86.48GiB(90.99%) wps: 98,305 mfu: 19.95% [rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20 loss: 11.9421 memory: 86.48GiB(90.99%) wps: 98,230 mfu: 19.93% [rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30 loss: 11.0090 memory: 86.48GiB(90.99%) wps: 98,435 mfu: 19.97% [rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40 loss: 9.9780 memory: 86.48GiB(90.99%) wps: 99,064 mfu: 20.10% [rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50 loss: 9.3572 memory: 86.48GiB(90.99%) wps: 98,813 mfu: 20.05% [rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60 loss: 8.7479 memory: 86.48GiB(90.99%) wps: 96,567 mfu: 19.59% [rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70 loss: 8.1769 memory: 86.48GiB(90.99%) wps: 98,604 mfu: 20.01% [rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80 loss: 7.8070 memory: 86.48GiB(90.99%) wps: 98,579 mfu: 20.00% [rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90 loss: 7.5329 memory: 86.48GiB(90.99%) wps: 98,743 mfu: 20.04% [rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100 loss: 7.3700 memory: 86.48GiB(90.99%) wps: 98,818 mfu: 20.05% ```
With these changes (`torch.compile`), local batch size 32: ``` [rank0]:2024-09-06 19:48:58,342 - root - INFO - Training starts at step 1, with local batch size 32, global batch size 128, sequence length 8192, total steps 50 (warmup 200) [rank0]:2024-09-06 19:49:08,904 - root - INFO - step: 1 loss: 12.2442 memory: 79.40GiB(83.54%) wps: 24,819 mfu: 5.04% [rank0]:2024-09-06 19:49:08,904 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-09-06 19:49:23,127 - root - INFO - step: 10 loss: 12.1998 memory: 80.81GiB(85.03%) wps: 165,880 mfu: 33.66% [rank0]:2024-09-06 19:49:38,946 - root - INFO - step: 20 loss: 11.9284 memory: 80.81GiB(85.03%) wps: 165,732 mfu: 33.63% [rank0]:2024-09-06 19:49:54,764 - root - INFO - step: 30 loss: 10.9587 memory: 80.81GiB(85.03%) wps: 165,733 mfu: 33.63% [rank0]:2024-09-06 19:50:10,566 - root - INFO - step: 40 loss: 9.8493 memory: 80.81GiB(85.03%) wps: 165,904 mfu: 33.66% [rank0]:2024-09-06 19:50:26,973 - root - INFO - step: 50 loss: 9.2317 memory: 80.81GiB(85.03%) wps: 159,786 mfu: 32.42% ```
Old Results With these changes, we can use local batch size 16: ``` [rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200) [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:2024-08-19 11:16:15,523 - root - INFO - step: 1 loss: 12.2386 memory: 72.29GiB(76.06%) wps: 21,887 mfu: 4.44% [rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10 loss: 12.1966 memory: 72.30GiB(76.07%) wps: 168,174 mfu: 34.12% [rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20 loss: 11.9229 memory: 72.30GiB(76.07%) wps: 168,196 mfu: 34.13% [rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30 loss: 10.9399 memory: 72.30GiB(76.07%) wps: 168,144 mfu: 34.12% [rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40 loss: 9.8742 memory: 72.30GiB(76.07%) wps: 167,898 mfu: 34.07% [rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50 loss: 9.2517 memory: 72.30GiB(76.07%) wps: 168,130 mfu: 34.11% [rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60 loss: 8.6441 memory: 72.30GiB(76.07%) wps: 168,435 mfu: 34.18% [rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70 loss: 8.0827 memory: 72.30GiB(76.07%) wps: 168,927 mfu: 34.28% [rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80 loss: 7.7330 memory: 72.30GiB(76.07%) wps: 168,772 mfu: 34.24% [rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90 loss: 7.4835 memory: 72.30GiB(76.07%) wps: 162,008 mfu: 32.87% [rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100 loss: 7.3274 memory: 72.30GiB(76.07%) wps: 167,963 mfu: 34.08% ``` 22.7% MFU -> 34.1% MFU
[ghstack-poisoned]