Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding stress testing utility for model inference to help isolate cause of stochastic failure in inference #197

Open
SamuelBrand1 opened this issue Jul 16, 2024 · 3 comments
Labels
enhancement New feature or request wontfix This will not be worked on

Comments

@SamuelBrand1
Copy link
Contributor

From f2f discussion there are occasional stochastic failures in the NUTS sampling procedure. The characteristics of these failures are:

  • They occur during NUTS warm up.
  • They cause either errors with max step size violation in the ODE solver, or opaque errors from XLA calls.
  • They occur for models with vaccinations, but not as much for models without vaccination.

These are suggestive of either sampling into a numerically unstable portion parameter space during warm up and/or some kind of numerical instability associated with vaccination rates.

A good first step towards isolation of the problem would be a stress test utility e.g like this. Its worth investigating the existing numpyro utilities before rolling our own solution.

@SamuelBrand1 SamuelBrand1 added the enhancement New feature or request label Jul 16, 2024
@SamuelBrand1
Copy link
Contributor Author

I've raised an issue with numpyro so if there is an already existing solution/utility that will get flagged. pyro-ppl/numpyro#1833

@kokbent
Copy link
Collaborator

kokbent commented Jul 17, 2024

Some more context on XLARuntimeError and max_step error:

Some failures appear early:

Running chain 3:   0%|          | 0/2000 [02:29<?, ?it/s]�[A�[A�[Ajax.pure_callback failed
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/site-packages/jax/_src/callback.py", line 77, in pure_callback_impl
    return callback(*args)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/callback.py", line 65, in __call__
    return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
  File "/usr/local/lib/python3.10/site-packages/equinox/_errors.py", line 70, in raises
    raise EqxRuntimeError(msgs[_index.item()])
equinox._errors.EqxRuntimeError: The maximum number of solver steps was reached. Try increasing `max_steps`.

Sometimes it's late (but still warm up?)

Running chain 1:  45%|####5     | 900/2000 [2:19:32<1:31:54,  5.01s/it]�[Ajax.pure_callback failed
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/site-packages/jax/_src/callback.py", line 77, in pure_callback_impl
    return callback(*args)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/callback.py", line 65, in __call__
    return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
  File "/usr/local/lib/python3.10/site-packages/equinox/_errors.py", line 70, in raises
    raise EqxRuntimeError(msgs[_index.item()])
equinox._errors.EqxRuntimeError: The maximum number of solver steps was reached. Try increasing `max_steps`.

Sometimes only one or two in the four chains failed, sometimes all four chains failed. Regardless, the fit ends with following errors:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/input/exp/fifty_state_6strain_2202_2407/smh_6str_prelim_5/run_task.py", line 264, in <module>
    runner.process_state(state, jobid, jobid_in_path=True)
  File "/input/exp/fifty_state_6strain_2202_2407/smh_6str_prelim_5/run_task.py", line 221, in process_state
    inferer.infer(
  File "/input/exp/fifty_state_6strain_2202_2407/smh_6str_prelim_5/inferer_smh.py", line 59, in infer
    self.inference_algo.run(
  File "/usr/local/lib/python3.10/site-packages/numpyro/infer/mcmc.py", line 678, in run
    states_flat = tree_map(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 321, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 321, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/usr/local/lib/python3.10/site-packages/numpyro/infer/mcmc.py", line 680, in <lambda>
    lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],) + x.shape[2:]),
  File "/usr/local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 787, in reshape
    return a.reshape(newshape, order=order)  # type: ignore[call-overload,union-attr]
  File "/usr/local/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 150, in _reshape
    return lax.reshape(a, newshape, None)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 892, in reshape
    return reshape_p.bind(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 422, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 425, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 913, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
    outs = fun(*args)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/array.py", line 940, in _array_shard_arg
    return shard_sharded_device_array_slow_path(x, devices, indices, sharding)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/array.py", line 908, in shard_sharded_device_array_slow_path
    return pxla.shard_arg(x._value, sharding, canonicalize=False)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/array.py", line 640, in _value
    npy_value[ind] = arr._single_device_array_to_np_array()
jaxlib.xla_extension.XlaRuntimeError

@arik-shurygin
Copy link
Collaborator

tagging wontfix as its parent issue #244 is also tagged that for now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request wontfix This will not be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants