Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding stresstest function to MechanisticInferer #199

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions mechanistic_model/mechanistic_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
import numpy as np
import numpyro # type: ignore
from diffrax import Solution
from jax import random
from jax.random import PRNGKey
from numpyro import distributions as Dist
from numpyro.diagnostics import summary # type: ignore
from numpyro.handlers import seed, trace # type: ignore
from numpyro.infer import MCMC, NUTS # type: ignore
from numpyro.infer.util import potential_energy

import mechanistic_model.utils as utils
from config.config import Config
Expand Down Expand Up @@ -408,8 +411,7 @@ def load_posterior_particle(
"run self.infer() first to produce posterior particles or pass externally produced particles"
)
if tf is None:
# run for same amount of timesteps as given in inference
# given to exists since self.infer_complete is True
# run for same amount of timesteps as given print(e) # given to exists since self.infer_complete is True
if hasattr(self, "inference_timesteps"):
tf = self.inference_timesteps
# unless user is using external_posterior, we may have not inferred yet
Expand Down Expand Up @@ -466,3 +468,58 @@ def _load_posterior_single_particle(
sol_dct = substituted_model(tf=tf, infer_mode=False)
sol_dct["posteriors"] = substituted_model.data
return sol_dct

def stresstest(
self, N: int, scale: float = 1, **kwargs
) -> jax.typing.ArrayLike:
"""
Perform a stress test on the model by generating random parameter values and
checking if the model fails for each parameter set. Model calls use `numpyro.infer.util.potential_energy`
with random parameters in unconstrained domain. Any parameter set causing a sample fail,
or returning `NaN` or `Inf` potential are returned.

Parameters
------------
N (int):
The number of random parameter sets to generate for stress testing.
scale (float, optional):
A scaling factor to apply to the random parameter values. Defaults to 1.
kwargs:
Key word arguments passed to `loglikelihood`.

Returns
---------------
List[Dict[str, Any]]: A list of failing parameter sets, where each parameter set is a dictionary
mapping parameter keys to their corresponding values.
"""
# Execute the model to collect parameter keys
exec_trace = trace(
seed(
self.likelihood,
jax.random.PRNGKey(self.config.INFERENCE_PRNGKEY),
)
).get_trace(kwargs)
# Generate random parameter values with cauchy distribution
rand_vars = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets maybe look for a way to only randomly vary parameters we actually sample, rather than all numpyro parameters (which includes data)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can do this by passing the data you want to pass into kwargs, sampling from the model prior and using that to create the stress test params.

random.cauchy(rk, (len(exec_trace.keys()),))
for rk in random.split(
jax.random.PRNGKey(self.config.INFERENCE_PRNGKEY), N
)
]
rand_params = [
{key: scale * x[i] for i, key in enumerate(exec_trace.keys())}
for x in rand_vars
]
failing_params = []
for param in rand_params:
try:
# potential_energy should raise an exception if the model fails
# and ingests parameters on the unconstrained domain
pe = potential_energy(self.likelihood, {}, {}, param)
if bool(jnp.isnan(pe)):
failing_params.append(param)
if bool(jnp.isinf(pe)):
failing_params.append(param)
except Exception as _:
failing_params.append(param)
return failing_params
7 changes: 7 additions & 0 deletions tests/test_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ def test_load_posterior_particle():
), "load_posterior_particle produced different timeline shapes than what was fit on"


def test_stresstest_runs():
failed_params = inferer.stresstest(1000, tf=10)
assert (
len(failed_params) >= 0
), "Params causing failure not returning as a list"


def test_external_posteriors():
load_across_chains = [
(chain, 0) for chain in range(inferer.config.INFERENCE_NUM_CHAINS)
Expand Down
Loading