Skip to content

Commit

Permalink
Implement truncated responses
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto committed Nov 10, 2023
1 parent ee266d9 commit 50aa30a
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 6 deletions.
9 changes: 5 additions & 4 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,9 @@ good-names=b,
nu,
X_terms,
Z_terms,
u
u,
ub,
lb


# Include a hint for the correct naming format with invalid-name
Expand Down Expand Up @@ -496,6 +498,5 @@ min-public-methods=2

[EXCEPTIONS]

# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=Exception
# Exceptions that will emit a warning when being caught.
overgeneral-exceptions=builtins.Exception
31 changes: 30 additions & 1 deletion bambi/backend/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def build_response_distribution(self, kwargs, pymc_backend):

kwargs = self.robustify_dims(pymc_backend, kwargs)

# Handle censoring
if self.term.is_censored:
dims = kwargs.pop("dims", None)
data_matrix = kwargs.pop("observed")
Expand All @@ -295,6 +296,34 @@ def build_response_distribution(self, kwargs, pymc_backend):
dist_rv = pm.Censored(
self.name, stateless_dist, lower=lower, upper=upper, observed=observed, dims=dims
)

# Handle truncation
elif self.term.is_truncated:
dims = kwargs.pop("dims", None)
data_matrix = kwargs.pop("observed")

# Get values of the response variable
observed = np.squeeze(data_matrix[:, 0])

# Get truncation values
lower = np.squeeze(data_matrix[:, 1])
upper = np.squeeze(data_matrix[:, 2])

# Handle 'None' and scalars appropriately
if np.all(lower == -np.inf):
lower = None
elif np.all(lower == lower[0]):
lower = lower[0]

if np.all(upper == np.inf):
upper = None
elif np.all(upper == upper[0]):
upper = upper[0]

stateless_dist = distribution.dist(**kwargs)
dist_rv = pm.Truncated(
self.name, stateless_dist, lower=lower, upper=upper, observed=observed, dims=dims
)
else:
dist_rv = distribution(self.name, **kwargs)

Expand All @@ -316,7 +345,7 @@ def robustify_dims(self, pymc_backend, kwargs):
if isinstance(self.family, (Multinomial, DirichletMultinomial)):
return kwargs

if self.term.is_censored:
if self.term.is_censored or self.term.is_truncated:
return kwargs

dims, data = kwargs["dims"], kwargs["observed"]
Expand Down
3 changes: 2 additions & 1 deletion bambi/terms/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

from bambi.terms.base import BaseTerm

from bambi.terms.utils import is_censored_response
from bambi.terms.utils import is_censored_response, is_truncated_response


class ResponseTerm(BaseTerm):
def __init__(self, response, family):
self.term = response.term.term
self.family = family
self.is_censored = is_censored_response(self.term)
self.is_truncated = is_truncated_response(self.term)

@property
def term(self):
Expand Down
10 changes: 10 additions & 0 deletions bambi/terms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,13 @@ def is_censored_response(term):
if not is_call_component(component):
return False
return is_call_of_kind(component, "censored")


def is_truncated_response(term):
"""Determines if a formulae term represents a truncated response"""
if not is_single_component(term):
return False
component = term.components[0] # get the first (and single) component
if not is_call_component(component):
return False
return is_call_of_kind(component, "truncated")
60 changes: 60 additions & 0 deletions bambi/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,65 @@ def censored(*args):
censored.__metadata__ = {"kind": "censored"}


def truncated(x, lb=None, ub=None):
"""Construct array for truncated response
Parameters
----------
x : np.ndarray
The values of the truncated variable.
lb : int, float, np.ndarray
A number or an array indicating the lower truncation bound.
ub : int, float, np.ndarray
A number or an array indicating the upper truncation bound.
Returns
-------
np.ndarray
Array of shape (n, 3). The first column contains the values of the variable,
the second column the values for the lower bound, and the third variable
the values for the upper bound.
"""
x = np.asarray(x)

if x.ndim != 1:
raise ValueError("'truncated' only works with 1-dimensional arrays")

if lb is None and ub is None:
raise ValueError("'lb' and 'ub' cannot both be None")

# Process lower bound so we get an 1d array with the adequate values
if lb is not None:
lower = np.asarray(lb)
if lower.ndim == 0:
lower = np.full(len(x), lower)
elif lower.ndim == 1:
assert len(lower) == len(x)
else:
raise ValueError("ups")
else:
lower = np.full(len(x), -np.inf)

# Process upper bound so we get an 1d array with the adequate values
if ub is not None:
upper = np.asarray(ub)
if upper.ndim == 0:
upper = np.full(len(x), upper)
elif upper.ndim == 1:
assert len(upper) == len(x)
else:
raise ValueError("ups 2")
else:
upper = np.full(len(x), np.inf)

# Construct output matrix
result = np.column_stack([x, lower, upper])

return result


truncated.__metadata__ = {"kind": "truncated"}

# pylint: disable = invalid-name
@register_stateful_transform
class HSGP: # pylint: disable = too-many-instance-attributes
Expand Down Expand Up @@ -316,6 +375,7 @@ def get_distance(x):
transformations_namespace = {
"c": c,
"censored": censored,
"truncated": truncated,
"log": np.log,
"log2": np.log2,
"log10": np.log10,
Expand Down

0 comments on commit 50aa30a

Please sign in to comment.