Skip to content

Commit

Permalink
Silencing some warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
dmnapolitano committed Oct 17, 2023
1 parent a67304f commit e54ba85
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/elexsolver/EITransitionSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pymc as pm
import pymc.sampling.jax as pmjax

from elexsolver.logging import initialize_logging
from elexsolver.TransitionSolver import TransitionSolver
Expand Down Expand Up @@ -64,7 +65,7 @@ def fit_predict(self, X, Y):
num_cols = Y.shape[0] # number of things in Y that are being transitioned "to"

# reshaping and rounding
Y_obs = np.swapaxes(Y * self._n, 0, 1).round()
Y_obs = np.transpose(Y * self._n).round()
X_extended = np.expand_dims(X, axis=2)
X_extended = np.repeat(X_extended, num_cols, axis=2)
X_extended = np.swapaxes(X_extended, 0, 1)
Expand All @@ -81,8 +82,8 @@ def fit_predict(self, X, Y):
shape=(num_units, num_cols),
)
try:
# TODO: allow other samplers; this one is very good but slow
model_trace = pm.sample(chains=self._chains, random_seed=self._seed, nuts_sampler="numpyro")
# TODO: keep trying to tune this for performance and speed
model_trace = pmjax.sample_numpyro_nuts(chains=self._chains, random_seed=self._seed, target_accept=0.95)
except Exception as e:
LOG.debug(model.debug())
raise e
Expand Down

0 comments on commit e54ba85

Please sign in to comment.