Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues for MCore DDP. #1474

Merged
merged 11 commits into from
Feb 19, 2025
4 changes: 3 additions & 1 deletion transformer_engine/pytorch/cpu_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ def __init__(
super().__init__()

def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs)
retrieve_identifier = self.offload_handler.tensor_push(
tensor.data, **self.handler_extra_kwargs
)
return retrieve_identifier

def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
Expand Down
19 changes: 12 additions & 7 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def backward(
( # pylint: disable=unbalanced-tuple-unpacking
inputmat,
weight,
_,
origin_weight,
bias,
ln_weight,
ln_out,
Expand Down Expand Up @@ -722,17 +722,22 @@ def backward(

if ctx.requires_wgrad:
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(weight, "grad_added_to_main_grad"):
weight.grad_added_to_main_grad = True
if getattr(weight, "zero_out_wgrad", False):
if ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"):
origin_weight.grad_added_to_main_grad = True
if getattr(origin_weight, "zero_out_wgrad", False):
wgrad = torch.zeros(
weight.main_grad.shape,
dtype=weight.dtype,
origin_weight.main_grad.shape,
dtype=origin_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
wgrad = None
wgrad = torch.empty(
origin_weight.main_grad.shape,
dtype=origin_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else:
Expand Down
7 changes: 6 additions & 1 deletion transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
requires_grad=False,
)
else:
wgrad = None
wgrad = torch.empty(
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else:
Expand Down
22 changes: 6 additions & 16 deletions transformer_engine/pytorch/tensor/quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def prepare_for_saving(
tensor_list.append(None)
tensor_objects_list.append(None)
elif type(tensor) in (torch.Tensor, torch.nn.Parameter):
tensor_list.append(tensor.data)
tensor_list.append(tensor)
tensor_objects_list.append(None)
else:
t, t_obj = tensor.prepare_for_saving()
Expand Down Expand Up @@ -116,10 +116,7 @@ def update_quantized(
"""Quantize tensor in-place"""

def quantize(
self,
tensor: torch.Tensor,
*,
out: Optional[QuantizedTensor] = None,
self, tensor: torch.Tensor, *, out: Optional[QuantizedTensor] = None
) -> QuantizedTensor:
"""Quantize tensor"""
if out is not None:
Expand Down Expand Up @@ -159,10 +156,7 @@ def calibrate(self, tensor: torch.Tensor) -> None:
"""

def set_usage(
self,
*,
rowwise: Optional[bool] = None,
columnwise: Optional[bool] = None,
self, *, rowwise: Optional[bool] = None, columnwise: Optional[bool] = None
) -> None:
"""Set how the quantized tensor is expected to be used

Expand Down Expand Up @@ -194,8 +188,7 @@ def forward(

@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
_ctx: torch.autograd.function.FunctionCtx, grad: torch.Tensor # unused
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision
Expand All @@ -212,9 +205,7 @@ class _IdentityFunc(torch.autograd.Function):

@staticmethod
def forward(
ctx,
tensor: QuantizedTensor,
init_kwargs: Optional[Dict[str, Any]] = None,
ctx, tensor: QuantizedTensor, init_kwargs: Optional[Dict[str, Any]] = None
) -> QuantizedTensor:
# pylint: disable=missing-function-docstring

Expand Down Expand Up @@ -408,8 +399,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
return torch._C._disabled_torch_function_impl(func, types, args, kwargs)

def contiguous(
self,
memory_format: torch.memory_format = torch.contiguous_format,
self, memory_format: torch.memory_format = torch.contiguous_format
) -> QuantizedTensor:
# pylint: disable=missing-function-docstring
raise NotImplementedError(
Expand Down