Skip to content

Commit

Permalink
Add clarifying notes on softmax activation
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Jan 5, 2025
1 parent 2127ac1 commit 465f7d1
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
3 changes: 3 additions & 0 deletions neurobayes/flax_nets/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def extract_mlp_configs(
layer_name = f"Dense{len(mlp.hidden_dims)}"
configs.append({
"features": mlp.target_dim,
# Note: activation is explicitly None here, overriding any softmax
# in the original FlaxMLP. For classification tasks, softmax will
# be applied later in PartialBNN.model()
"activation": None,
"is_probabilistic": layer_name in probabilistic_layers,
"layer_type": "fc",
Expand Down
11 changes: 6 additions & 5 deletions neurobayes/models/partial_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def prior(name, shape):
**({"input_dim": config['input_dim'],
"kernel_size": config['kernel_size']} if config["layer_type"] == "conv" else {})
)

print(layer)
if config['is_probabilistic']:
net = random_flax_module(
layer_name, layer,
Expand All @@ -106,13 +106,14 @@ def prior(name, shape):
}
current_input = layer.apply(params, current_input)

if self.is_regression:
# Regression case
if self.is_regression: # Regression case
mu = numpyro.deterministic("mu", net(current_input))
sig = numpyro.sample("sig", self.noise_prior)
numpyro.sample("y", dist.Normal(mu, sig), obs=y)
else:
# Classification case
else: # Classification case
# Note: Even if the original deterministic_nn had softmax,
# it was overridden to None in extract_mlp_configs, so we
# need to apply softmax here
probs = numpyro.deterministic("probs", softmax(current_input, axis=-1))
numpyro.sample("y", dist.Categorical(probs=probs), obs=y)

Expand Down

0 comments on commit 465f7d1

Please sign in to comment.