Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues for MCore DDP. #1474

Merged
merged 11 commits into from
Feb 19, 2025
Merged

Conversation

Victarry
Copy link
Contributor

@Victarry Victarry commented Feb 11, 2025

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

4: [rank4]: Traceback (most recent call last):
4: [rank4]:   File "/workspace/megatron-lm/pretrain_gpt.py", line 245, in <module>
4: [rank4]:     pretrain(
4: [rank4]:   File "/workspace/megatron-lm/megatron/training/training.py", line 313, in pretrain
4: [rank4]:     iteration, num_floating_point_operations_so_far = train(
4: [rank4]:                                                       ^^^^^^
4: [rank4]:   File "/workspace/megatron-lm/megatron/training/training.py", line 1157, in train
4: [rank4]:     train_step(forward_step_func,
4: [rank4]:   File "/workspace/megatron-lm/megatron/training/training.py", line 631, in train_step
4: [rank4]:     losses_reduced = forward_backward_func(
4: [rank4]:                      ^^^^^^^^^^^^^^^^^^^^^^
4: [rank4]:   File "/workspace/megatron-lm/megatron/core/pipeline_parallel/schedules.py", line 456, in forward_backward_no_pipelining
4: [rank4]:     backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)
4: [rank4]:   File "/workspace/megatron-lm/megatron/core/pipeline_parallel/schedules.py", line 356, in backward_step
4: [rank4]:     custom_backward(output_tensor[0], output_tensor_grad[0])
4: [rank4]:   File "/workspace/megatron-lm/megatron/core/pipeline_parallel/schedules.py", line 155, in custom_backward
4: [rank4]:     Variable._execution_engine.run_backward(
4: [rank4]:   File "/workspace/megatron-lm/megatron/core/distributed/distributed_data_parallel.py", line 223, in param_hook
4: [rank4]:     param.grad is not None
4: [rank4]: AssertionError: param.grad being None is not safe when overlap_grad_reduce is True

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change prepare_for_saving from tensor_list.append(tensor.data) to tensor_list.append(tensor). Since this will remove params attributes like grad_added_to_main_grad
  • Add .data to CPU offload hook. (Details of reason on Fix issues for MCore DDP. #1474 (comment))
  • Revert the return value of wgrad to empty tensor instead of None, since DDP with backward overlap requires tensor value for wgrad.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@Victarry Victarry force-pushed the denliu/fix_mcore_ddp branch from ec1d3ec to 4997b56 Compare February 11, 2025 04:55
@ksivaman
Copy link
Member

@Victarry Could you sign-off your commits? Here is the guide.

Signed-off-by: Dennis Liu <[email protected]>
@Victarry Victarry force-pushed the denliu/fix_mcore_ddp branch from 4997b56 to 3ffd732 Compare February 11, 2025 05:24
@Victarry
Copy link
Contributor Author

@Victarry Could you sign-off your commits? Here is the guide.

Thanks. Done

@ksivaman
Copy link
Member

/te-ci pytorch

@timmoon10
Copy link
Collaborator

timmoon10 commented Feb 11, 2025

@Victarry Just to confirm, Mcore now requires param.grad to be allocated when gradient_accumulation_fusion=True? This will avoid some race conditions with backward hooks (hooks are launched on a different thread if grad is None), but also add unnecessary memory usage. Also, does the distributed optimizer also have this requirement?

@Victarry
Copy link
Contributor Author

Mcore now requires param.grad to be allocated when gradient_accumulation_fusion=True?

MCore always requires param.grad to be allocated when gradient_accumulation_fusion=True. But TE2.0 changed the return value from empty tensor to None.
https://github.com/NVIDIA/TransformerEngine/blame/49a4535d1addd2c5743a7e280e2f4f2640f0bedf/transformer_engine/pytorch/module/linear.py#L609

Also, does the distributed optimizer also have this requirement?

I'm not familiar with the distributed optimizer. Maybe @deepakn94 can provide some comments?

@Victarry Victarry force-pushed the denliu/fix_mcore_ddp branch from ddd2c1a to 6a2d88a Compare February 13, 2025 06:55
@Victarry Victarry force-pushed the denliu/fix_mcore_ddp branch from 6a2d88a to 594ea31 Compare February 13, 2025 06:57
@yaox12
Copy link
Collaborator

yaox12 commented Feb 13, 2025

/te-ci pytorch

@Victarry
Copy link
Contributor Author

Victarry commented Feb 13, 2025

Change prepare_for_saving from tensor_list.append(tensor.data) to tensor_list.append(tensor). Since this will remove params attributes like grad_added_to_main_grad

I found above change will cause UT failing with CPU offloading, and the reaons are as follows:

tensors_to_save, tensor_objects = prepare_for_saving(
saved_inputmat,
weightmat,
weight,
bias,
)

  1. With BF16 training, weightmat and weight point to the same tensor. And the CPU offload hook will be applied twice on them.
  2. During the offloading hook to weightmat, the data of weightmat will be copied to CPU and then its data is set to blank tensor in
    tensor_on_device.data = torch.Tensor() # Force to release memory
  3. During the offloading hook to weight, only a blank tensor is saved, which will cause size mismatch after restoring in the backward

In the original version code, tensor.data will create two tensor objects, such that the force release will not influence each other. But the underlying tensor data is actually offloaded twice.

@Victarry
Copy link
Contributor Author

Victarry commented Feb 13, 2025

To make the fix MR simple and make MCore work as soon as possible, I added .data to save_for_backward hook in CPU offload handler.

@deepakn94
Copy link
Contributor

@Victarry Just to confirm, Mcore now requires param.grad to be allocated when gradient_accumulation_fusion=True? This will avoid some race conditions with backward hooks, but also add unnecessary memory usage. Also, does the distributed optimizer also have this requirement?

Yes, distributed optimizer also has this requirement for the same reason.

@deepakn94
Copy link
Contributor

Mcore now requires param.grad to be allocated when gradient_accumulation_fusion=True?

MCore always requires param.grad to be allocated when gradient_accumulation_fusion=True. But TE2.0 changed the return value from empty tensor to None. https://github.com/NVIDIA/TransformerEngine/blame/49a4535d1addd2c5743a7e280e2f4f2640f0bedf/transformer_engine/pytorch/module/linear.py#L609

Yup, exactly. MCore has had this requirement for the last year plus. Changing the return value to None is a breaking change for us.

@timmoon10
Copy link
Collaborator

/te-ci pytorch

@ptrendx ptrendx added the 2.1.0 label Feb 15, 2025
@yaox12
Copy link
Collaborator

yaox12 commented Feb 17, 2025

/te-ci pytorch

@Victarry
Copy link
Contributor Author

@timmoon10 @ksivaman could I have your approve for this bug fix? It seems the UT fails with unrelated bugs. Thanks!

@ksivaman
Copy link
Member

/te-ci pytorch

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@timmoon10 timmoon10 merged commit 978f1d7 into NVIDIA:main Feb 19, 2025
11 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants