diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py index e29eede07d8751..17a6aebd8f9361 100644 --- a/test/distributed/_tensor/test_dtensor.py +++ b/test/distributed/_tensor/test_dtensor.py @@ -331,6 +331,11 @@ def test_to_local(self): except RuntimeError: self.assertEqual(sharded_tensor.grad.stride(), [1, 3 * self.world_size]) + # test the case under no-grad we directly return the local tensor + with torch.no_grad(): + local_no_grad = sharded_tensor.to_local() + assert local_no_grad is sharded_tensor._local_tensor + @with_comms def test_to_local_grad_hint(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py index 49fe7267c6340e..be887f3ce6ca8c 100644 --- a/torch/distributed/_tensor/api.py +++ b/torch/distributed/_tensor/api.py @@ -418,6 +418,9 @@ def to_local( .. note:: `to_local` is differentiable, the `requires_grad` of the local tensor returned will depend on if the `DTensor` requires_grad or not. """ + if not torch.is_grad_enabled(): + return self._local_tensor + if grad_placements is not None and not isinstance(grad_placements, tuple): grad_placements = tuple(grad_placements) return _ToTorchTensor.apply(