-
Notifications
You must be signed in to change notification settings - Fork 364
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Fix fuse_wgrad_accumulation for GroupedLinear #1488
Conversation
Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Xin Yao <[email protected]>
/te-ci pytorch |
Signed-off-by: Xin Yao <[email protected]>
/te-ci pytorch |
ctx.weights_shape_1 = weights[0].shape[1] | ||
|
||
tensors_to_save, tensor_objects = prepare_for_saving(*inputmats, *weights_fp8, *biases) | ||
ctx.save_for_backward(*tensors_to_save) | ||
ctx.tensor_objects = tensor_objects | ||
|
||
ctx.weights_requires_grad = weights[0].requires_grad | ||
if fuse_wgrad_accumulation and ctx.weights_requires_grad: | ||
ctx.main_grads = [weights[i].main_grad for i in range(num_gemms)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's recommended to use ctx.save_for_backward
instead of storing tensors directly in ctx
. They warn about messing up the grad graph and memory leaks, although I'm not sure what cases they are specifically worried about.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. We were saving main_grad
tensors using ctx.save_for_backward
in TE 1.x. But I'm seeing there is comment here.
# Since main_grad can be modified inplace, it should not be a part of saved_tensors |
I'm wondering if we have seen issues with
ctx.save_for_backward
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, we should follow the example of Linear
then.
@ksivaman This change is from commit 7e58678
in the internal repo. Do you remember why we can't store main_grad
in saved_tensors
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I know, previous prepare_for_saving
saves tensor.data
instead of the tensor itself. So for main_grad
that need to be modified inplace, this could be an issue.
Now #1474 changed prepare_for_saving
to save the tensor itself, this is no longer a problem.
outputs.append(p.grad) | ||
if getattr(p, "main_grad", None) is not None: | ||
outputs.append(p.main_grad) | ||
assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It turns out Mcore expects p.grad
to not be None
: #1474 (comment)
#1474 sets grad
to an uninitialized tensor and assumes Mcore will ignore it.
/te-ci pytorch |
Description
Due to the wrong indent, the wgrad computation is not called when
ctx.fuse_wgrad_accumulation == True
.Also update the test to cover this case.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: