diff --git a/neurobayes/flax_nets/configs.py b/neurobayes/flax_nets/configs.py index 0a8431f..2fe874d 100644 --- a/neurobayes/flax_nets/configs.py +++ b/neurobayes/flax_nets/configs.py @@ -53,6 +53,9 @@ def extract_mlp_configs( layer_name = f"Dense{len(mlp.hidden_dims)}" configs.append({ "features": mlp.target_dim, + # Note: activation is explicitly None here, overriding any softmax + # in the original FlaxMLP. For classification tasks, softmax will + # be applied later in PartialBNN.model() "activation": None, "is_probabilistic": layer_name in probabilistic_layers, "layer_type": "fc", diff --git a/neurobayes/models/partial_bnn.py b/neurobayes/models/partial_bnn.py index 8f3e1df..8a4611a 100644 --- a/neurobayes/models/partial_bnn.py +++ b/neurobayes/models/partial_bnn.py @@ -87,7 +87,7 @@ def prior(name, shape): **({"input_dim": config['input_dim'], "kernel_size": config['kernel_size']} if config["layer_type"] == "conv" else {}) ) - + print(layer) if config['is_probabilistic']: net = random_flax_module( layer_name, layer, @@ -106,13 +106,14 @@ def prior(name, shape): } current_input = layer.apply(params, current_input) - if self.is_regression: - # Regression case + if self.is_regression: # Regression case mu = numpyro.deterministic("mu", net(current_input)) sig = numpyro.sample("sig", self.noise_prior) numpyro.sample("y", dist.Normal(mu, sig), obs=y) - else: - # Classification case + else: # Classification case + # Note: Even if the original deterministic_nn had softmax, + # it was overridden to None in extract_mlp_configs, so we + # need to apply softmax here probs = numpyro.deterministic("probs", softmax(current_input, axis=-1)) numpyro.sample("y", dist.Categorical(probs=probs), obs=y)