Skip to content

Commit

Permalink
Implement truncated responses (#752)
Browse files Browse the repository at this point in the history
* Implement truncated responses

* Remove event type from publish-docs.yml
  • Loading branch information
tomicapretto authored Nov 11, 2023
1 parent ee266d9 commit c6e5dbb
Show file tree
Hide file tree
Showing 12 changed files with 337 additions and 52 deletions.
1 change: 0 additions & 1 deletion .github/workflows/publish-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ on:
- opened
- reopened
- synchronize
- closed

jobs:
build:
Expand Down
9 changes: 5 additions & 4 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,9 @@ good-names=b,
nu,
X_terms,
Z_terms,
u
u,
ub,
lb


# Include a hint for the correct naming format with invalid-name
Expand Down Expand Up @@ -496,6 +498,5 @@ min-public-methods=2

[EXCEPTIONS]

# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=Exception
# Exceptions that will emit a warning when being caught.
overgeneral-exceptions=builtins.Exception
31 changes: 30 additions & 1 deletion bambi/backend/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def build_response_distribution(self, kwargs, pymc_backend):

kwargs = self.robustify_dims(pymc_backend, kwargs)

# Handle censoring
if self.term.is_censored:
dims = kwargs.pop("dims", None)
data_matrix = kwargs.pop("observed")
Expand All @@ -295,6 +296,34 @@ def build_response_distribution(self, kwargs, pymc_backend):
dist_rv = pm.Censored(
self.name, stateless_dist, lower=lower, upper=upper, observed=observed, dims=dims
)

# Handle truncation
elif self.term.is_truncated:
dims = kwargs.pop("dims", None)
data_matrix = kwargs.pop("observed")

# Get values of the response variable
observed = np.squeeze(data_matrix[:, 0])

# Get truncation values
lower = np.squeeze(data_matrix[:, 1])
upper = np.squeeze(data_matrix[:, 2])

# Handle 'None' and scalars appropriately
if np.all(lower == -np.inf):
lower = None
elif np.all(lower == lower[0]):
lower = lower[0]

if np.all(upper == np.inf):
upper = None
elif np.all(upper == upper[0]):
upper = upper[0]

stateless_dist = distribution.dist(**kwargs)
dist_rv = pm.Truncated(
self.name, stateless_dist, lower=lower, upper=upper, observed=observed, dims=dims
)
else:
dist_rv = distribution(self.name, **kwargs)

Expand All @@ -316,7 +345,7 @@ def robustify_dims(self, pymc_backend, kwargs):
if isinstance(self.family, (Multinomial, DirichletMultinomial)):
return kwargs

if self.term.is_censored:
if self.term.is_censored or self.term.is_truncated:
return kwargs

dims, data = kwargs["dims"], kwargs["observed"]
Expand Down
3 changes: 2 additions & 1 deletion bambi/terms/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

from bambi.terms.base import BaseTerm

from bambi.terms.utils import is_censored_response
from bambi.terms.utils import is_censored_response, is_truncated_response


class ResponseTerm(BaseTerm):
def __init__(self, response, family):
self.term = response.term.term
self.family = family
self.is_censored = is_censored_response(self.term)
self.is_truncated = is_truncated_response(self.term)

@property
def term(self):
Expand Down
10 changes: 10 additions & 0 deletions bambi/terms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,13 @@ def is_censored_response(term):
if not is_call_component(component):
return False
return is_call_of_kind(component, "censored")


def is_truncated_response(term):
"""Determines if a formulae term represents a truncated response"""
if not is_single_component(term):
return False
component = term.components[0] # get the first (and single) component
if not is_call_component(component):
return False
return is_call_of_kind(component, "truncated")
60 changes: 60 additions & 0 deletions bambi/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,65 @@ def censored(*args):
censored.__metadata__ = {"kind": "censored"}


def truncated(x, lb=None, ub=None):
"""Construct array for truncated response
Parameters
----------
x : np.ndarray
The values of the truncated variable.
lb : int, float, np.ndarray
A number or an array indicating the lower truncation bound.
ub : int, float, np.ndarray
A number or an array indicating the upper truncation bound.
Returns
-------
np.ndarray
Array of shape (n, 3). The first column contains the values of the variable,
the second column the values for the lower bound, and the third variable
the values for the upper bound.
"""
x = np.asarray(x)

if x.ndim != 1:
raise ValueError("'truncated' only works with 1-dimensional arrays")

if lb is None and ub is None:
raise ValueError("'lb' and 'ub' cannot both be None")

# Process lower bound so we get an 1d array with the adequate values
if lb is not None:
lower = np.asarray(lb)
if lower.ndim == 0:
lower = np.full(len(x), lower)
elif lower.ndim == 1:
assert len(lower) == len(x), "The length of 'lb' must be equal to the one of 'x'"
else:
raise ValueError("'lb' must be 0 or 1 dimensional.")
else:
lower = np.full(len(x), -np.inf)

# Process upper bound so we get an 1d array with the adequate values
if ub is not None:
upper = np.asarray(ub)
if upper.ndim == 0:
upper = np.full(len(x), upper)
elif upper.ndim == 1:
assert len(upper) == len(x), "The length of 'ub' must be equal to the one of 'x'"
else:
raise ValueError("'ub' must be 0 or 1 dimensional.")
else:
upper = np.full(len(x), np.inf)

# Construct output matrix
result = np.column_stack([x, lower, upper])

return result


truncated.__metadata__ = {"kind": "truncated"}

# pylint: disable = invalid-name
@register_stateful_transform
class HSGP: # pylint: disable = too-many-instance-attributes
Expand Down Expand Up @@ -316,6 +375,7 @@ def get_distance(x):
transformations_namespace = {
"c": c,
"censored": censored,
"truncated": truncated,
"log": np.log,
"log2": np.log2,
"log10": np.log10,
Expand Down
26 changes: 26 additions & 0 deletions tests/test_built_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,3 +1047,29 @@ def test_censored_response():
idata = model.fit(tune=100, draws=100, random_seed=121195)
model.predict(idata, kind="pps")
model.predict(idata, data=data, kind="pps")


def test_truncated_response():
rng = np.random.default_rng(12345)
slope, intercept, sigma, N = 1, 0, 2, 200
x = rng.uniform(-10, 10, N)
y = rng.normal(loc=slope * x + intercept, scale=sigma)

def truncate_y(x, y, bounds):

return (x[keep], y[keep])

bounds = [-5, 5]
keep = (y >= bounds[0]) & (y <= bounds[1])
xt = x[keep]
yt = y[keep]

df = pd.DataFrame({"x": xt, "y": yt})
priors = {
"Intercept": bmb.Prior("Normal", mu=0, sigma=1),
"x": bmb.Prior("Normal", mu=0, sigma=1),
"sigma": bmb.Prior("HalfNormal", sigma=1),
}
model = bmb.Model("truncated(y, -5, 5) ~ x", df, priors=priors)
idata = model.fit(tune=100, draws=100, random_seed=1234)
model.predict(idata, kind="pps")
2 changes: 1 addition & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ def test_config():
config.INTERPRET_VERBOSE = "anything"

with pytest.raises(KeyError, match="'DOESNT_EXIST' is not a valid configuration option"):
config.DOESNT_EXIST = "anything"
config.DOESNT_EXIST = "anything"
3 changes: 3 additions & 0 deletions tests/test_interpret_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ def mtcars():
idata = model.fit(tune=500, draws=500, random_seed=1234)
return model, idata


# Use caplog fixture to capture log messages generated by the interpret logger


def test_predictions_list(mtcars, caplog):
model, idata = mtcars
caplog.set_level("INFO", logger="__bambi_interpret__")
Expand Down Expand Up @@ -51,6 +53,7 @@ def test_predictions_list_unspecified(mtcars, caplog):
assert unspecified_msg in interpret_log_msgs
assert len(caplog.records) == 3


def test_predictions_dict_unspecified(mtcars, caplog):
model, idata = mtcars
caplog.set_level("INFO", logger="__bambi_interpret__")
Expand Down
9 changes: 7 additions & 2 deletions tests/test_model_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,6 @@ def test_data_is_copied():
assert all(adults.dtypes[:3] == "object")


@pytest.mark.skip(reason="Censored still not ported")
def test_response_is_censored():
df = pd.DataFrame(
{
Expand All @@ -435,7 +434,13 @@ def test_response_is_censored():
}
)
dm = bmb.Model("censored(x, status) ~ 1", df)
assert dm.response.is_censored
assert dm.response_component.response_term.is_censored is True


def test_response_is_truncated():
df = pd.DataFrame({"x": [1, 2, 3, 4, 5]})
dm = bmb.Model("truncated(x, 5.5) ~ 1", df)
assert dm.response_component.response_term.is_truncated is True


def test_custom_likelihood_function():
Expand Down
Loading

0 comments on commit c6e5dbb

Please sign in to comment.