Skip to content

Commit

Permalink
Handle latest PyMC/PyTensor breaking changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jun 26, 2024
1 parent dfe3fe0 commit 94790c7
Show file tree
Hide file tree
Showing 13 changed files with 62 additions and 51 deletions.
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ dependencies:
- xhistogram
- statsmodels
- pip:
- pymc>=5.13.0 # CI was failing to resolve
- pymc>=5.16.1 # CI was failing to resolve
- blackjax
- scikit-learn
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ dependencies:
- xhistogram
- statsmodels
- pip:
- pymc>=5.13.0 # CI was failing to resolve
- pymc>=5.16.1 # CI was failing to resolve
- blackjax
- scikit-learn
7 changes: 3 additions & 4 deletions pymc_experimental/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
The imports from pymc are not fully replicated here: add imports as necessary.
"""

from typing import List, Tuple, Union
from typing import Tuple, Union

import numpy as np
import pytensor.tensor as pt
Expand All @@ -37,8 +37,7 @@

class GenExtremeRV(RandomVariable):
name: str = "Generalized Extreme Value"
ndim_supp: int = 0
ndims_params: List[int] = [0, 0, 0]
signature = "(),(),()->()"
dtype: str = "floatX"
_print_name: Tuple[str, str] = ("Generalized Extreme Value", "\\operatorname{GEV}")

Expand Down Expand Up @@ -275,7 +274,7 @@ def chi_dist(nu: TensorVariable, size: TensorVariable) -> TensorVariable:

def __new__(cls, name, nu, **kwargs):
if "observed" not in kwargs:
kwargs.setdefault("transform", transforms.log)
kwargs.setdefault("default_transform", transforms.log)
return CustomDist(name, nu, dist=cls.chi_dist, class_name="Chi", **kwargs)

@classmethod
Expand Down
3 changes: 1 addition & 2 deletions pymc_experimental/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ def log1mexp(x):

class GeneralizedPoissonRV(RandomVariable):
name = "generalized_poisson"
ndim_supp = 0
ndims_params = [0, 0]
signature = "(),()->()"
dtype = "int64"
_print_name = ("GeneralizedPoisson", "\\operatorname{GeneralizedPoisson}")

Expand Down
36 changes: 26 additions & 10 deletions pymc_experimental/model/marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pytensor.graph.replace import graph_replace, vectorize_graph
from pytensor.scan import map as scan_map
from pytensor.tensor import TensorType, TensorVariable
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.shape import Shape
from pytensor.tensor.special import log_softmax

Expand Down Expand Up @@ -598,7 +598,18 @@ def is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs):
fg = FunctionGraph(outputs=output_rvs, clone=False)

non_elemwise_blockers = [
o for node in fg.apply_nodes if not isinstance(node.op, Elemwise) for o in node.outputs
o
for node in fg.apply_nodes
if not (
isinstance(node.op, Elemwise)
# Allow expand_dims on the left
or (
isinstance(node.op, DimShuffle)
and not node.op.drop
and node.op.shuffle == sorted(node.op.shuffle)
)
)
for o in node.outputs
]
blocker_candidates = [rv_to_marginalize] + other_input_rvs + non_elemwise_blockers
blockers = [var for var in blocker_candidates if var not in output_rvs]
Expand Down Expand Up @@ -698,16 +709,17 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs

def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
op = rv.owner.op
dist_params = rv.owner.op.dist_params(rv.owner)
if isinstance(op, Bernoulli):
return (0, 1)
elif isinstance(op, Categorical):
p_param = rv.owner.inputs[3]
[p_param] = dist_params
return tuple(range(pt.get_vector_length(p_param)))
elif isinstance(op, DiscreteUniform):
lower, upper = constant_fold(rv.owner.inputs[3:])
lower, upper = constant_fold(dist_params)
return tuple(np.arange(lower, upper + 1))
elif isinstance(op, DiscreteMarkovChain):
P = rv.owner.inputs[0]
P, *_ = dist_params
return tuple(range(pt.get_vector_length(P[-1])))

raise NotImplementedError(f"Cannot compute domain for op {op}")
Expand Down Expand Up @@ -827,11 +839,15 @@ def marginal_hmm_logp(op, values, *inputs, **kwargs):
# This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1})
# We do it entirely in logs, though.

# To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) under
# the initial distribution. This is robust to everything the user can throw at it.
batch_logp_init_dist = pt.vectorize(lambda x: logp(init_dist_, x), "()->()")(
batch_chain_value[..., 0]
)
# To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states)
# under the initial distribution. This is robust to everything the user can throw at it.
init_dist_value = init_dist_.type()
logp_init_dist = logp(init_dist_, init_dist_value)
# There is a degerate batch dim for lags=1 (the only supported case),
# that we have to work around, by expanding the batch value and then squeezing it out of the logp
batch_logp_init_dist = vectorize_graph(
logp_init_dist, {init_dist_value: batch_chain_value[:, None, ..., 0]}
).squeeze(1)
log_alpha_init = batch_logp_init_dist + batch_logp_emissions[..., 0]

def step_alpha(logp_emission, log_alpha, log_P):
Expand Down
18 changes: 11 additions & 7 deletions pymc_experimental/model/transforms/autoreparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import pytensor
import pytensor.tensor as pt
import scipy.special
from pymc.distributions import SymbolicRandomVariable
from pymc.exceptions import NotConstantValueError
from pymc.logprob.transforms import Transform
from pymc.model.fgraph import (
ModelDeterministic,
Expand All @@ -17,7 +19,7 @@
model_from_fgraph,
model_named,
)
from pymc.pytensorf import toposort_replace
from pymc.pytensorf import constant_fold, toposort_replace
from pytensor.graph.basic import Apply, Variable
from pytensor.tensor.random.op import RandomVariable

Expand Down Expand Up @@ -170,14 +172,16 @@ def vip_reparam_node(
dims: List[Variable],
transform: Optional[Transform],
) -> Tuple[ModelDeterministic, ModelNamed]:
if not isinstance(node.op, RandomVariable):
if not isinstance(node.op, RandomVariable | SymbolicRandomVariable):
raise TypeError("Op should be RandomVariable type")
size = node.inputs[1]
if not isinstance(size, pt.TensorConstant):
rv = node.default_output()
try:
[rv_shape] = constant_fold([rv.shape])
except NotConstantValueError:
raise ValueError("Size should be static for autoreparametrization.")
logit_lam_ = pytensor.shared(
np.zeros(size.data),
shape=size.data,
np.zeros(rv_shape),
shape=rv_shape,
name=f"{name}::lam_logit__",
)
logit_lam = model_named(logit_lam_, *dims)
Expand Down Expand Up @@ -216,7 +220,7 @@ def _(
transform: Optional[Transform],
lam: pt.TensorVariable,
) -> ModelDeterministic:
rng, size, _, loc, scale = node.inputs
rng, size, loc, scale = node.inputs
if transform is not None:
raise NotImplementedError("Reparametrization of Normal with Transform is not implemented")
vip_rv_ = pm.Normal.dist(
Expand Down
25 changes: 9 additions & 16 deletions pymc_experimental/tests/model/test_marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import pytensor.tensor as pt
import pytest
from arviz import InferenceData, dict_to_dataset
from pymc import ImputationWarning, inputvars
from pymc.distributions import transforms
from pymc.logprob.abstract import _logprob
from pymc.model.fgraph import fgraph_from_model
from pymc.pytensorf import inputvars
from pymc.util import UNSET
from scipy.special import log_softmax, logsumexp
from scipy.stats import halfnorm, norm
Expand Down Expand Up @@ -45,9 +45,7 @@ def disaster_model():
early_rate = pm.Exponential("early_rate", 1.0, initval=3)
late_rate = pm.Exponential("late_rate", 1.0, initval=1)
rate = pm.math.switch(switchpoint >= years, early_rate, late_rate)
with pytest.warns(ImputationWarning), pytest.warns(
RuntimeWarning, match="invalid value encountered in cast"
):
with pytest.warns(Warning):
disasters = pm.Poisson("disasters", rate, observed=disaster_data)

return disaster_model, years
Expand Down Expand Up @@ -294,7 +292,7 @@ def test_recover_marginals_basic():

with m:
prior = pm.sample_prior_predictive(
samples=20,
draws=20,
random_seed=rng,
return_inferencedata=False,
)
Expand Down Expand Up @@ -337,7 +335,7 @@ def test_recover_marginals_coords():

with m:
prior = pm.sample_prior_predictive(
samples=20,
draws=20,
random_seed=rng,
return_inferencedata=False,
)
Expand All @@ -364,7 +362,7 @@ def test_recover_batched_marginal():

with m:
prior = pm.sample_prior_predictive(
samples=20,
draws=20,
random_seed=rng,
return_inferencedata=False,
)
Expand Down Expand Up @@ -394,7 +392,7 @@ def test_nested_recover_marginals():

with m:
prior = pm.sample_prior_predictive(
samples=20,
draws=20,
random_seed=rng,
return_inferencedata=False,
)
Expand Down Expand Up @@ -565,7 +563,7 @@ def test_marginalized_transforms(transform, expected_warning):
w=w,
comp_dists=pm.HalfNormal.dist([1, 2, 3]),
initval=initval,
transform=transform,
default_transform=transform,
)
y = pm.Normal("y", 0, sigma, observed=data)

Expand All @@ -583,7 +581,7 @@ def test_marginalized_transforms(transform, expected_warning):
),
),
initval=initval,
transform=transform,
default_transform=transform,
)
y = pm.Normal("y", 0, sigma, observed=data)

Expand Down Expand Up @@ -710,12 +708,7 @@ def test_marginalized_hmm_normal_emission(batch_chain, batch_emission):

@pytest.mark.parametrize(
"categorical_emission",
[
False,
# Categorical has a core vector parameter,
# so it is not possible to build a graph that uses elemwise operations exclusively
pytest.param(True, marks=pytest.mark.xfail(raises=NotImplementedError)),
],
[False, True],
)
def test_marginalized_hmm_categorical_emission(categorical_emission):
"""Example adapted from https://www.youtube.com/watch?v=9-sPm4CfcD0"""
Expand Down
2 changes: 1 addition & 1 deletion pymc_experimental/tests/statespace/test_SARIMAX.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def test_interpretable_raises_if_d_nonzero():

def test_interpretable_states_are_interpretable(arima_mod_interp, pymc_mod_interp):
with pymc_mod_interp:
prior = pm.sample_prior_predictive(samples=10)
prior = pm.sample_prior_predictive(draws=10)

prior_outputs = arima_mod_interp.sample_unconditional_prior(prior)
ar_lags = prior.prior.coords["ar_lag"].values - 1
Expand Down
2 changes: 1 addition & 1 deletion pymc_experimental/tests/statespace/test_VARMAX.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def pymc_mod(varma_mod, data):
@pytest.fixture(scope="session")
def idata(pymc_mod, rng):
with pymc_mod:
idata = pm.sample_prior_predictive(samples=10, random_seed=rng)
idata = pm.sample_prior_predictive(draws=10, random_seed=rng)

return idata

Expand Down
6 changes: 3 additions & 3 deletions pymc_experimental/tests/statespace/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_lgss_distribution_from_steps(output_name, ss_mod_me, pymc_model_2):
latent_states, obs_states = LinearGaussianStateSpace("states", *matrices, steps=100)
# pylint: enable=unpacking-non-sequence

idata = pm.sample_prior_predictive(samples=10)
idata = pm.sample_prior_predictive(draws=10)
delete_rvs_from_model(["states_latent", "states_observed", "states_combined"])

assert idata.prior.coords["states_latent_dim_0"].shape == (101,)
Expand All @@ -144,7 +144,7 @@ def test_lgss_distribution_with_dims(output_name, ss_mod_me, pymc_model_2):
"states", *matrices, steps=100, dims=[TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]
)
# pylint: enable=unpacking-non-sequence
idata = pm.sample_prior_predictive(samples=10)
idata = pm.sample_prior_predictive(draws=10)
delete_rvs_from_model(["states_latent", "states_observed", "states_combined"])

assert idata.prior.coords["time"].shape == (101,)
Expand Down Expand Up @@ -198,7 +198,7 @@ def test_lgss_with_time_varying_inputs(output_name, rng):
dims=[TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]
)
# pylint: enable=unpacking-non-sequence
idata = pm.sample_prior_predictive(samples=10)
idata = pm.sample_prior_predictive(draws=10)

assert idata.prior.coords["time"].shape == (10,)
assert all(
Expand Down
4 changes: 2 additions & 2 deletions pymc_experimental/tests/statespace/test_statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def exog_pymc_mod(exog_ss_mod, rng):
def idata(pymc_mod, rng):
with pymc_mod:
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
idata_prior = pm.sample_prior_predictive(samples=10, random_seed=rng)
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)

idata.extend(idata_prior)
return idata
Expand All @@ -145,7 +145,7 @@ def idata(pymc_mod, rng):
def idata_exog(exog_pymc_mod, rng):
with exog_pymc_mod:
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
idata_prior = pm.sample_prior_predictive(samples=10, random_seed=rng)
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
idata.extend(idata_prior)
return idata

Expand Down
4 changes: 2 additions & 2 deletions pymc_experimental/tests/statespace/test_structural.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def test_filter_scans_time_varying_design_matrix(rng):
x0, P0, c, d, T, Z, R, H, Q = mod.unpack_statespace()
pm.Deterministic("Z", Z)

prior = pm.sample_prior_predictive(samples=10)
prior = pm.sample_prior_predictive(draws=10)

prior_Z = prior.prior.Z.values
assert prior_Z.shape == (1, 10, 100, 1, 2)
Expand Down Expand Up @@ -790,7 +790,7 @@ def test_extract_components_from_idata(rng):
mod.build_statespace_graph(y)

x0, P0, c, d, T, Z, R, H, Q = mod.unpack_statespace()
prior = pm.sample_prior_predictive(samples=10)
prior = pm.sample_prior_predictive(draws=10)

filter_prior = mod.sample_conditional_prior(prior)
comp_prior = mod.extract_components_from_idata(filter_prior)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pymc>=5.13.0
pymc>=5.16.1
scikit-learn

0 comments on commit 94790c7

Please sign in to comment.