Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Fix fuse_wgrad_accumulation for GroupedLinear #1488

Merged
merged 4 commits into from
Feb 19, 2025

Conversation

yaox12
Copy link
Collaborator

@yaox12 yaox12 commented Feb 17, 2025

Description

Due to the wrong indent, the wgrad computation is not called when ctx.fuse_wgrad_accumulation == True.

Also update the test to cover this case.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@yaox12 yaox12 added bug Something isn't working 2.1.0 labels Feb 17, 2025
@yaox12 yaox12 requested a review from timmoon10 February 17, 2025 08:22
@yaox12
Copy link
Collaborator Author

yaox12 commented Feb 17, 2025

/te-ci pytorch

Signed-off-by: Xin Yao <[email protected]>
@yaox12
Copy link
Collaborator Author

yaox12 commented Feb 17, 2025

/te-ci pytorch

ctx.weights_shape_1 = weights[0].shape[1]

tensors_to_save, tensor_objects = prepare_for_saving(*inputmats, *weights_fp8, *biases)
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects

ctx.weights_requires_grad = weights[0].requires_grad
if fuse_wgrad_accumulation and ctx.weights_requires_grad:
ctx.main_grads = [weights[i].main_grad for i in range(num_gemms)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's recommended to use ctx.save_for_backward instead of storing tensors directly in ctx. They warn about messing up the grad graph and memory leaks, although I'm not sure what cases they are specifically worried about.

Copy link
Collaborator Author

@yaox12 yaox12 Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. We were saving main_grad tensors using ctx.save_for_backward in TE 1.x. But I'm seeing there is comment here.

# Since main_grad can be modified inplace, it should not be a part of saved_tensors

I'm wondering if we have seen issues with ctx.save_for_backward?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, we should follow the example of Linear then.

@ksivaman This change is from commit 7e58678 in the internal repo. Do you remember why we can't store main_grad in saved_tensors?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I know, previous prepare_for_saving saves tensor.data instead of the tensor itself. So for main_grad that need to be modified inplace, this could be an issue.

Now #1474 changed prepare_for_saving to save the tensor itself, this is no longer a problem.

outputs.append(p.grad)
if getattr(p, "main_grad", None) is not None:
outputs.append(p.main_grad)
assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out Mcore expects p.grad to not be None: #1474 (comment)
#1474 sets grad to an uninitialized tensor and assumes Mcore will ignore it.

@timmoon10 timmoon10 self-requested a review February 18, 2025 23:33
@timmoon10
Copy link
Collaborator

/te-ci pytorch

@timmoon10 timmoon10 merged commit fceff07 into NVIDIA:main Feb 19, 2025
11 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
2.1.0 bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants