Skip to content

Commit

Permalink
clean up equinox mlp implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
grfrederic committed Feb 28, 2024
1 parent 4cfc2ca commit a12c8b9
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions src/bmi/estimators/neural/_critics.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,23 @@ def __init__(
- 8 -> 12
- 12 -> 1
"""
# We have in total the following dimensionalities:
dim_sizes = [dim_x + dim_y] + list(hidden_layers) + [1]
# ... and one layer less:
keys = jax.random.split(key, len(hidden_layers) + 1)

self.layers = []
key_hidden, key_final = jax.random.split(key)
keys_hidden = jax.random.split(key, len(hidden_layers))

for i, key in enumerate(keys):
self.layers.append(eqx.nn.Linear(dim_sizes[i], dim_sizes[i + 1], key=key))
dim_ins = [dim_x + dim_y] + list(hidden_layers)[:-1]
dim_outs = list(hidden_layers)
self.layers = []
for dim_in, dim_out, key in zip(dim_ins, dim_outs, keys_hidden):
self.layers.append(eqx.nn.Linear(dim_in, dim_out, key=key))
self.layers.append(jax.nn.relu)

# This is ann additional trainable parameter.
self.extra_bias = jax.numpy.ones(1)
self.layers.append(eqx.nn.Linear(dim_outs[-1], 1, key=key_final))

def __call__(self, x: Point, y: Point) -> float:
def __call__(self, x: Point, y: Point) -> jax.Array:
z = jnp.concatenate([x, y])

for layer in self.layers[:-1]:
z = jax.nn.relu(layer(z))
return jnp.mean(self.layers[-1](z) + self.extra_bias)
for layer in self.layers:
z = layer(z)

return z[..., 0] # return scalar

0 comments on commit a12c8b9

Please sign in to comment.