-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactored SBI Example: updated description and plots
Co-authored-by: Nicholas Tolley <[email protected]> Signed-off-by: samadpls <[email protected]>
- Loading branch information
Showing
1 changed file
with
44 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,8 @@ | |
# Mainak Jas <[email protected]> | ||
|
||
############################################################################### | ||
# First, we need to import the necessary packages. The Simulation-Based | ||
# Inference (SBI) package can be installed with ``pip install sbi``. | ||
# Let us import ``hnn_core`` and all the necessary libraries. | ||
|
||
import numpy as np | ||
|
@@ -35,6 +37,20 @@ | |
############################################################################### | ||
# This function sets the parameters for our neural network model. It adds an | ||
# evoked drive to the network with specific weights and delays. | ||
# | ||
# The `add_evoked_drive` function simulates external input to the network, | ||
# mimicking sensory stimulation or other external events. | ||
# | ||
# - `evprox` indicates a proximal drive, targeting dendrites near the cell | ||
# bodies. | ||
# - `mu=40` and `sigma=5` define the timing (mean and spread) of the input. | ||
# - `numspikes=1` means it's a single, brief stimulation. | ||
# - `weights_ampa` and `synaptic_delays` control the strength and | ||
# timing of the input. | ||
# | ||
# This evoked drive causes the initial positive deflection in the dipole | ||
# signal, triggering a cascade of activity through the network and | ||
# resulting in the complex waveforms observed. | ||
|
||
|
||
def set_params(param_values, net=None): | ||
|
@@ -56,13 +72,12 @@ def set_params(param_values, net=None): | |
# the 'weight_pyr' parameter between 10^-4 and 10^-1. | ||
|
||
|
||
rng = np.random.default_rng(seed=42) | ||
val = rng.uniform(-4, -1, size=n_simulations) | ||
val = np.linspace(-4, -1, n_simulations) | ||
param_grid = { | ||
'weight_pyr': val.tolist() | ||
} | ||
|
||
net = jones_2009_model(mesh_shape=(1, 1)) | ||
net = jones_2009_model(mesh_shape=(3, 3)) | ||
batch_simulator = BatchSimulate(set_params=set_params, | ||
net=net, | ||
tstop=170) | ||
|
@@ -89,7 +104,9 @@ def extract_dipole_data(sim_results): | |
|
||
############################################################################### | ||
# Now we prepare our data for the SBI algorithm. 'thetas' are our parameters, | ||
# and 'xs' are our observed data (the dipole activity). | ||
# and 'xs' are our observed data (the dipole activity). These will be used by | ||
# the SBI algorithm to learn the relationship between parameters and | ||
# the resulting neural activity. | ||
|
||
thetas = torch.tensor(param_grid['weight_pyr'], | ||
dtype=torch.float32).reshape(-1, 1) | ||
|
@@ -106,6 +123,11 @@ def extract_dipole_data(sim_results): | |
density_estimator = inference.append_simulations(thetas, xs).train() | ||
posterior = inference.build_posterior(density_estimator) | ||
|
||
# The prior distribution represents our initial guess about the range of | ||
# possible values for `weight_pyr`. The SBI algorithm will use this prior, | ||
# along with the simulated data, to build a posterior distribution, which | ||
# represents our updated belief about `weight_pyr` after seeing the data. | ||
|
||
############################################################################### | ||
# This function allows us to simulate data for a single parameter value. | ||
|
||
|
@@ -121,7 +143,7 @@ def simulator_batch(param): | |
# a parameter value that we pretend we don't know. | ||
|
||
|
||
unknown_param = torch.tensor([[rng.uniform(-4, -1)]]) | ||
unknown_param = torch.tensor([[np.random.choice(np.linspace(-4, -1, 100))]]) | ||
x_o = simulator_batch(unknown_param.item()) | ||
samples = posterior.sample((1000,), x=x_o) | ||
|
||
|
@@ -140,27 +162,35 @@ def simulator_batch(param): | |
plt.axvline(unknown_param.item(), color='r', linestyle='dashed', | ||
linewidth=2, label='True Parameter') | ||
plt.legend() | ||
plt.savefig('posterior_distribution_log.png') | ||
plt.xlim([-4, -1]) | ||
plt.show() | ||
|
||
# This plot shows the posterior distribution of the inferred parameter values. | ||
# If the inferred posterior distribution is centered around the true parameter | ||
# value, it suggests that the SBI method is accurately capturing the underlying | ||
# parameter. The red dashed line represents the true parameter value. | ||
|
||
############################################################################### | ||
# Finally, we'll evaluate the performance of our SBI method on multiple | ||
# unseen parameter values. | ||
|
||
unseen_params = rng.uniform(-4, -1, size=10) | ||
unseen_params = np.linspace(-4, -1, 10) | ||
unseen_data = [simulator_batch(param) for param in unseen_params] | ||
unseen_samples = [posterior.sample((100,), x=x) for x in unseen_data] | ||
|
||
plt.figure(figsize=(12, 6)) | ||
for i, (param, samples) in enumerate(zip(unseen_params, unseen_samples)): | ||
plt.scatter([param] * len(samples), samples, alpha=0.1, | ||
label=f'Param {i+1}' if i == 0 else '') | ||
plt.boxplot([samples.numpy().flatten() for samples in unseen_samples], | ||
positions=unseen_params, widths=0.05, vert=True) | ||
plt.xlabel('True Parameter') | ||
plt.ylabel('Inferred Parameter') | ||
plt.title('SBI Performance on Unseen Data') | ||
plt.savefig('sbi_performance_unseen.png') | ||
plt.xticks(ticks=unseen_params, labels=[f'{param:.2f}' | ||
for param in unseen_params]) | ||
plt.xlim(-4.1, -0.9) | ||
plt.show() | ||
|
||
############################################################################### | ||
# In this plot, each color represents a different true parameter value. The | ||
# spread of points for each color shows the distribution of inferred values. | ||
# This boxplot visualizes the distribution of inferred parameters for each | ||
# unseen true parameter value. The true parameters are shown on the x-axis, | ||
# and the boxes represent the spread of inferred values. If the inferred | ||
# parameters closely follow the diagonal (where inferred = true), it indicates | ||
# that the SBI method is performing well. |