-
Notifications
You must be signed in to change notification settings - Fork 1
/
reduce_python.py
218 lines (169 loc) · 7.36 KB
/
reduce_python.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
from typing import Optional
import torch
from torch import Tensor
def _reduce_grad(gradient: Tensor, keys: Tensor, indexes: Tensor) -> (Tensor, Tensor):
device = gradient.device
dtype = gradient.dtype
new_keys = keys.clone()
for grad_i, grad_key in enumerate(keys):
new_keys[grad_i, 0] = indexes[grad_key[0]]
reduced_keys, grad_indexes = torch.unique(new_keys, dim=0, return_inverse=True)
new_shape = (len(reduced_keys),) + gradient.shape[1:]
reduced_gradient = torch.zeros(new_shape, dtype=dtype, device=device)
reduced_gradient.index_add_(0, grad_indexes.to(gradient.device), gradient)
return reduced_gradient, reduced_keys
def reduce(
values: Tensor,
keys: Tensor,
dim: int,
positions_grad: Optional[Tensor] = None,
positions_grad_keys: Optional[Tensor] = None,
cell_grad: Optional[Tensor] = None,
cell_grad_keys: Optional[Tensor] = None,
) -> (
(Tensor, Tensor),
(Optional[Tensor], Optional[Tensor]),
(Optional[Tensor], Optional[Tensor]),
):
# `keys` contains a description of the rows of `values`. Each row in `key`
# correspond to a row in `values` (respectively for `reduced_values` /
# `reduced_keys`)
#
# `positions_grad` contains the gradients w.r.t. positions of values, and
# `positions_grad_keys` describes the rows of `positions_grad`. The first
# column in `positions_grad_keys` is always the row in `values` we are
# taking the gradient of.
#
# Similar considerations apply to `cell_grad` / `cell_grad_keys`
device = values.device
dtype = values.dtype
assert keys.dim() == 2, "keys should have only two dimensions"
reduced_keys, indexes = torch.unique(keys[:, dim], return_inverse=True)
new_shape = (len(reduced_keys),) + values.shape[1:]
reduced_values = torch.zeros(new_shape, dtype=dtype, device=device)
reduced_values.index_add_(0, indexes.to(values.device), values)
if positions_grad is not None:
assert positions_grad_keys is not None
assert positions_grad.device == device
assert positions_grad.dtype == dtype
result = _reduce_grad(positions_grad, positions_grad_keys, indexes)
reduced_positions_grad, reduced_positions_grad_keys = result
else:
reduced_positions_grad = None
reduced_positions_grad_keys = None
if cell_grad is not None:
assert cell_grad_keys is not None
assert cell_grad.device == device
result = _reduce_grad(cell_grad, cell_grad_keys, indexes)
reduced_cell_grad, reduced_cell_grad_keys = result
else:
reduced_cell_grad = None
reduced_cell_grad_keys = None
reduced_values = (reduced_values, reduced_keys.reshape(-1, 1))
reduced_positions_grad = (reduced_positions_grad, reduced_positions_grad_keys)
reduced_cell_grad = (reduced_cell_grad, reduced_cell_grad_keys)
return (reduced_values, reduced_positions_grad, reduced_cell_grad)
class ReduceValuesAutograd(torch.autograd.Function):
@staticmethod
def forward(
ctx, values: Tensor, keys: Tensor, dim: int
) -> (Tensor, Tensor, Tensor):
device = values.device
dtype = values.dtype
assert keys.dim() == 2, "keys should have only two dimensions"
reduced_keys = torch.unique(keys[:, dim])
indexes = torch.empty(values.shape[0], dtype=torch.int32, device=device)
mapping = []
for i, reduced_key in enumerate(reduced_keys):
idx = torch.where(keys[:, dim] == reduced_key)[0]
indexes.index_put_(
(idx,), torch.tensor(i, dtype=torch.int32, device=device)
)
mapping.append(idx)
new_shape = (len(reduced_keys),) + values.shape[1:]
reduced_values = torch.zeros(new_shape, dtype=dtype, device=device)
reduced_values.index_add_(0, indexes, values)
ctx.save_for_backward(values)
ctx.reduce_mapping = mapping
ctx.mark_non_differentiable(reduced_keys)
return reduced_values, reduced_keys.reshape(-1, 1), indexes
@staticmethod
def backward(ctx, grad_reduced_values, _grad_reduced_keys, _grad_indexes):
(values,) = ctx.saved_tensors
values_grad = None
if values.requires_grad:
values_grad = torch.zeros_like(values)
for i, idx in enumerate(ctx.reduce_mapping):
values_grad[idx] = grad_reduced_values[i]
return values_grad, None, None
class ReduceGradientAutograd(torch.autograd.Function):
@staticmethod
def forward(
ctx, gradient: Tensor, keys: Tensor, indexes: Tensor
) -> (Tensor, Tensor):
device = gradient.device
dtype = gradient.dtype
new_keys = keys.clone()
for grad_i, grad_key in enumerate(keys):
new_keys[grad_i, 0] = indexes[grad_key[0]]
reduced_keys = torch.unique(new_keys, dim=0)
indexes = torch.empty(gradient.shape[0], dtype=torch.int32, device=device)
mapping = []
for i, reduced_key in enumerate(reduced_keys):
# FIXME: this might be slow?
idx = torch.all(new_keys == reduced_key[None, :], axis=1)
mapping.append(idx)
indexes.index_put_(
(idx,), torch.tensor(i, dtype=torch.int32, device=device)
)
new_shape = (len(reduced_keys),) + gradient.shape[1:]
reduced_gradient = torch.zeros(new_shape, dtype=dtype, device=device)
reduced_gradient.index_add_(0, indexes, gradient)
ctx.save_for_backward(gradient)
ctx.reduce_mapping = mapping
ctx.mark_non_differentiable(reduced_keys)
return reduced_gradient, reduced_keys
@staticmethod
def backward(ctx, grad_reduced_gradient, _grad_reduced_keys):
(gradient,) = ctx.saved_tensors
gradient_grad = None
if gradient.requires_grad:
gradient_grad = torch.zeros_like(gradient)
for i, idx in enumerate(ctx.reduce_mapping):
# TODO: use index_put here as well?
gradient_grad[idx] = grad_reduced_gradient[i]
return gradient_grad, None, None
def reduce_custom_autograd(
values: Tensor,
keys: Tensor,
dim: int,
positions_grad: Optional[Tensor] = None,
positions_grad_keys: Optional[Tensor] = None,
cell_grad: Optional[Tensor] = None,
cell_grad_keys: Optional[Tensor] = None,
) -> (
(Tensor, Tensor),
(Optional[Tensor], Optional[Tensor]),
(Optional[Tensor], Optional[Tensor]),
):
device = values.device
dtype = values.dtype
values, keys, indexes = ReduceValuesAutograd.apply(values, keys, dim)
if positions_grad is not None:
assert positions_grad_keys is not None
assert positions_grad.device == device
assert positions_grad.dtype == dtype
results = ReduceGradientAutograd.apply(
positions_grad, positions_grad_keys, indexes
)
positions_grad, positions_grad_keys = results
if cell_grad is not None:
assert cell_grad_keys is not None
assert cell_grad.device == device
assert cell_grad.dtype == dtype
results = ReduceGradientAutograd.apply(cell_grad, cell_grad_keys, indexes)
cell_grad, cell_grad_keys = results
reduced_values = (values, keys)
reduced_positions_grad = (positions_grad, positions_grad_keys)
reduced_cell_grad = (cell_grad, keys)
return (reduced_values, reduced_positions_grad, reduced_cell_grad)