From ca2c2dab231ac0fdc993e04d92c7b83d14b8e605 Mon Sep 17 00:00:00 2001 From: samadpls Date: Sat, 10 Aug 2024 20:03:42 +0500 Subject: [PATCH] Refactored SBI Example: updated description and plots Co-authored-by: Nicholas Tolley <55253912+ntolley@users.noreply.github.com> Signed-off-by: samadpls --- ...r_inference.py => sbi_hnncore_tutorial.py} | 58 ++++++++++++++----- 1 file changed, 44 insertions(+), 14 deletions(-) rename examples/howto/{sbi_hnn_core_parameter_inference.py => sbi_hnncore_tutorial.py} (70%) diff --git a/examples/howto/sbi_hnn_core_parameter_inference.py b/examples/howto/sbi_hnncore_tutorial.py similarity index 70% rename from examples/howto/sbi_hnn_core_parameter_inference.py rename to examples/howto/sbi_hnncore_tutorial.py index 60615abd5..04eddae89 100644 --- a/examples/howto/sbi_hnn_core_parameter_inference.py +++ b/examples/howto/sbi_hnncore_tutorial.py @@ -15,6 +15,8 @@ # Mainak Jas ############################################################################### +# First, we need to import the necessary packages. The Simulation-Based +# Inference (SBI) package can be installed with ``pip install sbi``. # Let us import ``hnn_core`` and all the necessary libraries. import numpy as np @@ -35,6 +37,20 @@ ############################################################################### # This function sets the parameters for our neural network model. It adds an # evoked drive to the network with specific weights and delays. +# +# The `add_evoked_drive` function simulates external input to the network, +# mimicking sensory stimulation or other external events. +# +# - `evprox` indicates a proximal drive, targeting dendrites near the cell +# bodies. +# - `mu=40` and `sigma=5` define the timing (mean and spread) of the input. +# - `numspikes=1` means it's a single, brief stimulation. +# - `weights_ampa` and `synaptic_delays` control the strength and +# timing of the input. +# +# This evoked drive causes the initial positive deflection in the dipole +# signal, triggering a cascade of activity through the network and +# resulting in the complex waveforms observed. def set_params(param_values, net=None): @@ -56,13 +72,12 @@ def set_params(param_values, net=None): # the 'weight_pyr' parameter between 10^-4 and 10^-1. -rng = np.random.default_rng(seed=42) -val = rng.uniform(-4, -1, size=n_simulations) +val = np.linspace(-4, -1, n_simulations) param_grid = { 'weight_pyr': val.tolist() } -net = jones_2009_model(mesh_shape=(1, 1)) +net = jones_2009_model(mesh_shape=(3, 3)) batch_simulator = BatchSimulate(set_params=set_params, net=net, tstop=170) @@ -89,7 +104,9 @@ def extract_dipole_data(sim_results): ############################################################################### # Now we prepare our data for the SBI algorithm. 'thetas' are our parameters, -# and 'xs' are our observed data (the dipole activity). +# and 'xs' are our observed data (the dipole activity). These will be used by +# the SBI algorithm to learn the relationship between parameters and +# the resulting neural activity. thetas = torch.tensor(param_grid['weight_pyr'], dtype=torch.float32).reshape(-1, 1) @@ -106,6 +123,11 @@ def extract_dipole_data(sim_results): density_estimator = inference.append_simulations(thetas, xs).train() posterior = inference.build_posterior(density_estimator) +# The prior distribution represents our initial guess about the range of +# possible values for `weight_pyr`. The SBI algorithm will use this prior, +# along with the simulated data, to build a posterior distribution, which +# represents our updated belief about `weight_pyr` after seeing the data. + ############################################################################### # This function allows us to simulate data for a single parameter value. @@ -121,7 +143,7 @@ def simulator_batch(param): # a parameter value that we pretend we don't know. -unknown_param = torch.tensor([[rng.uniform(-4, -1)]]) +unknown_param = torch.tensor([[np.random.choice(np.linspace(-4, -1, 100))]]) x_o = simulator_batch(unknown_param.item()) samples = posterior.sample((1000,), x=x_o) @@ -140,27 +162,35 @@ def simulator_batch(param): plt.axvline(unknown_param.item(), color='r', linestyle='dashed', linewidth=2, label='True Parameter') plt.legend() -plt.savefig('posterior_distribution_log.png') +plt.xlim([-4, -1]) plt.show() +# This plot shows the posterior distribution of the inferred parameter values. +# If the inferred posterior distribution is centered around the true parameter +# value, it suggests that the SBI method is accurately capturing the underlying +# parameter. The red dashed line represents the true parameter value. + ############################################################################### # Finally, we'll evaluate the performance of our SBI method on multiple # unseen parameter values. -unseen_params = rng.uniform(-4, -1, size=10) +unseen_params = np.linspace(-4, -1, 10) unseen_data = [simulator_batch(param) for param in unseen_params] unseen_samples = [posterior.sample((100,), x=x) for x in unseen_data] plt.figure(figsize=(12, 6)) -for i, (param, samples) in enumerate(zip(unseen_params, unseen_samples)): - plt.scatter([param] * len(samples), samples, alpha=0.1, - label=f'Param {i+1}' if i == 0 else '') +plt.boxplot([samples.numpy().flatten() for samples in unseen_samples], + positions=unseen_params, widths=0.05, vert=True) plt.xlabel('True Parameter') plt.ylabel('Inferred Parameter') plt.title('SBI Performance on Unseen Data') -plt.savefig('sbi_performance_unseen.png') +plt.xticks(ticks=unseen_params, labels=[f'{param:.2f}' + for param in unseen_params]) +plt.xlim(-4.1, -0.9) plt.show() -############################################################################### -# In this plot, each color represents a different true parameter value. The -# spread of points for each color shows the distribution of inferred values. +# This boxplot visualizes the distribution of inferred parameters for each +# unseen true parameter value. The true parameters are shown on the x-axis, +# and the boxes represent the spread of inferred values. If the inferred +# parameters closely follow the diagonal (where inferred = true), it indicates +# that the SBI method is performing well.