From a12c8b9154881a90b02310b615f490686e9df720 Mon Sep 17 00:00:00 2001 From: Frederic Grabowski Date: Wed, 28 Feb 2024 13:55:43 +0100 Subject: [PATCH] clean up equinox mlp implementation --- src/bmi/estimators/neural/_critics.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/bmi/estimators/neural/_critics.py b/src/bmi/estimators/neural/_critics.py index a2b9af5e..a84443ad 100644 --- a/src/bmi/estimators/neural/_critics.py +++ b/src/bmi/estimators/neural/_critics.py @@ -41,22 +41,23 @@ def __init__( - 8 -> 12 - 12 -> 1 """ - # We have in total the following dimensionalities: - dim_sizes = [dim_x + dim_y] + list(hidden_layers) + [1] - # ... and one layer less: - keys = jax.random.split(key, len(hidden_layers) + 1) - self.layers = [] + key_hidden, key_final = jax.random.split(key) + keys_hidden = jax.random.split(key, len(hidden_layers)) - for i, key in enumerate(keys): - self.layers.append(eqx.nn.Linear(dim_sizes[i], dim_sizes[i + 1], key=key)) + dim_ins = [dim_x + dim_y] + list(hidden_layers)[:-1] + dim_outs = list(hidden_layers) + self.layers = [] + for dim_in, dim_out, key in zip(dim_ins, dim_outs, keys_hidden): + self.layers.append(eqx.nn.Linear(dim_in, dim_out, key=key)) + self.layers.append(jax.nn.relu) - # This is ann additional trainable parameter. - self.extra_bias = jax.numpy.ones(1) + self.layers.append(eqx.nn.Linear(dim_outs[-1], 1, key=key_final)) - def __call__(self, x: Point, y: Point) -> float: + def __call__(self, x: Point, y: Point) -> jax.Array: z = jnp.concatenate([x, y]) - for layer in self.layers[:-1]: - z = jax.nn.relu(layer(z)) - return jnp.mean(self.layers[-1](z) + self.extra_bias) + for layer in self.layers: + z = layer(z) + + return z[..., 0] # return scalar