Skip to content

Commit

Permalink
Fix GPU-CPU device mismatch error in util filter_dilated_rows (pytorc…
Browse files Browse the repository at this point in the history
…h#633)

Summary:
## Types of changes

- [x] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)
- [ ] Docs change / refactoring / dependency upgrade

## Motivation and Context / Related issue

The function `filter_dilated_rows` in `tensor_utils.py` converts a tensor to an ndarray, modifies the ndarray, and converts the modified ndarray back to a tensor.

**Bug:**
If the original tensor is not on the CPU, the conversion to ndarray will fail because tensor.cpu() is not called.
```
File "opacus/utils/tensor_utils.py", line 328, in filter_dilated_rows
    tensor_np = tensor.numpy()
TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
```

**Fix:**
This PR directly modifies the tensor without ever converting it to an ndarray. This fixes the bug and is more efficient than the original implementation.

## How Has This Been Tested (if it applies)

Manually tested with the example provided in the function's DocString.

Also, `filter_dilated_rows` is called if the dilation of a 3d convolution is not 1. Thus, this function is implicitly tested by `tests/grad_samples/conv3d_test.py`.

## Checklist

- [x] The documentation is up-to-date with the changes I made.
- [x] I have read the **CONTRIBUTING** document and completed the CLA (see **CONTRIBUTING**).
- [x] All tests passed, and additional code has been covered with new tests.

Pull Request resolved: pytorch#633

Reviewed By: karthikprasad

Differential Revision: D54199129

fbshipit-source-id: 56026a8f298517e27b67cf77de06f94ab63d0a9c
  • Loading branch information
tklausen authored and facebook-github-bot committed Mar 5, 2024
1 parent ac639af commit 32a465b
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions opacus/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,14 +322,15 @@ def filter_dilated_rows(
kernel_rank = len(kernel_size)

indices_to_keep = [
list(range(0, dilated_kernel_size[i], dilation[i])) for i in range(kernel_rank)
torch.arange(0, dilated_kernel_size[i], dilation[i], device=tensor.device)
for i in range(kernel_rank)
]

tensor_np = tensor.numpy()

axis_offset = len(tensor.shape) - kernel_rank

for dim in range(kernel_rank):
tensor_np = np.take(tensor_np, indices_to_keep[dim], axis=axis_offset + dim)
tensor = torch.index_select(
tensor, dim=axis_offset + dim, index=indices_to_keep[dim]
)

return torch.Tensor(tensor_np)
return tensor

0 comments on commit 32a465b

Please sign in to comment.