Skip to content

Commit

Permalink
Refactored SBI Example: updated description and plots
Browse files Browse the repository at this point in the history
Co-authored-by: Nicholas Tolley <[email protected]>
Signed-off-by: samadpls <[email protected]>
  • Loading branch information
samadpls and ntolley committed Aug 17, 2024
1 parent 57a5b47 commit ca2c2da
Showing 1 changed file with 44 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# Mainak Jas <[email protected]>

###############################################################################
# First, we need to import the necessary packages. The Simulation-Based
# Inference (SBI) package can be installed with ``pip install sbi``.
# Let us import ``hnn_core`` and all the necessary libraries.

import numpy as np
Expand All @@ -35,6 +37,20 @@
###############################################################################
# This function sets the parameters for our neural network model. It adds an
# evoked drive to the network with specific weights and delays.
#
# The `add_evoked_drive` function simulates external input to the network,
# mimicking sensory stimulation or other external events.
#
# - `evprox` indicates a proximal drive, targeting dendrites near the cell
# bodies.
# - `mu=40` and `sigma=5` define the timing (mean and spread) of the input.
# - `numspikes=1` means it's a single, brief stimulation.
# - `weights_ampa` and `synaptic_delays` control the strength and
# timing of the input.
#
# This evoked drive causes the initial positive deflection in the dipole
# signal, triggering a cascade of activity through the network and
# resulting in the complex waveforms observed.


def set_params(param_values, net=None):
Expand All @@ -56,13 +72,12 @@ def set_params(param_values, net=None):
# the 'weight_pyr' parameter between 10^-4 and 10^-1.


rng = np.random.default_rng(seed=42)
val = rng.uniform(-4, -1, size=n_simulations)
val = np.linspace(-4, -1, n_simulations)
param_grid = {
'weight_pyr': val.tolist()
}

net = jones_2009_model(mesh_shape=(1, 1))
net = jones_2009_model(mesh_shape=(3, 3))
batch_simulator = BatchSimulate(set_params=set_params,
net=net,
tstop=170)
Expand All @@ -89,7 +104,9 @@ def extract_dipole_data(sim_results):

###############################################################################
# Now we prepare our data for the SBI algorithm. 'thetas' are our parameters,
# and 'xs' are our observed data (the dipole activity).
# and 'xs' are our observed data (the dipole activity). These will be used by
# the SBI algorithm to learn the relationship between parameters and
# the resulting neural activity.

thetas = torch.tensor(param_grid['weight_pyr'],
dtype=torch.float32).reshape(-1, 1)
Expand All @@ -106,6 +123,11 @@ def extract_dipole_data(sim_results):
density_estimator = inference.append_simulations(thetas, xs).train()
posterior = inference.build_posterior(density_estimator)

# The prior distribution represents our initial guess about the range of
# possible values for `weight_pyr`. The SBI algorithm will use this prior,
# along with the simulated data, to build a posterior distribution, which
# represents our updated belief about `weight_pyr` after seeing the data.

###############################################################################
# This function allows us to simulate data for a single parameter value.

Expand All @@ -121,7 +143,7 @@ def simulator_batch(param):
# a parameter value that we pretend we don't know.


unknown_param = torch.tensor([[rng.uniform(-4, -1)]])
unknown_param = torch.tensor([[np.random.choice(np.linspace(-4, -1, 100))]])
x_o = simulator_batch(unknown_param.item())
samples = posterior.sample((1000,), x=x_o)

Expand All @@ -140,27 +162,35 @@ def simulator_batch(param):
plt.axvline(unknown_param.item(), color='r', linestyle='dashed',
linewidth=2, label='True Parameter')
plt.legend()
plt.savefig('posterior_distribution_log.png')
plt.xlim([-4, -1])
plt.show()

# This plot shows the posterior distribution of the inferred parameter values.
# If the inferred posterior distribution is centered around the true parameter
# value, it suggests that the SBI method is accurately capturing the underlying
# parameter. The red dashed line represents the true parameter value.

###############################################################################
# Finally, we'll evaluate the performance of our SBI method on multiple
# unseen parameter values.

unseen_params = rng.uniform(-4, -1, size=10)
unseen_params = np.linspace(-4, -1, 10)
unseen_data = [simulator_batch(param) for param in unseen_params]
unseen_samples = [posterior.sample((100,), x=x) for x in unseen_data]

plt.figure(figsize=(12, 6))
for i, (param, samples) in enumerate(zip(unseen_params, unseen_samples)):
plt.scatter([param] * len(samples), samples, alpha=0.1,
label=f'Param {i+1}' if i == 0 else '')
plt.boxplot([samples.numpy().flatten() for samples in unseen_samples],
positions=unseen_params, widths=0.05, vert=True)
plt.xlabel('True Parameter')
plt.ylabel('Inferred Parameter')
plt.title('SBI Performance on Unseen Data')
plt.savefig('sbi_performance_unseen.png')
plt.xticks(ticks=unseen_params, labels=[f'{param:.2f}'
for param in unseen_params])
plt.xlim(-4.1, -0.9)
plt.show()

###############################################################################
# In this plot, each color represents a different true parameter value. The
# spread of points for each color shows the distribution of inferred values.
# This boxplot visualizes the distribution of inferred parameters for each
# unseen true parameter value. The true parameters are shown on the x-axis,
# and the boxes represent the spread of inferred values. If the inferred
# parameters closely follow the diagonal (where inferred = true), it indicates
# that the SBI method is performing well.

0 comments on commit ca2c2da

Please sign in to comment.