diff --git a/CHANGES.rst b/CHANGES.rst index e79b6373..e1b2c8f6 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,7 +4,7 @@ Changelog v0.9.0 (unreleased) ------------------- -Contributors to this version: Trevor James Smith (:user:`Zeitsperre`), Pascal Bourgault (:user:`aulemahal`), Gabriel Rondeau-Genesse (:user:`RondeauG`). +Contributors to this version: Trevor James Smith (:user:`Zeitsperre`), Pascal Bourgault (:user:`aulemahal`), Gabriel Rondeau-Genesse (:user:`RondeauG`), Juliette Lavoie (:user: `juliettelavoie`). Internal changes ^^^^^^^^^^^^^^^^ @@ -12,11 +12,16 @@ Internal changes * Addresses a handful of misconfigurations in the GitHub Workflows. * Added a few free `grep`-based hooks for finding unwanted artifacts in the code base. * Updated `ruff` to v0.2.0 and `black` to v24.2.0. +* Added tests for biasadjust. (:pull:`366`). Bug fixes ^^^^^^^^^ * Fix ``unstack_dates`` for the new frequency syntax introduced by pandas v2.2. (:pull:`359`). * ``subset_warming_level`` will not return partial subsets if the warming level is reached at the end of the timeseries. (:issue:`360`, :pull:`359`). +* Loading of training in `adjust` is now done outside of the periods loop. (:pull:`366`). +* Fixed bug for adding the preprocessing attributes inside the `adjust` function. (:pull:`366`). +* Fixed a bug to accept `group = False` in `adjust` function. (:pull:`366`). + v0.8.3 (2024-02-28) ------------------- diff --git a/tests/test_biasadjust.py b/tests/test_biasadjust.py new file mode 100644 index 00000000..5a51d6c0 --- /dev/null +++ b/tests/test_biasadjust.py @@ -0,0 +1,319 @@ +import numpy as np +import pytest +import xarray as xr +import xclim as xc +from conftest import notebooks +from xclim.testing.helpers import test_timeseries as timeseries + +import xscen as xs + +xc.set_options( + sdba_encode_cf=False +) # FIXME: A temporary bug fix waiting for xclim 0.49 + + +class TestTrain: + dref = timeseries( + np.ones(365 * 3), variable="tas", start="2001-01-01", freq="D", as_dataset=True + ) + + dhist = timeseries( + np.concatenate([np.ones(365 * 2) * 2, np.ones(365) * 3]), + variable="tas", + start="2001-01-01", + freq="D", + as_dataset=True, + ) + dhist.attrs["cat:xrfreq"] = "D" + dhist.attrs["cat:domain"] = "one_point" + dhist.attrs["cat:id"] = "fake_id" + + @pytest.mark.parametrize( + "var, period", + [("tas", ["2001", "2002"]), (["tas"], ["2003", "2003"])], + ) + def test_basic_train(self, var, period): + out = xs.train(self.dref, self.dhist, var=var, period=period) + + assert out.attrs["cat:xrfreq"] == "D" + assert out.attrs["cat:domain"] == "one_point" + assert out.attrs["cat:id"] == "fake_id" + assert out.attrs["cat:processing_level"] == "training_tas" + + assert "dayofyear" in out + assert "quantiles" in out + result = [-1] * 365 if period == ["2001", "2002"] else [-2] * 365 + np.testing.assert_array_equal(out["scaling"], result) + + def test_preprocess(self): + + dref360 = xc.core.calendar.convert_calendar( + self.dref, "360_day", align_on="year" + ) + + out = xs.train( + dref360, + self.dhist, + var="tas", + period=["2001", "2002"], + adapt_freq={"thresh": "2 K"}, + jitter_over={"upper_bnd": "3 K", "thresh": "2 K"}, + jitter_under={"thresh": "2 K"}, + ) + + assert out.attrs["train_params"] == { + "maximal_calendar": "noleap", + "adapt_freq": {"thresh": "2 K"}, + "jitter_over": {"upper_bnd": "3 K", "thresh": "2 K"}, + "jitter_under": {"thresh": "2 K"}, + "var": ["tas"], + } + + assert "pth" in out + assert "dP0" in out + assert "dayofyear" in out + assert "quantiles" in out + + def test_group(self): + out = xs.train( + self.dref, + self.dhist, + var="tas", + period=["2001", "2002"], + group={"group": "time.month", "window": 1}, + ) + + out1 = xs.train( + self.dref, + self.dhist, + var="tas", + period=["2001", "2002"], + group="time.month", + ) + + assert "month" in out + assert "quantiles" in out + assert out1.equals(out) + + def test_errors(self): + with pytest.raises(ValueError): + xs.train(self.dref, self.dhist, var=["tas", "pr"], period=["2001", "2002"]) + + +class TestAdjust: + dref = timeseries( + np.ones((365 * 3) + 1), # leap year + variable="tas", + start="2001-01-01", + freq="D", + as_dataset=True, + ) + + dsim = timeseries( + np.concatenate([np.ones(365 * 3) * 2, np.ones((365 * 3) + 1) * 4]), + variable="tas", + start="2001-01-01", + freq="D", + as_dataset=True, + ) + dsim.attrs["cat:xrfreq"] = "D" + dsim.attrs["cat:domain"] = "one_point" + dsim.attrs["cat:id"] = "fake_id" + + @pytest.mark.parametrize( + "periods, to_level, bias_adjust_institution, bias_adjust_project", + [ + (["2001", "2006"], None, None, None), + ([["2001", "2001"], ["2006", "2006"]], "test", "i", "p"), + ], + ) + def test_basic( + self, periods, to_level, bias_adjust_institution, bias_adjust_project + ): + dtrain = xs.train( + self.dref, + self.dsim.sel(time=slice("2001", "2003")), + var="tas", + period=["2001", "2003"], + ) + + out = xs.adjust( + dtrain, + self.dsim, + periods=periods, + to_level=to_level, + bias_adjust_institution=bias_adjust_institution, + bias_adjust_project=bias_adjust_project, + ) + assert out.attrs["cat:processing_level"] == to_level or "biasadjusted" + assert out.attrs["cat:variable"] == ("tas",) + assert out.attrs["cat:id"] == "fake_id" + assert ( + out["tas"].attrs["bias_adjustment"] + == "DetrendedQuantileMapping(group=Grouper(" + "name='time.dayofyear', window=31), kind='+'" + ").adjust(sim, )" + ) + assert xc.core.calendar.get_calendar(out) == "noleap" + + if bias_adjust_institution is not None: + assert out.attrs["cat:bias_adjust_institution"] == "i" + if bias_adjust_project is not None: + assert out.attrs["cat:bias_adjust_project"] == "p" + + assert out.time.dt.year.values[0] == 2001 + assert out.time.dt.year.values[-1] == 2006 + + if periods == ["2001", "2006"]: + np.testing.assert_array_equal( + out["tas"].values, + np.concatenate( + [np.ones(365 * 3) * 1, np.ones(365 * 3) * 3] + ), # -1 for leap year + ) + else: # periods==[['2001','2001'], ['2006','2006']] + np.testing.assert_array_equal( + out["tas"].values, + np.concatenate([np.ones(365 * 1) * 1, np.ones(365 * 1) * 3]), + ) + + def test_write_train(self): + dtrain = xs.train( + self.dref, + self.dsim.sel(time=slice("2001", "2003")), + var="tas", + period=["2001", "2003"], + adapt_freq={"thresh": "2 K"}, + jitter_over={"upper_bnd": "3 K", "thresh": "2 K"}, + jitter_under={"thresh": "2 K"}, + ) + + root = str(notebooks / "_data") + xs.save_to_zarr(dtrain, f"{root}/test.zarr", mode="o") + dtrain2 = xr.open_dataset( + f"{root}/test.zarr", chunks={"dayofyear": 365, "quantiles": 15} + ) + + out = xs.adjust( + dtrain, + self.dsim, + periods=["2001", "2006"], + xclim_adjust_args={ + "detrend": { + "LoessDetrend": {"f": 0.2, "niter": 1, "d": 0, "weights": "tricube"} + } + }, + ) + + out2 = xs.adjust( + dtrain2, + self.dsim, + periods=["2001", "2006"], + xclim_adjust_args={ + "detrend": { + "LoessDetrend": {"f": 0.2, "niter": 1, "d": 0, "weights": "tricube"} + } + }, + ) + + assert ( + out.tas.attrs["bias_adjustment"] + == "DetrendedQuantileMapping(group=Grouper(name='time.dayofyear'," + " window=31), kind='+').adjust(sim, detrend=)," + " ref and hist were prepared with jitter_under_thresh(ref, hist," + " {'thresh': '2 K'}) and jitter_over_thresh(ref, hist, {'upper_bnd':" + " '3 K', 'thresh': '2 K'}) and adapt_freq(ref, hist, {'thresh': '2 K'})" + ) + + assert ( + out2.tas.attrs["bias_adjustment"] + == "DetrendedQuantileMapping(group=Grouper(name='time.dayofyear'," + " window=31), kind='+').adjust(sim, detrend=), ref and" + " hist were prepared with jitter_under_thresh(ref, hist, {'thresh':" + " '2 K'}) and jitter_over_thresh(ref, hist, {'upper_bnd': '3 K'," + " 'thresh': '2 K'}) and adapt_freq(ref, hist, {'thresh': '2 K'})" + ) + + assert out.equals(out2) + + def test_xclim_vs_xscen( + self, + ): # should give the same results using xscen and xclim + dref = ( + timeseries( + np.random.randint(0, high=10, size=(365 * 3) + 1), + variable="pr", + start="2001-01-01", + freq="D", + as_dataset=True, + ) + .astype("float32") + .chunk({"time": -1}) + ) + + dsim = ( + timeseries( + np.random.randint(0, high=10, size=365 * 6 + 1), + variable="pr", + start="2001-01-01", + freq="D", + as_dataset=True, + ) + .astype("float32") + .chunk({"time": -1}) + ) + dhist = dsim.sel(time=slice("2001", "2003")).chunk({"time": -1}) + + # xscen version + dtrain_xscen = xs.train( + dref, + dhist, + var="pr", + period=["2001", "2003"], + adapt_freq={"thresh": "1 mm d-1"}, + xclim_train_args={"kind": "*", "nquantiles": 50}, + ) + + out_xscen = xs.adjust( + dtrain_xscen, + dsim, + periods=["2001", "2006"], + xclim_adjust_args={ + "detrend": { + "LoessDetrend": {"f": 0.2, "niter": 1, "d": 0, "weights": "tricube"} + }, + "interp": "nearest", + "extrapolation": "constant", + }, + ) + + # xclim version + with xc.set_options(sdba_extra_output=True): + group = xc.sdba.Grouper(group="time.dayofyear", window=31) + + drefx = xc.core.calendar.convert_calendar( + dref.sel(time=slice("2001", "2003")), "noleap" + ) + dhistx = xc.core.calendar.convert_calendar( + dhist.sel(time=slice("2001", "2003")), "noleap" + ) + dsimx = xc.core.calendar.convert_calendar( + dsim.sel(time=slice("2001", "2006")), "noleap" + ) + + dhist_ad, pth, dP0 = xc.sdba.processing.adapt_freq( + drefx["pr"], dhistx["pr"], group=group, thresh="1 mm d-1" + ) + + QM = xc.sdba.DetrendedQuantileMapping.train( + drefx["pr"], dhist_ad, group=group, kind="*", nquantiles=50 + ) + + detrend = xc.sdba.detrending.LoessDetrend( + f=0.2, niter=1, d=0, weights="tricube", group=group, kind="*" + ) + out_xclim = QM.adjust( + dsimx["pr"], detrend=detrend, interp="nearest", extrapolation="constant" + ).rename({"scen": "pr"}) + + assert out_xscen.equals(out_xclim) diff --git a/xscen/biasadjust.py b/xscen/biasadjust.py index ab083d47..63306489 100644 --- a/xscen/biasadjust.py +++ b/xscen/biasadjust.py @@ -8,7 +8,6 @@ import xclim as xc from xclim import sdba from xclim.core.calendar import convert_calendar, get_calendar -from xclim.sdba import construct_moving_yearly_window, unpack_moving_yearly_window from .catutils import parse_from_ds from .config import parse_config @@ -52,6 +51,7 @@ def _add_preprocessing_attr(scen, train_kwargs): scen.attrs[ "bias_adjustment" ] += ", ref and hist were prepared with " + " and ".join(preproc) + return scen @parse_config @@ -123,7 +123,10 @@ def train( ref = dref[var[0]] hist = dhist[var[0]] - group = group or {"group": "time.dayofyear", "window": 31} + # we want to put default if group is None, but not if group is False + if group is None: + group = {"group": "time.dayofyear", "window": 31} + xclim_train_args = xclim_train_args or {} if method == "DetrendedQuantileMapping": xclim_train_args.setdefault("nquantiles", 15) @@ -197,7 +200,6 @@ def adjust( to_level: str = "biasadjusted", bias_adjust_institution: Optional[str] = None, bias_adjust_project: Optional[str] = None, - moving_yearly_window: Optional[dict] = None, align_on: Optional[str] = "year", ) -> xr.Dataset: """ @@ -220,12 +222,6 @@ def adjust( The institution to assign to the output. bias_adjust_project : str, optional The project to assign to the output. - moving_yearly_window: dict, optional - Arguments to pass to `xclim.sdba.construct_moving_yearly_window`. - If not None, `construct_moving_yearly_window` will be called on dsim (and scen in xclim_adjust_args if it exists) - before adjusting and `unpack_moving_yearly_window` will be called on the output after the adjustment. - `construct_moving_yearly_window` stacks windows of the dataArray in a new 'movingwin' dimension. - `unpack_moving_yearly_window` unpacks it to a normal time series. align_on: str, optional `align_on` argument for the fonction `xclim.core.calendar.convert_calendar`. @@ -239,18 +235,9 @@ def adjust( xclim.sdba.adjustment.DetrendedQuantileMapping, xclim.sdba.adjustment.ExtremeValues """ - # TODO: To be adequately fixed later - xclim_adjust_args = deepcopy(xclim_adjust_args) xclim_adjust_args = xclim_adjust_args or {} - if moving_yearly_window: - dsim = construct_moving_yearly_window(dsim, **moving_yearly_window) - if "scen" in xclim_adjust_args: - xclim_adjust_args["scen"] = construct_moving_yearly_window( - xclim_adjust_args["scen"], **moving_yearly_window - ) - # evaluate the dict that was stored as a string if not isinstance(dtrain.attrs["train_params"], dict): dtrain.attrs["train_params"] = eval(dtrain.attrs["train_params"]) @@ -270,33 +257,31 @@ def adjust( if simcal != mincal: sim = convert_calendar(sim, mincal, align_on=align_on) + # adjust + ADJ = sdba.adjustment.TrainAdjust.from_dataset(dtrain) + + if ("detrend" in xclim_adjust_args) and ( + isinstance(xclim_adjust_args["detrend"], dict) + ): + name, kwargs = list(xclim_adjust_args["detrend"].items())[0] + kwargs = kwargs or {} + kwargs.setdefault("group", ADJ.group) + kwargs.setdefault("kind", ADJ.kind) + xclim_adjust_args["detrend"] = getattr(sdba.detrending, name)(**kwargs) + # do the adjustment for all the simulation_period lists periods = standardize_periods(periods) slices = [] for period in periods: sim_sel = sim.sel(time=slice(period[0], period[1])) - # adjust - ADJ = sdba.adjustment.TrainAdjust.from_dataset(dtrain) - - if ("detrend" in xclim_adjust_args) and ( - isinstance(xclim_adjust_args["detrend"], dict) - ): - name, kwargs = list(xclim_adjust_args["detrend"].items())[0] - kwargs = kwargs or {} - kwargs.setdefault("group", ADJ.group) - kwargs.setdefault("kind", ADJ.kind) - xclim_adjust_args["detrend"] = getattr(sdba.detrending, name)(**kwargs) - out = ADJ.adjust(sim_sel, **xclim_adjust_args) slices.extend([out]) # put all the adjusted period back together dscen = xr.concat(slices, dim="time") - _add_preprocessing_attr(dscen, dtrain.attrs["train_params"]) + dscen = _add_preprocessing_attr(dscen, dtrain.attrs["train_params"]) dscen = xr.Dataset(data_vars={var: dscen}, attrs=dsim.attrs) - # TODO: History, attrs, etc. (TODO kept from previous version of `biasadjust`) - # TODO: Check for variables to add (grid_mapping, etc.) (TODO kept from previous version of `biasadjust`) dscen.attrs["cat:processing_level"] = to_level dscen.attrs["cat:variable"] = parse_from_ds(dscen, ["variable"])["variable"] if bias_adjust_institution is not None: @@ -304,7 +289,4 @@ def adjust( if bias_adjust_project is not None: dscen.attrs["cat:bias_adjust_project"] = bias_adjust_project - if moving_yearly_window: - dscen = unpack_moving_yearly_window(dscen) - return dscen