-
Notifications
You must be signed in to change notification settings - Fork 172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable TritonFusedRMSNorm with local_map annotation #404
Conversation
…nnotation" **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`): 1. with `norm_type = "rmsnorm"` ``` [rank0]:2024-06-05 11:57:35,505 - root - INFO - step: 1 loss: 12.2703 memory: 24.66GiB(31.15%) wps: 143 mfu: 2.66% [rank0]:2024-06-05 11:57:35,505 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-06-05 11:58:11,490 - root - INFO - step: 10 loss: 11.0446 memory: 31.96GiB(40.37%) wps: 512 mfu: 9.51% [rank0]:2024-06-05 11:58:46,488 - root - INFO - step: 20 loss: 9.2321 memory: 31.96GiB(40.37%) wps: 586 mfu: 10.87% [rank0]:2024-06-05 11:59:22,462 - root - INFO - step: 30 loss: 8.2184 memory: 31.96GiB(40.37%) wps: 570 mfu: 10.58% [rank0]:2024-06-05 11:59:57,301 - root - INFO - step: 40 loss: 7.6220 memory: 31.96GiB(40.37%) wps: 589 mfu: 10.93% [rank0]:2024-06-05 12:00:32,254 - root - INFO - step: 50 loss: 7.5399 memory: 31.96GiB(40.37%) wps: 587 mfu: 10.89% [rank0]:2024-06-05 12:01:07,155 - root - INFO - step: 60 loss: 7.3179 memory: 31.96GiB(40.37%) wps: 588 mfu: 10.91% [rank0]:2024-06-05 12:01:41,999 - root - INFO - step: 70 loss: 7.3508 memory: 31.96GiB(40.37%) wps: 589 mfu: 10.92% [rank0]:2024-06-05 12:02:17,093 - root - INFO - step: 80 loss: 7.2696 memory: 31.96GiB(40.37%) wps: 584 mfu: 10.85% [rank0]:2024-06-05 12:02:52,009 - root - INFO - step: 90 loss: 7.0481 memory: 31.96GiB(40.37%) wps: 588 mfu: 10.91% [rank0]:2024-06-05 12:03:27,715 - root - INFO - step: 100 loss: 6.9623 memory: 31.96GiB(40.37%) wps: 575 mfu: 10.67% ``` 3. with `norm_type = "fused_rmsnorm"` ```[rank0]:2024-06-05 12:08:35,004 - root - INFO - step: 1 loss: 12.2422 memory: 24.62GiB(31.10%) wps: 95 mfu: 1.76% [rank0]:2024-06-05 12:08:35,004 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-06-05 12:09:12,401 - root - INFO - step: 10 loss: 11.0361 memory: 32.09GiB(40.54%) wps: 493 mfu: 9.15% [rank0]:2024-06-05 12:09:49,380 - root - INFO - step: 20 loss: 9.2725 memory: 32.09GiB(40.54%) wps: 554 mfu: 10.29% [rank0]:2024-06-05 12:10:26,645 - root - INFO - step: 30 loss: 8.2091 memory: 32.09GiB(40.54%) wps: 550 mfu: 10.21% [rank0]:2024-06-05 12:11:03,616 - root - INFO - step: 40 loss: 7.5601 memory: 32.09GiB(40.54%) wps: 555 mfu: 10.30% [rank0]:2024-06-05 12:11:40,625 - root - INFO - step: 50 loss: 7.5144 memory: 32.09GiB(40.54%) wps: 554 mfu: 10.29% [rank0]:2024-06-05 12:12:17,768 - root - INFO - step: 60 loss: 7.3869 memory: 32.09GiB(40.54%) wps: 552 mfu: 10.25% [rank0]:2024-06-05 12:12:54,820 - root - INFO - step: 70 loss: 7.3358 memory: 32.09GiB(40.54%) wps: 553 mfu: 10.27% [rank0]:2024-06-05 12:13:31,817 - root - INFO - step: 80 loss: 7.2085 memory: 32.09GiB(40.54%) wps: 554 mfu: 10.29% [rank0]:2024-06-05 12:14:09,156 - root - INFO - step: 90 loss: 7.0140 memory: 32.09GiB(40.54%) wps: 549 mfu: 10.19% [rank0]:2024-06-05 12:14:48,518 - root - INFO - step: 100 loss: 6.9507 memory: 32.09GiB(40.54%) wps: 521 mfu: 9.67%``` [ghstack-poisoned]
…nnotation" **Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`): 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-12 13:55:25,005 - root - INFO - step: 1 loss: 12.2971 memory: 23.68GiB(29.92%) wps: 258 mfu: 4.79% [rank2]:2024-06-12 13:55:43,082 - root - INFO - step: 5 loss: 11.6237 memory: 30.98GiB(39.14%) wps: 453 mfu: 8.41% [rank2]:2024-06-12 13:56:00,742 - root - INFO - step: 10 loss: 10.7210 memory: 30.98GiB(39.14%) wps: 580 mfu: 10.77% [rank2]:2024-06-12 13:56:18,274 - root - INFO - step: 15 loss: 9.4563 memory: 30.98GiB(39.14%) wps: 585 mfu: 10.85% [rank2]:2024-06-12 13:56:35,888 - root - INFO - step: 20 loss: 8.9246 memory: 30.98GiB(39.14%) wps: 582 mfu: 10.80% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-12 13:52:48,671 - root - INFO - step: 1 loss: 12.2779 memory: 23.64GiB(29.86%) wps: 186 mfu: 3.45% [rank2]:2024-06-12 13:53:06,983 - root - INFO - step: 5 loss: 11.6073 memory: 31.11GiB(39.31%) wps: 447 mfu: 8.30% [rank2]:2024-06-12 13:53:23,895 - root - INFO - step: 10 loss: 10.6355 memory: 31.11GiB(39.31%) wps: 606 mfu: 11.25% [rank2]:2024-06-12 13:53:41,108 - root - INFO - step: 15 loss: 9.5591 memory: 31.11GiB(39.31%) wps: 596 mfu: 11.05% [rank2]:2024-06-12 13:53:58,045 - root - INFO - step: 20 loss: 9.0287 memory: 31.11GiB(39.31%) wps: 605 mfu: 11.23% ``` [ghstack-poisoned]
…nnotation" **Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`): 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-12 13:55:25,005 - root - INFO - step: 1 loss: 12.2971 memory: 23.68GiB(29.92%) wps: 258 mfu: 4.79% [rank2]:2024-06-12 13:55:43,082 - root - INFO - step: 5 loss: 11.6237 memory: 30.98GiB(39.14%) wps: 453 mfu: 8.41% [rank2]:2024-06-12 13:56:00,742 - root - INFO - step: 10 loss: 10.7210 memory: 30.98GiB(39.14%) wps: 580 mfu: 10.77% [rank2]:2024-06-12 13:56:18,274 - root - INFO - step: 15 loss: 9.4563 memory: 30.98GiB(39.14%) wps: 585 mfu: 10.85% [rank2]:2024-06-12 13:56:35,888 - root - INFO - step: 20 loss: 8.9246 memory: 30.98GiB(39.14%) wps: 582 mfu: 10.80% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-12 13:52:48,671 - root - INFO - step: 1 loss: 12.2779 memory: 23.64GiB(29.86%) wps: 186 mfu: 3.45% [rank2]:2024-06-12 13:53:06,983 - root - INFO - step: 5 loss: 11.6073 memory: 31.11GiB(39.31%) wps: 447 mfu: 8.30% [rank2]:2024-06-12 13:53:23,895 - root - INFO - step: 10 loss: 10.6355 memory: 31.11GiB(39.31%) wps: 606 mfu: 11.25% [rank2]:2024-06-12 13:53:41,108 - root - INFO - step: 15 loss: 9.5591 memory: 31.11GiB(39.31%) wps: 596 mfu: 11.05% [rank2]:2024-06-12 13:53:58,045 - root - INFO - step: 20 loss: 9.0287 memory: 31.11GiB(39.31%) wps: 605 mfu: 11.23% ``` [ghstack-poisoned]
…nnotation" **Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`): 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-12 13:55:25,005 - root - INFO - step: 1 loss: 12.2971 memory: 23.68GiB(29.92%) wps: 258 mfu: 4.79% [rank2]:2024-06-12 13:55:43,082 - root - INFO - step: 5 loss: 11.6237 memory: 30.98GiB(39.14%) wps: 453 mfu: 8.41% [rank2]:2024-06-12 13:56:00,742 - root - INFO - step: 10 loss: 10.7210 memory: 30.98GiB(39.14%) wps: 580 mfu: 10.77% [rank2]:2024-06-12 13:56:18,274 - root - INFO - step: 15 loss: 9.4563 memory: 30.98GiB(39.14%) wps: 585 mfu: 10.85% [rank2]:2024-06-12 13:56:35,888 - root - INFO - step: 20 loss: 8.9246 memory: 30.98GiB(39.14%) wps: 582 mfu: 10.80% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-12 13:52:48,671 - root - INFO - step: 1 loss: 12.2779 memory: 23.64GiB(29.86%) wps: 186 mfu: 3.45% [rank2]:2024-06-12 13:53:06,983 - root - INFO - step: 5 loss: 11.6073 memory: 31.11GiB(39.31%) wps: 447 mfu: 8.30% [rank2]:2024-06-12 13:53:23,895 - root - INFO - step: 10 loss: 10.6355 memory: 31.11GiB(39.31%) wps: 606 mfu: 11.25% [rank2]:2024-06-12 13:53:41,108 - root - INFO - step: 15 loss: 9.5591 memory: 31.11GiB(39.31%) wps: 596 mfu: 11.05% [rank2]:2024-06-12 13:53:58,045 - root - INFO - step: 20 loss: 9.0287 memory: 31.11GiB(39.31%) wps: 605 mfu: 11.23% ``` [ghstack-poisoned]
…nnotation" **Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings: ``` [job] dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 5 enable_tensorboard = false save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" lr = 3e-4 [training] batch_size = 4 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 4 pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" [checkpoint] enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy ``` 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-13 00:47:55,607 - root - INFO - step: 1 loss: 12.2262 memory: 57.70GiB(72.89%) wps: 429 mfu: 7.96% [rank2]:2024-06-13 00:48:57,536 - root - INFO - step: 5 loss: 11.4801 memory: 65.53GiB(82.78%) wps: 529 mfu: 9.82% [rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10 loss: 10.2305 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15 loss: 9.3287 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.09% [rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20 loss: 8.7126 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.19% [rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25 loss: 8.2011 memory: 65.53GiB(82.78%) wps: 591 mfu: 10.98% [rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30 loss: 7.7424 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35 loss: 7.4964 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40 loss: 7.2799 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.16% [rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45 loss: 7.2280 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.10% [rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50 loss: 7.0669 memory: 65.53GiB(82.78%) wps: 602 mfu: 11.17% [rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55 loss: 6.9967 memory: 65.53GiB(82.78%) wps: 595 mfu: 11.04% [rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60 loss: 7.0763 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65 loss: 6.9260 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70 loss: 6.9757 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75 loss: 6.8074 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80 loss: 6.7362 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.08% [rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85 loss: 6.7016 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.09% [rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90 loss: 6.6640 memory: 65.53GiB(82.78%) wps: 596 mfu: 11.06% [rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95 loss: 6.7214 memory: 65.53GiB(82.78%) wps: 604 mfu: 11.20% [rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100 loss: 6.5953 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-13 00:19:33,609 - root - INFO - step: 1 loss: 12.2194 memory: 57.31GiB(72.40%) wps: 412 mfu: 7.64% [rank2]:2024-06-13 00:20:29,175 - root - INFO - step: 5 loss: 11.4519 memory: 65.13GiB(82.29%) wps: 590 mfu: 10.95% [rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10 loss: 10.2199 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.79% [rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15 loss: 9.3509 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.92% [rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20 loss: 8.7972 memory: 65.13GiB(82.29%) wps: 629 mfu: 11.68% [rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25 loss: 8.2348 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.91% [rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30 loss: 7.7037 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35 loss: 7.4639 memory: 65.13GiB(82.29%) wps: 641 mfu: 11.90% [rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40 loss: 7.2406 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45 loss: 7.1822 memory: 65.13GiB(82.29%) wps: 640 mfu: 11.87% [rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50 loss: 7.0580 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55 loss: 6.9888 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60 loss: 7.0387 memory: 65.13GiB(82.29%) wps: 638 mfu: 11.84% [rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65 loss: 6.9199 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70 loss: 6.9503 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75 loss: 6.7960 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80 loss: 6.6798 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85 loss: 6.6504 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90 loss: 6.6655 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.81% [rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95 loss: 6.7359 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.78% [rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100 loss: 6.5410 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.85% ``` [ghstack-poisoned]
…nnotation" **Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings: ``` [job] dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 5 enable_tensorboard = false save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" lr = 3e-4 [training] batch_size = 4 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 4 pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" [checkpoint] enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy ``` 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-13 00:47:55,607 - root - INFO - step: 1 loss: 12.2262 memory: 57.70GiB(72.89%) wps: 429 mfu: 7.96% [rank2]:2024-06-13 00:48:57,536 - root - INFO - step: 5 loss: 11.4801 memory: 65.53GiB(82.78%) wps: 529 mfu: 9.82% [rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10 loss: 10.2305 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15 loss: 9.3287 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.09% [rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20 loss: 8.7126 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.19% [rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25 loss: 8.2011 memory: 65.53GiB(82.78%) wps: 591 mfu: 10.98% [rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30 loss: 7.7424 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35 loss: 7.4964 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40 loss: 7.2799 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.16% [rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45 loss: 7.2280 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.10% [rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50 loss: 7.0669 memory: 65.53GiB(82.78%) wps: 602 mfu: 11.17% [rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55 loss: 6.9967 memory: 65.53GiB(82.78%) wps: 595 mfu: 11.04% [rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60 loss: 7.0763 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65 loss: 6.9260 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70 loss: 6.9757 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75 loss: 6.8074 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80 loss: 6.7362 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.08% [rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85 loss: 6.7016 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.09% [rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90 loss: 6.6640 memory: 65.53GiB(82.78%) wps: 596 mfu: 11.06% [rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95 loss: 6.7214 memory: 65.53GiB(82.78%) wps: 604 mfu: 11.20% [rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100 loss: 6.5953 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-13 00:19:33,609 - root - INFO - step: 1 loss: 12.2194 memory: 57.31GiB(72.40%) wps: 412 mfu: 7.64% [rank2]:2024-06-13 00:20:29,175 - root - INFO - step: 5 loss: 11.4519 memory: 65.13GiB(82.29%) wps: 590 mfu: 10.95% [rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10 loss: 10.2199 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.79% [rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15 loss: 9.3509 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.92% [rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20 loss: 8.7972 memory: 65.13GiB(82.29%) wps: 629 mfu: 11.68% [rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25 loss: 8.2348 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.91% [rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30 loss: 7.7037 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35 loss: 7.4639 memory: 65.13GiB(82.29%) wps: 641 mfu: 11.90% [rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40 loss: 7.2406 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45 loss: 7.1822 memory: 65.13GiB(82.29%) wps: 640 mfu: 11.87% [rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50 loss: 7.0580 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55 loss: 6.9888 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60 loss: 7.0387 memory: 65.13GiB(82.29%) wps: 638 mfu: 11.84% [rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65 loss: 6.9199 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70 loss: 6.9503 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75 loss: 6.7960 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80 loss: 6.6798 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85 loss: 6.6504 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90 loss: 6.6655 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.81% [rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95 loss: 6.7359 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.78% [rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100 loss: 6.5410 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.85% ``` [ghstack-poisoned]
…nnotation" **Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings: ``` [job] dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 5 enable_tensorboard = false save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" lr = 3e-4 [training] batch_size = 4 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 4 pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" [checkpoint] enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy ``` 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-13 00:47:55,607 - root - INFO - step: 1 loss: 12.2262 memory: 57.70GiB(72.89%) wps: 429 mfu: 7.96% [rank2]:2024-06-13 00:48:57,536 - root - INFO - step: 5 loss: 11.4801 memory: 65.53GiB(82.78%) wps: 529 mfu: 9.82% [rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10 loss: 10.2305 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15 loss: 9.3287 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.09% [rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20 loss: 8.7126 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.19% [rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25 loss: 8.2011 memory: 65.53GiB(82.78%) wps: 591 mfu: 10.98% [rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30 loss: 7.7424 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35 loss: 7.4964 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40 loss: 7.2799 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.16% [rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45 loss: 7.2280 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.10% [rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50 loss: 7.0669 memory: 65.53GiB(82.78%) wps: 602 mfu: 11.17% [rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55 loss: 6.9967 memory: 65.53GiB(82.78%) wps: 595 mfu: 11.04% [rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60 loss: 7.0763 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65 loss: 6.9260 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70 loss: 6.9757 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75 loss: 6.8074 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80 loss: 6.7362 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.08% [rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85 loss: 6.7016 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.09% [rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90 loss: 6.6640 memory: 65.53GiB(82.78%) wps: 596 mfu: 11.06% [rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95 loss: 6.7214 memory: 65.53GiB(82.78%) wps: 604 mfu: 11.20% [rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100 loss: 6.5953 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-13 00:19:33,609 - root - INFO - step: 1 loss: 12.2194 memory: 57.31GiB(72.40%) wps: 412 mfu: 7.64% [rank2]:2024-06-13 00:20:29,175 - root - INFO - step: 5 loss: 11.4519 memory: 65.13GiB(82.29%) wps: 590 mfu: 10.95% [rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10 loss: 10.2199 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.79% [rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15 loss: 9.3509 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.92% [rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20 loss: 8.7972 memory: 65.13GiB(82.29%) wps: 629 mfu: 11.68% [rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25 loss: 8.2348 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.91% [rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30 loss: 7.7037 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35 loss: 7.4639 memory: 65.13GiB(82.29%) wps: 641 mfu: 11.90% [rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40 loss: 7.2406 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45 loss: 7.1822 memory: 65.13GiB(82.29%) wps: 640 mfu: 11.87% [rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50 loss: 7.0580 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55 loss: 6.9888 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60 loss: 7.0387 memory: 65.13GiB(82.29%) wps: 638 mfu: 11.84% [rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65 loss: 6.9199 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70 loss: 6.9503 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75 loss: 6.7960 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80 loss: 6.6798 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85 loss: 6.6504 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90 loss: 6.6655 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.81% [rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95 loss: 6.7359 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.78% [rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100 loss: 6.5410 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.85% ``` [ghstack-poisoned]
…nnotation" **Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings: ``` [job] dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 5 enable_tensorboard = false save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" lr = 3e-4 [training] batch_size = 4 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 4 pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" [checkpoint] enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy ``` 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-13 00:47:55,607 - root - INFO - step: 1 loss: 12.2262 memory: 57.70GiB(72.89%) wps: 429 mfu: 7.96% [rank2]:2024-06-13 00:48:57,536 - root - INFO - step: 5 loss: 11.4801 memory: 65.53GiB(82.78%) wps: 529 mfu: 9.82% [rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10 loss: 10.2305 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15 loss: 9.3287 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.09% [rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20 loss: 8.7126 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.19% [rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25 loss: 8.2011 memory: 65.53GiB(82.78%) wps: 591 mfu: 10.98% [rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30 loss: 7.7424 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35 loss: 7.4964 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40 loss: 7.2799 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.16% [rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45 loss: 7.2280 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.10% [rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50 loss: 7.0669 memory: 65.53GiB(82.78%) wps: 602 mfu: 11.17% [rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55 loss: 6.9967 memory: 65.53GiB(82.78%) wps: 595 mfu: 11.04% [rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60 loss: 7.0763 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65 loss: 6.9260 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70 loss: 6.9757 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75 loss: 6.8074 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80 loss: 6.7362 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.08% [rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85 loss: 6.7016 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.09% [rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90 loss: 6.6640 memory: 65.53GiB(82.78%) wps: 596 mfu: 11.06% [rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95 loss: 6.7214 memory: 65.53GiB(82.78%) wps: 604 mfu: 11.20% [rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100 loss: 6.5953 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-13 00:19:33,609 - root - INFO - step: 1 loss: 12.2194 memory: 57.31GiB(72.40%) wps: 412 mfu: 7.64% [rank2]:2024-06-13 00:20:29,175 - root - INFO - step: 5 loss: 11.4519 memory: 65.13GiB(82.29%) wps: 590 mfu: 10.95% [rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10 loss: 10.2199 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.79% [rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15 loss: 9.3509 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.92% [rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20 loss: 8.7972 memory: 65.13GiB(82.29%) wps: 629 mfu: 11.68% [rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25 loss: 8.2348 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.91% [rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30 loss: 7.7037 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35 loss: 7.4639 memory: 65.13GiB(82.29%) wps: 641 mfu: 11.90% [rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40 loss: 7.2406 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45 loss: 7.1822 memory: 65.13GiB(82.29%) wps: 640 mfu: 11.87% [rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50 loss: 7.0580 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55 loss: 6.9888 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60 loss: 7.0387 memory: 65.13GiB(82.29%) wps: 638 mfu: 11.84% [rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65 loss: 6.9199 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70 loss: 6.9503 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75 loss: 6.7960 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80 loss: 6.6798 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85 loss: 6.6504 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90 loss: 6.6655 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.81% [rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95 loss: 6.7359 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.78% [rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100 loss: 6.5410 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.85% ``` [ghstack-poisoned]
…nnotation" **Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings: ``` [job] dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 5 enable_tensorboard = false save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" lr = 3e-4 [training] batch_size = 4 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 4 pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" [checkpoint] enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy ``` 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-13 00:47:55,607 - root - INFO - step: 1 loss: 12.2262 memory: 57.70GiB(72.89%) wps: 429 mfu: 7.96% [rank2]:2024-06-13 00:48:57,536 - root - INFO - step: 5 loss: 11.4801 memory: 65.53GiB(82.78%) wps: 529 mfu: 9.82% [rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10 loss: 10.2305 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15 loss: 9.3287 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.09% [rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20 loss: 8.7126 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.19% [rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25 loss: 8.2011 memory: 65.53GiB(82.78%) wps: 591 mfu: 10.98% [rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30 loss: 7.7424 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35 loss: 7.4964 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40 loss: 7.2799 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.16% [rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45 loss: 7.2280 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.10% [rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50 loss: 7.0669 memory: 65.53GiB(82.78%) wps: 602 mfu: 11.17% [rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55 loss: 6.9967 memory: 65.53GiB(82.78%) wps: 595 mfu: 11.04% [rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60 loss: 7.0763 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65 loss: 6.9260 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70 loss: 6.9757 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75 loss: 6.8074 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80 loss: 6.7362 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.08% [rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85 loss: 6.7016 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.09% [rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90 loss: 6.6640 memory: 65.53GiB(82.78%) wps: 596 mfu: 11.06% [rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95 loss: 6.7214 memory: 65.53GiB(82.78%) wps: 604 mfu: 11.20% [rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100 loss: 6.5953 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-13 00:19:33,609 - root - INFO - step: 1 loss: 12.2194 memory: 57.31GiB(72.40%) wps: 412 mfu: 7.64% [rank2]:2024-06-13 00:20:29,175 - root - INFO - step: 5 loss: 11.4519 memory: 65.13GiB(82.29%) wps: 590 mfu: 10.95% [rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10 loss: 10.2199 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.79% [rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15 loss: 9.3509 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.92% [rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20 loss: 8.7972 memory: 65.13GiB(82.29%) wps: 629 mfu: 11.68% [rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25 loss: 8.2348 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.91% [rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30 loss: 7.7037 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35 loss: 7.4639 memory: 65.13GiB(82.29%) wps: 641 mfu: 11.90% [rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40 loss: 7.2406 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45 loss: 7.1822 memory: 65.13GiB(82.29%) wps: 640 mfu: 11.87% [rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50 loss: 7.0580 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55 loss: 6.9888 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60 loss: 7.0387 memory: 65.13GiB(82.29%) wps: 638 mfu: 11.84% [rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65 loss: 6.9199 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70 loss: 6.9503 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75 loss: 6.7960 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80 loss: 6.6798 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85 loss: 6.6504 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90 loss: 6.6655 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.81% [rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95 loss: 6.7359 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.78% [rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100 loss: 6.5410 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.85% ``` [ghstack-poisoned]
…nnotation" **Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings: ``` [job] dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 5 enable_tensorboard = false save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" lr = 3e-4 [training] batch_size = 4 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 4 pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" [checkpoint] enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy ``` 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-13 00:47:55,607 - root - INFO - step: 1 loss: 12.2262 memory: 57.70GiB(72.89%) wps: 429 mfu: 7.96% [rank2]:2024-06-13 00:48:57,536 - root - INFO - step: 5 loss: 11.4801 memory: 65.53GiB(82.78%) wps: 529 mfu: 9.82% [rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10 loss: 10.2305 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15 loss: 9.3287 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.09% [rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20 loss: 8.7126 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.19% [rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25 loss: 8.2011 memory: 65.53GiB(82.78%) wps: 591 mfu: 10.98% [rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30 loss: 7.7424 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35 loss: 7.4964 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40 loss: 7.2799 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.16% [rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45 loss: 7.2280 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.10% [rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50 loss: 7.0669 memory: 65.53GiB(82.78%) wps: 602 mfu: 11.17% [rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55 loss: 6.9967 memory: 65.53GiB(82.78%) wps: 595 mfu: 11.04% [rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60 loss: 7.0763 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65 loss: 6.9260 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70 loss: 6.9757 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75 loss: 6.8074 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80 loss: 6.7362 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.08% [rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85 loss: 6.7016 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.09% [rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90 loss: 6.6640 memory: 65.53GiB(82.78%) wps: 596 mfu: 11.06% [rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95 loss: 6.7214 memory: 65.53GiB(82.78%) wps: 604 mfu: 11.20% [rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100 loss: 6.5953 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-13 00:19:33,609 - root - INFO - step: 1 loss: 12.2194 memory: 57.31GiB(72.40%) wps: 412 mfu: 7.64% [rank2]:2024-06-13 00:20:29,175 - root - INFO - step: 5 loss: 11.4519 memory: 65.13GiB(82.29%) wps: 590 mfu: 10.95% [rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10 loss: 10.2199 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.79% [rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15 loss: 9.3509 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.92% [rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20 loss: 8.7972 memory: 65.13GiB(82.29%) wps: 629 mfu: 11.68% [rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25 loss: 8.2348 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.91% [rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30 loss: 7.7037 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35 loss: 7.4639 memory: 65.13GiB(82.29%) wps: 641 mfu: 11.90% [rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40 loss: 7.2406 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45 loss: 7.1822 memory: 65.13GiB(82.29%) wps: 640 mfu: 11.87% [rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50 loss: 7.0580 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55 loss: 6.9888 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60 loss: 7.0387 memory: 65.13GiB(82.29%) wps: 638 mfu: 11.84% [rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65 loss: 6.9199 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70 loss: 6.9503 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75 loss: 6.7960 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80 loss: 6.6798 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85 loss: 6.6504 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90 loss: 6.6655 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.81% [rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95 loss: 6.7359 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.78% [rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100 loss: 6.5410 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.85% ``` [ghstack-poisoned]
…nnotation" **Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings: ``` [job] dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 5 enable_tensorboard = false save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" lr = 3e-4 [training] batch_size = 4 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 4 pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" [checkpoint] enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy ``` 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-13 00:47:55,607 - root - INFO - step: 1 loss: 12.2262 memory: 57.70GiB(72.89%) wps: 429 mfu: 7.96% [rank2]:2024-06-13 00:48:57,536 - root - INFO - step: 5 loss: 11.4801 memory: 65.53GiB(82.78%) wps: 529 mfu: 9.82% [rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10 loss: 10.2305 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15 loss: 9.3287 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.09% [rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20 loss: 8.7126 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.19% [rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25 loss: 8.2011 memory: 65.53GiB(82.78%) wps: 591 mfu: 10.98% [rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30 loss: 7.7424 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35 loss: 7.4964 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40 loss: 7.2799 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.16% [rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45 loss: 7.2280 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.10% [rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50 loss: 7.0669 memory: 65.53GiB(82.78%) wps: 602 mfu: 11.17% [rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55 loss: 6.9967 memory: 65.53GiB(82.78%) wps: 595 mfu: 11.04% [rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60 loss: 7.0763 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65 loss: 6.9260 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70 loss: 6.9757 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75 loss: 6.8074 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80 loss: 6.7362 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.08% [rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85 loss: 6.7016 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.09% [rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90 loss: 6.6640 memory: 65.53GiB(82.78%) wps: 596 mfu: 11.06% [rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95 loss: 6.7214 memory: 65.53GiB(82.78%) wps: 604 mfu: 11.20% [rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100 loss: 6.5953 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-13 00:19:33,609 - root - INFO - step: 1 loss: 12.2194 memory: 57.31GiB(72.40%) wps: 412 mfu: 7.64% [rank2]:2024-06-13 00:20:29,175 - root - INFO - step: 5 loss: 11.4519 memory: 65.13GiB(82.29%) wps: 590 mfu: 10.95% [rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10 loss: 10.2199 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.79% [rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15 loss: 9.3509 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.92% [rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20 loss: 8.7972 memory: 65.13GiB(82.29%) wps: 629 mfu: 11.68% [rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25 loss: 8.2348 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.91% [rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30 loss: 7.7037 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35 loss: 7.4639 memory: 65.13GiB(82.29%) wps: 641 mfu: 11.90% [rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40 loss: 7.2406 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45 loss: 7.1822 memory: 65.13GiB(82.29%) wps: 640 mfu: 11.87% [rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50 loss: 7.0580 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55 loss: 6.9888 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60 loss: 7.0387 memory: 65.13GiB(82.29%) wps: 638 mfu: 11.84% [rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65 loss: 6.9199 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70 loss: 6.9503 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75 loss: 6.7960 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80 loss: 6.6798 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85 loss: 6.6504 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90 loss: 6.6655 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.81% [rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95 loss: 6.7359 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.78% [rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100 loss: 6.5410 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.85% ``` [ghstack-poisoned]
…nnotation" **Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings: ``` [job] dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 5 enable_tensorboard = false save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" lr = 3e-4 [training] batch_size = 4 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 4 pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" [checkpoint] enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy ``` 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-13 00:47:55,607 - root - INFO - step: 1 loss: 12.2262 memory: 57.70GiB(72.89%) wps: 429 mfu: 7.96% [rank2]:2024-06-13 00:48:57,536 - root - INFO - step: 5 loss: 11.4801 memory: 65.53GiB(82.78%) wps: 529 mfu: 9.82% [rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10 loss: 10.2305 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15 loss: 9.3287 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.09% [rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20 loss: 8.7126 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.19% [rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25 loss: 8.2011 memory: 65.53GiB(82.78%) wps: 591 mfu: 10.98% [rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30 loss: 7.7424 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35 loss: 7.4964 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40 loss: 7.2799 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.16% [rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45 loss: 7.2280 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.10% [rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50 loss: 7.0669 memory: 65.53GiB(82.78%) wps: 602 mfu: 11.17% [rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55 loss: 6.9967 memory: 65.53GiB(82.78%) wps: 595 mfu: 11.04% [rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60 loss: 7.0763 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65 loss: 6.9260 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70 loss: 6.9757 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75 loss: 6.8074 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80 loss: 6.7362 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.08% [rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85 loss: 6.7016 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.09% [rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90 loss: 6.6640 memory: 65.53GiB(82.78%) wps: 596 mfu: 11.06% [rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95 loss: 6.7214 memory: 65.53GiB(82.78%) wps: 604 mfu: 11.20% [rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100 loss: 6.5953 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-13 00:19:33,609 - root - INFO - step: 1 loss: 12.2194 memory: 57.31GiB(72.40%) wps: 412 mfu: 7.64% [rank2]:2024-06-13 00:20:29,175 - root - INFO - step: 5 loss: 11.4519 memory: 65.13GiB(82.29%) wps: 590 mfu: 10.95% [rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10 loss: 10.2199 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.79% [rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15 loss: 9.3509 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.92% [rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20 loss: 8.7972 memory: 65.13GiB(82.29%) wps: 629 mfu: 11.68% [rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25 loss: 8.2348 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.91% [rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30 loss: 7.7037 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35 loss: 7.4639 memory: 65.13GiB(82.29%) wps: 641 mfu: 11.90% [rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40 loss: 7.2406 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45 loss: 7.1822 memory: 65.13GiB(82.29%) wps: 640 mfu: 11.87% [rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50 loss: 7.0580 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55 loss: 6.9888 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60 loss: 7.0387 memory: 65.13GiB(82.29%) wps: 638 mfu: 11.84% [rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65 loss: 6.9199 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70 loss: 6.9503 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75 loss: 6.7960 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80 loss: 6.6798 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85 loss: 6.6504 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90 loss: 6.6655 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.81% [rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95 loss: 6.7359 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.78% [rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100 loss: 6.5410 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.85% ``` [ghstack-poisoned]
…nnotation" **Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings: ``` [job] dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 5 enable_tensorboard = false save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" lr = 3e-4 [training] batch_size = 4 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 4 pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" [checkpoint] enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy ``` 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-13 00:47:55,607 - root - INFO - step: 1 loss: 12.2262 memory: 57.70GiB(72.89%) wps: 429 mfu: 7.96% [rank2]:2024-06-13 00:48:57,536 - root - INFO - step: 5 loss: 11.4801 memory: 65.53GiB(82.78%) wps: 529 mfu: 9.82% [rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10 loss: 10.2305 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15 loss: 9.3287 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.09% [rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20 loss: 8.7126 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.19% [rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25 loss: 8.2011 memory: 65.53GiB(82.78%) wps: 591 mfu: 10.98% [rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30 loss: 7.7424 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35 loss: 7.4964 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40 loss: 7.2799 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.16% [rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45 loss: 7.2280 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.10% [rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50 loss: 7.0669 memory: 65.53GiB(82.78%) wps: 602 mfu: 11.17% [rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55 loss: 6.9967 memory: 65.53GiB(82.78%) wps: 595 mfu: 11.04% [rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60 loss: 7.0763 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65 loss: 6.9260 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70 loss: 6.9757 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75 loss: 6.8074 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80 loss: 6.7362 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.08% [rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85 loss: 6.7016 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.09% [rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90 loss: 6.6640 memory: 65.53GiB(82.78%) wps: 596 mfu: 11.06% [rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95 loss: 6.7214 memory: 65.53GiB(82.78%) wps: 604 mfu: 11.20% [rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100 loss: 6.5953 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-13 00:19:33,609 - root - INFO - step: 1 loss: 12.2194 memory: 57.31GiB(72.40%) wps: 412 mfu: 7.64% [rank2]:2024-06-13 00:20:29,175 - root - INFO - step: 5 loss: 11.4519 memory: 65.13GiB(82.29%) wps: 590 mfu: 10.95% [rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10 loss: 10.2199 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.79% [rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15 loss: 9.3509 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.92% [rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20 loss: 8.7972 memory: 65.13GiB(82.29%) wps: 629 mfu: 11.68% [rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25 loss: 8.2348 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.91% [rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30 loss: 7.7037 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35 loss: 7.4639 memory: 65.13GiB(82.29%) wps: 641 mfu: 11.90% [rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40 loss: 7.2406 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45 loss: 7.1822 memory: 65.13GiB(82.29%) wps: 640 mfu: 11.87% [rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50 loss: 7.0580 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55 loss: 6.9888 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60 loss: 7.0387 memory: 65.13GiB(82.29%) wps: 638 mfu: 11.84% [rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65 loss: 6.9199 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70 loss: 6.9503 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75 loss: 6.7960 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80 loss: 6.6798 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85 loss: 6.6504 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90 loss: 6.6655 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.81% [rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95 loss: 6.7359 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.78% [rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100 loss: 6.5410 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.85% ``` [ghstack-poisoned]
…nnotation" **Summary** This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. **Test Plan** Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings: ``` [job] dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 5 enable_tensorboard = false save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" lr = 3e-4 [training] batch_size = 4 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 100 data_parallel_degree = -1 tensor_parallel_degree = 4 pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" [checkpoint] enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy ``` 1. with `norm_type = "rmsnorm"` ``` [rank2]:2024-06-13 00:47:55,607 - root - INFO - step: 1 loss: 12.2262 memory: 57.70GiB(72.89%) wps: 429 mfu: 7.96% [rank2]:2024-06-13 00:48:57,536 - root - INFO - step: 5 loss: 11.4801 memory: 65.53GiB(82.78%) wps: 529 mfu: 9.82% [rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10 loss: 10.2305 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15 loss: 9.3287 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.09% [rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20 loss: 8.7126 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.19% [rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25 loss: 8.2011 memory: 65.53GiB(82.78%) wps: 591 mfu: 10.98% [rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30 loss: 7.7424 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35 loss: 7.4964 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40 loss: 7.2799 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.16% [rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45 loss: 7.2280 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.10% [rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50 loss: 7.0669 memory: 65.53GiB(82.78%) wps: 602 mfu: 11.17% [rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55 loss: 6.9967 memory: 65.53GiB(82.78%) wps: 595 mfu: 11.04% [rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60 loss: 7.0763 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65 loss: 6.9260 memory: 65.53GiB(82.78%) wps: 603 mfu: 11.20% [rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70 loss: 6.9757 memory: 65.53GiB(82.78%) wps: 601 mfu: 11.15% [rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75 loss: 6.8074 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% [rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80 loss: 6.7362 memory: 65.53GiB(82.78%) wps: 597 mfu: 11.08% [rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85 loss: 6.7016 memory: 65.53GiB(82.78%) wps: 598 mfu: 11.09% [rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90 loss: 6.6640 memory: 65.53GiB(82.78%) wps: 596 mfu: 11.06% [rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95 loss: 6.7214 memory: 65.53GiB(82.78%) wps: 604 mfu: 11.20% [rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100 loss: 6.5953 memory: 65.53GiB(82.78%) wps: 600 mfu: 11.14% ``` 2. with `norm_type = "fused_rmsnorm"` ``` [rank2]:2024-06-13 00:19:33,609 - root - INFO - step: 1 loss: 12.2194 memory: 57.31GiB(72.40%) wps: 412 mfu: 7.64% [rank2]:2024-06-13 00:20:29,175 - root - INFO - step: 5 loss: 11.4519 memory: 65.13GiB(82.29%) wps: 590 mfu: 10.95% [rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10 loss: 10.2199 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.79% [rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15 loss: 9.3509 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.92% [rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20 loss: 8.7972 memory: 65.13GiB(82.29%) wps: 629 mfu: 11.68% [rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25 loss: 8.2348 memory: 65.13GiB(82.29%) wps: 642 mfu: 11.91% [rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30 loss: 7.7037 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35 loss: 7.4639 memory: 65.13GiB(82.29%) wps: 641 mfu: 11.90% [rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40 loss: 7.2406 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.80% [rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45 loss: 7.1822 memory: 65.13GiB(82.29%) wps: 640 mfu: 11.87% [rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50 loss: 7.0580 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55 loss: 6.9888 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60 loss: 7.0387 memory: 65.13GiB(82.29%) wps: 638 mfu: 11.84% [rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65 loss: 6.9199 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70 loss: 6.9503 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75 loss: 6.7960 memory: 65.13GiB(82.29%) wps: 637 mfu: 11.83% [rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80 loss: 6.6798 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85 loss: 6.6504 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.86% [rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90 loss: 6.6655 memory: 65.13GiB(82.29%) wps: 636 mfu: 11.81% [rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95 loss: 6.7359 memory: 65.13GiB(82.29%) wps: 635 mfu: 11.78% [rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100 loss: 6.5410 memory: 65.13GiB(82.29%) wps: 639 mfu: 11.85% ``` [ghstack-poisoned]
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.
#364 is created from |
I get an error |
@fabianlim That's true. |
@fabianlim curious if you are using fused_rmsnorm or not? currently torchtitan is evolving and depending on nightlies, we'll start committing to BC with code releases. If you want to unblock, you can either use a older commit, or do some simple changes in this file https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py#L17 from
to
|
@XilunWu yea i brought it up beacuse it is not clear from the requirements file that it was only for nightly torch https://github.com/pytorch/torchtitan/blob/main/.ci/docker/requirements.txt#L1 @wanchaol yea i downgraded and Im fine now |
Summary This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. pytorch#364
Summary This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. pytorch#364
Summary
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. #364