-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathnets_jax.py
336 lines (268 loc) · 9.81 KB
/
nets_jax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
import jax
import jax.numpy as jnp
import equinox as eqx
from functools import partial
from typing import Any, Callable, List, Optional
from inspect import isfunction
from equinox import field
from metrics_jax import pad2d_circular, pad2d_reflect
import diffrax
from reg_lib_jax import RegularizedODEfunc
from einops import rearrange
import math
def zero_init(model):
leaves, treedef = jax.tree_util.tree_flatten(model, eqx.is_array)
zero_leaves = jax.tree.map(lambda x: jnp.zeros(x.shape), leaves)
return jax.tree_util.tree_unflatten(treedef, zero_leaves)
class DefaultConv2d(eqx.nn.Conv2d):
"""
A default 2D convolution module with 3x3 kernel, same padding, and circular padding mode
This is implemented as a subclass of `eqx.nn.Conv2d`
because equinox v0.11.3 doesn't support `padding_mode` argument
"""
def __init__(self, dim, dim_out, *, key):
super().__init__(dim, dim_out, 3, padding=0, key=key)
def __call__(self, x: jax.Array, *, key: Any | None = None) -> jax.Array:
x = pad2d_circular(x)
return super().__call__(x, key=key)
class SpatialLinear(eqx.nn.Conv2d):
"""
A spatial linear module, which is a 1x1 convolution without padding
"""
def __init__(self, dim, dim_out, *, key):
super().__init__(dim, dim_out, 1, padding=0, key=key)
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
class Residual(eqx.Module):
fn: eqx.Module
def __init__(self, fn):
self.fn = fn
def __call__(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
def image_resize(x, factor):
h, w = x.shape[-2:]
if factor > 1:
x = jax.image.resize(
x, (*x.shape[:-2], h * factor, w * factor), method="bilinear"
)
elif factor < 1:
x = jax.image.resize(
x, (*x.shape[:-2], int(h * factor), int(w * factor)), method="bilinear"
)
return x
class Resample(eqx.Module):
factor: int = eqx.field(static=True)
conv: DefaultConv2d
def __init__(self, dim_in, dim_out, factor, *, key):
self.factor = factor
self.conv = DefaultConv2d(dim_in, dim_out, key=key)
def __call__(self, x):
x = image_resize(x, self.factor)
return self.conv(x)
class SinusoidalPosEmb(eqx.Module):
# Set as a static field due to compatibility reason with old equinox versions
emb: jax.Array = eqx.field(static=True)
def __init__(self, dim):
half_dim = dim // 2
emb = math.log(10000) / (half_dim - 1)
self.emb = jnp.exp(jnp.arange(half_dim) * -emb)
def __call__(self, x):
emb = x * self.emb
emb = jnp.concatenate((jnp.sin(emb), jnp.cos(emb)), axis=-1)
return emb
class ConvBlock(eqx.Module):
proj: DefaultConv2d
norm: eqx.nn.GroupNorm
mlp: Optional[eqx.nn.Sequential]
act: Callable
def __init__(
self, dim, dim_out, kernel_size=3, emb_dim=None, act=jax.nn.silu, *, key
):
super().__init__()
keys = jax.random.split(key, 2)
assert kernel_size == 1 or kernel_size == 3, "kernel size must be 1 or 3"
if kernel_size == 3:
self.proj = DefaultConv2d(dim, dim_out, key=keys[0])
else:
self.proj = SpatialLinear(dim, dim_out, key=keys[0])
self.norm = eqx.nn.GroupNorm(
min(dim_out // 4, 32), dim_out, channelwise_affine=not exists(emb_dim)
)
self.mlp = (
eqx.nn.Sequential(
[
eqx.nn.Lambda(jax.nn.silu),
zero_init(eqx.nn.Linear(emb_dim, dim_out * 2, key=keys[1])),
]
)
if exists(emb_dim)
else None
)
self.act = act
def __call__(self, x, emb=None):
x = self.proj(x)
x = self.norm(x)
if exists(self.mlp) and exists(emb):
scale_shift = self.mlp(emb)
scale_shift = rearrange(scale_shift, "c -> c 1 1")
scale, shift = jnp.split(scale_shift, 2, axis=0)
# scale + 1 to avoid random scale at initialization
x = x * (scale + 1) + shift
x = self.act(x)
return x
class LinearTimeSelfAttention(eqx.Module):
group_norm: eqx.nn.GroupNorm
heads: int
to_qkv: eqx.nn.Conv2d
to_out: eqx.nn.Conv2d
def __init__(self, dim, heads=4, dim_head=32, *, key):
keys = jax.random.split(key, 2)
self.group_norm = eqx.nn.GroupNorm(min(dim // 4, 32), dim)
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = eqx.nn.Conv2d(dim, hidden_dim * 3, 1, key=keys[0])
self.to_out = eqx.nn.Conv2d(hidden_dim, dim, 1, key=keys[1])
# model surgery: zero init for better training
self.to_out = zero_init(self.to_out)
def __call__(self, x):
c, h, w = x.shape
x = self.group_norm(x)
qkv = self.to_qkv(x)
q, k, v = rearrange(
qkv, "(qkv heads c) h w -> qkv heads c (h w)", heads=self.heads, qkv=3
)
k = jax.nn.softmax(k, axis=-1)
context = jnp.einsum("hdn,hen->hde", k, v)
out = jnp.einsum("hde,hdn->hen", context, q)
out = rearrange(
out, "heads c (h w) -> (heads c) h w", heads=self.heads, h=h, w=w
)
return self.to_out(out)
# Neural Differential Appearance Equations
class NDAE(eqx.Module):
init_conv: ConvBlock
sinusoidal_pos_emb: SinusoidalPosEmb
time_mlp: eqx.nn.MLP
downs: List[List[eqx.Module]]
mid: List[eqx.Module]
ups: List[List[eqx.Module]]
final_conv: List[eqx.Module]
def __init__(
self,
dim,
in_dim=None,
out_dim=None,
dim_mults=(1, 2),
use_attn=True,
attn_heads=4,
attn_head_dim=8,
*,
key
):
super().__init__()
assert dim_mults[0] == 1, "first dim_mult must be 1"
keys = jax.random.split(key, 6)
in_dim = default(in_dim, dim)
out_dim = default(out_dim, dim)
self.init_conv = ConvBlock(in_dim, dim, kernel_size=1, key=keys[0])
# time embeddings
time_dim = dim * 2
self.sinusoidal_pos_emb = SinusoidalPosEmb(dim)
self.time_mlp = eqx.nn.MLP(
dim, time_dim, time_dim, 1, activation=jax.nn.silu, key=keys[1]
)
# down, mid, and up layers
dims = [dim * mult for mult in dim_mults]
in_out = list(zip(dims[:-1], dims[1:]))
self.downs = []
self.ups = []
convblock = partial(ConvBlock, emb_dim=time_dim)
attn = partial(
LinearTimeSelfAttention, heads=attn_heads, dim_head=attn_head_dim
)
down_keys = jax.random.split(keys[2], len(in_out))
for ind, (dim_in, dim_out) in enumerate(in_out):
_keys = jax.random.split(down_keys[ind], 3)
down = [
convblock(dim_in, dim_in, key=_keys[0]),
Residual(attn(dim_in, key=_keys[1])) if use_attn else eqx.nn.Identity(),
Resample(dim_in, dim_out, 0.5, key=_keys[2]),
]
self.downs.append(down)
mid_dim = dims[-1]
mid_keys = jax.random.split(keys[3], 2)
self.mid = [
convblock(mid_dim, mid_dim, key=mid_keys[0]),
Residual(attn(mid_dim, key=mid_keys[1])) if use_attn else eqx.nn.Identity(),
]
up_keys = jax.random.split(keys[4], len(in_out))
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
_keys = jax.random.split(up_keys[ind], 3)
up = [
Resample(dim_out, dim_in, 2, key=_keys[0]),
convblock(dim_in * 2, dim_in, key=_keys[1]),
Residual(attn(dim_in, key=_keys[2])) if use_attn else eqx.nn.Identity(),
]
self.ups.append(up)
final_keys = jax.random.split(keys[5], 2)
self.final_conv = [
# use sigmoid activation to avoid unbounded ODE
ConvBlock(
dim * 2, dim, kernel_size=1, act=jax.nn.sigmoid, key=final_keys[0]
),
SpatialLinear(dim, out_dim, key=final_keys[1]),
]
# model surgery: zero init for better training
self.final_conv[-1] = zero_init(self.final_conv[-1])
@partial(eqx.filter_vmap, in_axes=(None, None, 0, None))
def __call__(self, time, x, args=None):
t = self.time_mlp(self.sinusoidal_pos_emb(time))
x = self.init_conv(x)
h = []
h.append(x)
for block, attn, downsample in self.downs:
x = block(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
block, attn = self.mid
x = block(x, t)
x = attn(x)
for upsample, block, attn in self.ups:
x = upsample(x)
x = jnp.concatenate((x, h.pop()), axis=0)
x = block(x, t)
x = attn(x)
x = jnp.concatenate((x, h.pop()), axis=0)
for layers in self.final_conv:
x = layers(x)
assert len(h) == 0, "all hidden states should be used"
return x
class NeuralODE(eqx.Module):
odefunc: eqx.Module
n_reg: int = field(static=True)
def __init__(self, odefunc, reg_fns=()):
self.n_reg = len(reg_fns)
self.odefunc = RegularizedODEfunc(odefunc, reg_fns)
def __call__(self, t0, t1, y0, get_reg=False, key=None, **kwargs):
states = {
"x": y0,
}
args = {"get_reg": get_reg}
if get_reg and self.n_reg:
states["reg"] = jnp.zeros(self.n_reg)
args["_e"] = jax.random.normal(key, y0.shape)
solution = diffrax.diffeqsolve(
diffrax.ODETerm(self.odefunc),
t0=t0,
t1=t1,
y0=states,
args=args,
# max_steps=100_000, # no limit
**kwargs,
)
return solution