Skip to content

Commit

Permalink
more flexible shapes and protect against div0
Browse files Browse the repository at this point in the history
  • Loading branch information
grantbuster committed Jan 7, 2024
1 parent 36781a6 commit 5fec663
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions rex/bias_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ def _clean_params(self, params, arr_shape):

msg = f'params must be 2D array but received {type(params)}'
assert isinstance(params, np.ndarray), msg

if len(params.shape) == 1:
params = np.expand_dims(params, 0)

msg = (f'params must be 2D array of shape ({arr_shape[1]}, N) '
f'but received shape {params.shape}')
assert len(params.shape) == 2, msg
Expand Down Expand Up @@ -356,18 +360,23 @@ def __call__(self, arr):
Bias corrected copy of the input array with same shape.
"""

if len(arr.shape) == 1:
arr = np.expand_dims(arr, 1)

params_oh = self._clean_params(self.params_oh, arr.shape)
params_mh = self._clean_params(self.params_mh, arr.shape)
params_mf = self._clean_params(self.params_mf, arr.shape)

p_mf = self.cdf(arr, params_mf)
x_oh = self.ppf(p_mf, params_oh)
q_mf = self.cdf(arr, params_mf) # Tau_m_p
x_oh = self.ppf(q_mf, params_oh) # x^_o:m_h:p
x_mh_mf = self.ppf(q_mf, params_mh) # F-1_mh[Tau_m_p]

if self.relative:
delta = arr / self.ppf(p_mf, params_mh)
x_mh_mf[x_mh_mf == 0] = 0.001 # arbitrary limit to prevent div 0
delta = arr / x_mh_mf
arr_bc = x_oh * delta
else:
delta = arr - self.ppf(p_mf, params_mh)
delta = arr - x_mh_mf
arr_bc = x_oh + delta

msg = ('Input shape {} does not match QDM bias corrected output '
Expand Down Expand Up @@ -554,6 +563,16 @@ def qdm_ws(ws, params_oh, params_mh, params_mf=None, dist='empirical',
qdm = _QuantileDeltaMapping(params_oh, params_mh, params_mf, dist=dist,
relative=relative, sampling=sampling,
log_base=log_base)

# This will prevent inverse CDF functions from returning zero resulting in
# a divide by zero error in the calculation of the QDM delta. These zeros
# get fixed later
ws_zeros = ws == 0
ws[ws_zeros] = 0.01

ws = qdm(ws)

ws = np.maximum(ws, 0)
ws[ws_zeros] = 0

return ws

0 comments on commit 5fec663

Please sign in to comment.