Skip to content

Commit

Permalink
[dtensor] directly return local_tensor under no_grad (pytorch#128145)
Browse files Browse the repository at this point in the history
as titled, skip the autograd function and directly return the
local_tensor if it's under no_grad context, this would avoid creating
views

Pull Request resolved: pytorch#128145
Approved by: https://github.com/awgu
ghstack dependencies: pytorch#128112
  • Loading branch information
wanchaol authored and pytorchmergebot committed Jun 7, 2024
1 parent 747fc35 commit 3df53c2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
5 changes: 5 additions & 0 deletions test/distributed/_tensor/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,11 @@ def test_to_local(self):
except RuntimeError:
self.assertEqual(sharded_tensor.grad.stride(), [1, 3 * self.world_size])

# test the case under no-grad we directly return the local tensor
with torch.no_grad():
local_no_grad = sharded_tensor.to_local()
assert local_no_grad is sharded_tensor._local_tensor

@with_comms
def test_to_local_grad_hint(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
Expand Down
3 changes: 3 additions & 0 deletions torch/distributed/_tensor/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,9 @@ def to_local(
.. note:: `to_local` is differentiable, the `requires_grad` of the local tensor returned
will depend on if the `DTensor` requires_grad or not.
"""
if not torch.is_grad_enabled():
return self._local_tensor

if grad_placements is not None and not isinstance(grad_placements, tuple):
grad_placements = tuple(grad_placements)
return _ToTorchTensor.apply(
Expand Down

0 comments on commit 3df53c2

Please sign in to comment.