From 50aa30a4cdce6cd64465067d615db9b426890990 Mon Sep 17 00:00:00 2001 From: Tomas Capretto Date: Fri, 10 Nov 2023 11:59:00 -0300 Subject: [PATCH] Implement truncated responses --- .pylintrc | 9 +++--- bambi/backend/terms.py | 31 ++++++++++++++++++++- bambi/terms/response.py | 3 +- bambi/terms/utils.py | 10 +++++++ bambi/transformations.py | 60 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 107 insertions(+), 6 deletions(-) diff --git a/.pylintrc b/.pylintrc index bf013ead7..c5bde5e22 100644 --- a/.pylintrc +++ b/.pylintrc @@ -261,7 +261,9 @@ good-names=b, nu, X_terms, Z_terms, - u + u, + ub, + lb # Include a hint for the correct naming format with invalid-name @@ -496,6 +498,5 @@ min-public-methods=2 [EXCEPTIONS] -# Exceptions that will emit a warning when being caught. Defaults to -# "Exception" -overgeneral-exceptions=Exception +# Exceptions that will emit a warning when being caught. +overgeneral-exceptions=builtins.Exception diff --git a/bambi/backend/terms.py b/bambi/backend/terms.py index fec60cae2..730a9af5f 100644 --- a/bambi/backend/terms.py +++ b/bambi/backend/terms.py @@ -276,6 +276,7 @@ def build_response_distribution(self, kwargs, pymc_backend): kwargs = self.robustify_dims(pymc_backend, kwargs) + # Handle censoring if self.term.is_censored: dims = kwargs.pop("dims", None) data_matrix = kwargs.pop("observed") @@ -295,6 +296,34 @@ def build_response_distribution(self, kwargs, pymc_backend): dist_rv = pm.Censored( self.name, stateless_dist, lower=lower, upper=upper, observed=observed, dims=dims ) + + # Handle truncation + elif self.term.is_truncated: + dims = kwargs.pop("dims", None) + data_matrix = kwargs.pop("observed") + + # Get values of the response variable + observed = np.squeeze(data_matrix[:, 0]) + + # Get truncation values + lower = np.squeeze(data_matrix[:, 1]) + upper = np.squeeze(data_matrix[:, 2]) + + # Handle 'None' and scalars appropriately + if np.all(lower == -np.inf): + lower = None + elif np.all(lower == lower[0]): + lower = lower[0] + + if np.all(upper == np.inf): + upper = None + elif np.all(upper == upper[0]): + upper = upper[0] + + stateless_dist = distribution.dist(**kwargs) + dist_rv = pm.Truncated( + self.name, stateless_dist, lower=lower, upper=upper, observed=observed, dims=dims + ) else: dist_rv = distribution(self.name, **kwargs) @@ -316,7 +345,7 @@ def robustify_dims(self, pymc_backend, kwargs): if isinstance(self.family, (Multinomial, DirichletMultinomial)): return kwargs - if self.term.is_censored: + if self.term.is_censored or self.term.is_truncated: return kwargs dims, data = kwargs["dims"], kwargs["observed"] diff --git a/bambi/terms/response.py b/bambi/terms/response.py index 98fe15986..814c62dc6 100644 --- a/bambi/terms/response.py +++ b/bambi/terms/response.py @@ -2,7 +2,7 @@ from bambi.terms.base import BaseTerm -from bambi.terms.utils import is_censored_response +from bambi.terms.utils import is_censored_response, is_truncated_response class ResponseTerm(BaseTerm): @@ -10,6 +10,7 @@ def __init__(self, response, family): self.term = response.term.term self.family = family self.is_censored = is_censored_response(self.term) + self.is_truncated = is_truncated_response(self.term) @property def term(self): diff --git a/bambi/terms/utils.py b/bambi/terms/utils.py index 9d026304e..e0af0c6a0 100644 --- a/bambi/terms/utils.py +++ b/bambi/terms/utils.py @@ -30,3 +30,13 @@ def is_censored_response(term): if not is_call_component(component): return False return is_call_of_kind(component, "censored") + + +def is_truncated_response(term): + """Determines if a formulae term represents a truncated response""" + if not is_single_component(term): + return False + component = term.components[0] # get the first (and single) component + if not is_call_component(component): + return False + return is_call_of_kind(component, "truncated") diff --git a/bambi/transformations.py b/bambi/transformations.py index 8194f83aa..054489f5a 100644 --- a/bambi/transformations.py +++ b/bambi/transformations.py @@ -68,6 +68,65 @@ def censored(*args): censored.__metadata__ = {"kind": "censored"} +def truncated(x, lb=None, ub=None): + """Construct array for truncated response + + Parameters + ---------- + x : np.ndarray + The values of the truncated variable. + lb : int, float, np.ndarray + A number or an array indicating the lower truncation bound. + ub : int, float, np.ndarray + A number or an array indicating the upper truncation bound. + + Returns + ------- + np.ndarray + Array of shape (n, 3). The first column contains the values of the variable, + the second column the values for the lower bound, and the third variable + the values for the upper bound. + """ + x = np.asarray(x) + + if x.ndim != 1: + raise ValueError("'truncated' only works with 1-dimensional arrays") + + if lb is None and ub is None: + raise ValueError("'lb' and 'ub' cannot both be None") + + # Process lower bound so we get an 1d array with the adequate values + if lb is not None: + lower = np.asarray(lb) + if lower.ndim == 0: + lower = np.full(len(x), lower) + elif lower.ndim == 1: + assert len(lower) == len(x) + else: + raise ValueError("ups") + else: + lower = np.full(len(x), -np.inf) + + # Process upper bound so we get an 1d array with the adequate values + if ub is not None: + upper = np.asarray(ub) + if upper.ndim == 0: + upper = np.full(len(x), upper) + elif upper.ndim == 1: + assert len(upper) == len(x) + else: + raise ValueError("ups 2") + else: + upper = np.full(len(x), np.inf) + + # Construct output matrix + result = np.column_stack([x, lower, upper]) + + return result + + +truncated.__metadata__ = {"kind": "truncated"} + # pylint: disable = invalid-name @register_stateful_transform class HSGP: # pylint: disable = too-many-instance-attributes @@ -316,6 +375,7 @@ def get_distance(x): transformations_namespace = { "c": c, "censored": censored, + "truncated": truncated, "log": np.log, "log2": np.log2, "log10": np.log10,