From 0f49e4ad8ab2d9089c8e3fd4d5b2a5fe58807127 Mon Sep 17 00:00:00 2001 From: Bouwe Andela Date: Mon, 9 Sep 2024 09:39:37 +0200 Subject: [PATCH 1/2] Merge input cubes only once when computing lazy multimodel statistics --- esmvalcore/preprocessor/_multimodel.py | 47 +++++++++++++------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/esmvalcore/preprocessor/_multimodel.py b/esmvalcore/preprocessor/_multimodel.py index d1e0d90e74..25c8e9df82 100644 --- a/esmvalcore/preprocessor/_multimodel.py +++ b/esmvalcore/preprocessor/_multimodel.py @@ -480,7 +480,8 @@ def _compute_eager( input_slices = cubes # scalar cubes else: input_slices = [cube[chunk] for cube in cubes] - result_slice = _compute(input_slices, operator=operator, **kwargs) + combined_cube = _combine(input_slices) + result_slice = _compute(combined_cube, operator=operator, **kwargs) result_slices.append(result_slice) try: @@ -498,10 +499,13 @@ def _compute_eager( return result_cube -def _compute(cubes: list, *, operator: iris.analysis.Aggregator, **kwargs): +def _compute( + cube: iris.cube.Cube, + *, + operator: iris.analysis.Aggregator, + **kwargs, +): """Compute statistic.""" - cube = _combine(cubes) - with warnings.catch_warnings(): warnings.filterwarnings( 'ignore', @@ -526,8 +530,6 @@ def _compute(cubes: list, *, operator: iris.analysis.Aggregator, **kwargs): # Remove concatenation dimension added by _combine result_cube.remove_coord(CONCAT_DIM) - for cube in cubes: - cube.remove_coord(CONCAT_DIM) # some iris aggregators modify dtype, see e.g. # https://numpy.org/doc/stable/reference/generated/numpy.ma.average.html @@ -540,7 +542,7 @@ def _compute(cubes: list, *, operator: iris.analysis.Aggregator, **kwargs): method=cell_method.method, coords=cell_method.coord_names, intervals=cell_method.intervals, - comments=f'input_cubes: {len(cubes)}') + ) result_cube.add_cell_method(updated_method) return result_cube @@ -596,27 +598,26 @@ def _multicube_statistics( # Calculate statistics statistics_cubes = {} lazy_input = any(cube.has_lazy_data() for cube in cubes) - for stat in statistics: - (stat_id, result_cube) = _compute_statistic(cubes, lazy_input, stat) + combined_cube = None + for statistic in statistics: + stat_id = _get_stat_identifier(statistic) + logger.debug('Multicube statistics: computing: %s', stat_id) + + (operator, kwargs) = _get_operator_and_kwargs(statistic) + (agg, agg_kwargs) = get_iris_aggregator(operator, **kwargs) + if lazy_input and agg.lazy_func is not None: + if combined_cube is None: + # Merge input cubes only once as this is can be computationally + # expensive. + combined_cube = _combine(cubes) + result_cube = _compute(combined_cube, operator=agg, **agg_kwargs) + else: + result_cube = _compute_eager(cubes, operator=agg, **agg_kwargs) statistics_cubes[stat_id] = result_cube return statistics_cubes -def _compute_statistic(cubes, lazy_input, statistic): - """Compute a single statistic.""" - stat_id = _get_stat_identifier(statistic) - logger.debug('Multicube statistics: computing: %s', stat_id) - - (operator, kwargs) = _get_operator_and_kwargs(statistic) - (agg, agg_kwargs) = get_iris_aggregator(operator, **kwargs) - if lazy_input and agg.lazy_func is not None: - result_cube = _compute(cubes, operator=agg, **agg_kwargs) - else: - result_cube = _compute_eager(cubes, operator=agg, **agg_kwargs) - return (stat_id, result_cube) - - def _multiproduct_statistics( products, statistics, From f3d56aff2ffa915e1a89311817865ad3b94ba7dd Mon Sep 17 00:00:00 2001 From: Bouwe Andela Date: Thu, 26 Sep 2024 21:43:05 +0200 Subject: [PATCH 2/2] Use ruff formatting --- esmvalcore/preprocessor/_multimodel.py | 131 +++++++++++++------------ 1 file changed, 68 insertions(+), 63 deletions(-) diff --git a/esmvalcore/preprocessor/_multimodel.py b/esmvalcore/preprocessor/_multimodel.py index 25c8e9df82..b790b45117 100644 --- a/esmvalcore/preprocessor/_multimodel.py +++ b/esmvalcore/preprocessor/_multimodel.py @@ -7,6 +7,7 @@ generalized functions that operate on iris cubes. These wrappers support grouped execution by passing a groupby keyword. """ + from __future__ import annotations import logging @@ -39,12 +40,12 @@ logger = logging.getLogger(__name__) -CONCAT_DIM = 'multi-model' +CONCAT_DIM = "multi-model" def _get_consistent_time_unit(cubes): """Return cubes' time unit if consistent, standard calendar otherwise.""" - t_units = [cube.coord('time').units for cube in cubes] + t_units = [cube.coord("time").units for cube in cubes] if len(set(t_units)) == 1: return t_units[0] return cf_units.Unit("days since 1850-01-01", calendar="standard") @@ -68,7 +69,7 @@ def _unify_time_coordinates(cubes): for cube in cubes: # Extract date info from cube - coord = cube.coord('time') + coord = cube.coord("time") years = [p.year for p in coord.units.num2date(coord.points)] months = [p.month for p in coord.units.num2date(coord.points)] days = [p.day for p in coord.units.num2date(coord.points)] @@ -93,36 +94,38 @@ def _unify_time_coordinates(cubes): logger.warning( "Multimodel encountered (sub)daily data and inconsistent " "time units or calendars. Attempting to continue, but " - "might produce unexpected results.") + "might produce unexpected results." + ) else: raise ValueError( "Multimodel statistics preprocessor currently does not " - "support sub-daily data.") + "support sub-daily data." + ) # Update the cubes' time coordinate (both point values and the units!) - cube.coord('time').points = date2num(dates, t_unit, coord.dtype) - cube.coord('time').units = t_unit + cube.coord("time").points = date2num(dates, t_unit, coord.dtype) + cube.coord("time").units = t_unit _guess_time_bounds(cube) def _guess_time_bounds(cube): """Guess time bounds if possible.""" - cube.coord('time').bounds = None - if cube.coord('time').shape == (1,): + cube.coord("time").bounds = None + if cube.coord("time").shape == (1,): logger.debug( "Encountered scalar time coordinate in multi_model_statistics: " "cannot determine its bounds" ) else: - cube.coord('time').guess_bounds() + cube.coord("time").guess_bounds() def _time_coords_are_aligned(cubes): """Return `True` if time coords are aligned.""" - first_time_array = cubes[0].coord('time').points + first_time_array = cubes[0].coord("time").points for cube in cubes[1:]: - other_time_array = cube.coord('time').points + other_time_array = cube.coord("time").points if not np.array_equal(first_time_array, other_time_array): return False @@ -135,20 +138,23 @@ def _map_to_new_time(cube, time_points): Missing data inside original bounds is filled with nearest neighbour Missing data outside original bounds is masked. """ - time_coord = cube.coord('time') + time_coord = cube.coord("time") # Try if the required time points can be obtained by slicing the cube. time_slice = np.isin(time_coord.points, time_points) - if np.any(time_slice) and np.array_equal(time_coord.points[time_slice], - time_points): - time_idx, = cube.coord_dims('time') - indices = tuple(time_slice if i == time_idx else slice(None) - for i in range(cube.ndim)) + if np.any(time_slice) and np.array_equal( + time_coord.points[time_slice], time_points + ): + (time_idx,) = cube.coord_dims("time") + indices = tuple( + time_slice if i == time_idx else slice(None) + for i in range(cube.ndim) + ) return cube[indices] time_points = time_coord.units.num2date(time_points) - sample_points = [('time', time_points)] - scheme = iris.analysis.Nearest(extrapolation_mode='mask') + sample_points = [("time", time_points)] + scheme = iris.analysis.Nearest(extrapolation_mode="mask") # Make sure that all integer time coordinates ('year', 'month', # 'day_of_year', etc.) are converted to floats, otherwise the @@ -156,8 +162,9 @@ def _map_to_new_time(cube, time_points): # to integer". In addition, remove their bounds (this would be done by iris # anyway). int_time_coords = [] - for coord in cube.coords(dimensions=cube.coord_dims('time'), - dim_coords=False): + for coord in cube.coords( + dimensions=cube.coord_dims("time"), dim_coords=False + ): if np.issubdtype(coord.points.dtype, np.integer): int_time_coords.append(coord.name()) coord.points = coord.points.astype(float) @@ -168,7 +175,7 @@ def _map_to_new_time(cube, time_points): new_cube = cube.interpolate(sample_points, scheme) except Exception as excinfo: additional_info = "" - if cube.coords('time', dimensions=()): + if cube.coords("time", dimensions=()): additional_info = ( " Note: this alignment does not work for scalar time " "coordinates. To ignore all scalar coordinates in the input " @@ -182,9 +189,11 @@ def _map_to_new_time(cube, time_points): # Change the dtype of int_time_coords to their original values for coord_name in int_time_coords: - coord = new_cube.coord(coord_name, - dimensions=new_cube.coord_dims('time'), - dim_coords=False) + coord = new_cube.coord( + coord_name, + dimensions=new_cube.coord_dims("time"), + dim_coords=False, + ) coord.points = coord.points.astype(int) return new_cube @@ -197,15 +206,17 @@ def _align_time_coord(cubes, span): if _time_coords_are_aligned(cubes): return cubes - all_time_arrays = [cube.coord('time').points for cube in cubes] + all_time_arrays = [cube.coord("time").points for cube in cubes] - if span == 'overlap': + if span == "overlap": new_time_points = reduce(np.intersect1d, all_time_arrays) - elif span == 'full': + elif span == "full": new_time_points = reduce(np.union1d, all_time_arrays) else: - raise ValueError(f"Invalid argument for span: {span!r}" - "Must be one of 'overlap', 'full'.") + raise ValueError( + f"Invalid argument for span: {span!r}" + "Must be one of 'overlap', 'full'." + ) new_cubes = [_map_to_new_time(cube, new_time_points) for cube in cubes] @@ -229,8 +240,8 @@ def _get_equal_coords_metadata(cubes): for coord in cubes[0].coords(): for other_cube in cubes[1:]: other_cube_has_equal_coord = [ - coord.metadata == other_coord.metadata for other_coord in - other_cube.coords(coord.name()) + coord.metadata == other_coord.metadata + for other_coord in other_cube.coords(coord.name()) ] if not any(other_cube_has_equal_coord): break @@ -261,7 +272,6 @@ def _get_equal_coord_names_metadata(cubes, equal_coords_metadata): # Check if coordinate names and units match across all cubes for other_cube in cubes[1:]: - # Ignore names that do not exist in other cube/are not unique if len(other_cube.coords(coord_name)) != 1: break @@ -276,12 +286,8 @@ def _get_equal_coord_names_metadata(cubes, equal_coords_metadata): std_names = list( {c.coord(coord_name).standard_name for c in cubes} ) - long_names = list( - {c.coord(coord_name).long_name for c in cubes} - ) - var_names = list( - {c.coord(coord_name).var_name for c in cubes} - ) + long_names = list({c.coord(coord_name).long_name for c in cubes}) + var_names = list({c.coord(coord_name).var_name for c in cubes}) equal_names_metadata[coord_name] = dict( standard_name=std_names[0] if len(std_names) == 1 else None, long_name=long_names[0] if len(long_names) == 1 else None, @@ -304,14 +310,12 @@ def _equalise_coordinate_metadata(cubes): # --> keep matching names of these coordinates # Note: ignores duplicate coordinates equal_names_metadata = _get_equal_coord_names_metadata( - cubes, - equal_coords_metadata + cubes, equal_coords_metadata ) # Modify all coordinates of all cubes accordingly for cube in cubes: for coord in cube.coords(): - # Exactly matching coordinates --> do not modify if coord.metadata in equal_coords_metadata: continue @@ -325,9 +329,9 @@ def _equalise_coordinate_metadata(cubes): # Matching names and units --> set common names if coord.name() in equal_names_metadata: equal_names = equal_names_metadata[coord.name()] - coord.standard_name = equal_names['standard_name'] - coord.long_name = equal_names['long_name'] - coord.var_name = equal_names['var_name'] + coord.standard_name = equal_names["standard_name"] + coord.long_name = equal_names["long_name"] + coord.var_name = equal_names["var_name"] continue # Remaining coordinates --> remove long_name @@ -338,7 +342,7 @@ def _equalise_coordinate_metadata(cubes): # in the input cubes. Note: if `ignore_scalar_coords=True` is used for # `multi_model_statistics`, the cubes do not contain scalar coordinates # at this point anymore. - scalar_coords_to_always_remove = ['p0', 'ptop'] + scalar_coords_to_always_remove = ["p0", "ptop"] for scalar_coord in cube.coords(dimensions=()): if scalar_coord.var_name in scalar_coords_to_always_remove: cube.remove_coord(scalar_coord) @@ -363,7 +367,7 @@ def _equalise_var_metadata(cubes): `standard_names`, `long_names`, and `var_names`. """ - attrs = ['standard_name', 'long_name', 'var_name'] + attrs = ["standard_name", "long_name", "var_name"] equal_names_metadata = {} # Collect all names from the different cubes, grouped by cube.name() and @@ -424,7 +428,7 @@ def _combine(cubes): except MergeError as exc: # Note: str(exc) starts with "failed to merge into a single cube.\n" # --> remove this here for clear error message - msg = "\n".join(str(exc).split('\n')[1:]) + msg = "\n".join(str(exc).split("\n")[1:]) raise ValueError( f"Multi-model statistics failed to merge input cubes into a " f"single array:\n{cubes}\n{msg}" @@ -492,7 +496,8 @@ def _compute_eager( f"single array. This happened for operator {operator} " f"with computed statistics {result_slices}. " f"This can happen e.g. if the calculation results in inconsistent " - f"dtypes") from excinfo + f"dtypes" + ) from excinfo result_cube.data = np.ma.array(result_cube.data) @@ -508,22 +513,22 @@ def _compute( """Compute statistic.""" with warnings.catch_warnings(): warnings.filterwarnings( - 'ignore', + "ignore", message=( "Collapsing a non-contiguous coordinate. " f"Metadata may not be fully descriptive for '{CONCAT_DIM}." ), category=UserWarning, - module='iris', + module="iris", ) warnings.filterwarnings( - 'ignore', + "ignore", message=( f"Cannot check if coordinate is contiguous: Invalid " f"operation for '{CONCAT_DIM}'" ), category=UserWarning, - module='iris', + module="iris", ) # This will always return a masked array result_cube = cube.collapsed(CONCAT_DIM, operator, **kwargs) @@ -583,7 +588,7 @@ def _multicube_statistics( # If all cubes contain a time coordinate, align them. If no cube contains a # time coordinate, do nothing. Else, raise an exception. - time_coords = [cube.coords('time') for cube in cubes] + time_coords = [cube.coords("time") for cube in cubes] if all(time_coords): cubes = _align_time_coord(cubes, span=span) elif not any(time_coords): @@ -601,7 +606,7 @@ def _multicube_statistics( combined_cube = None for statistic in statistics: stat_id = _get_stat_identifier(statistic) - logger.debug('Multicube statistics: computing: %s', stat_id) + logger.debug("Multicube statistics: computing: %s", stat_id) (operator, kwargs) = _get_operator_and_kwargs(statistic) (agg, agg_kwargs) = get_iris_aggregator(operator, **kwargs) @@ -659,12 +664,12 @@ def _get_operator_and_kwargs(statistic: str | dict) -> tuple[str, dict]: """Get operator and kwargs from a single statistic.""" if isinstance(statistic, dict): statistic = dict(statistic) - if 'operator' not in statistic: + if "operator" not in statistic: raise ValueError( f"`statistic` given as dictionary, but missing required key " f"`operator`, got {statistic}" ) - operator = statistic.pop('operator') + operator = statistic.pop("operator") kwargs = statistic else: operator = statistic @@ -674,8 +679,8 @@ def _get_operator_and_kwargs(statistic: str | dict) -> tuple[str, dict]: def _get_stat_identifier(statistic: str | dict) -> str: (operator, kwargs) = _get_operator_and_kwargs(statistic) - if 'percent' in kwargs: - operator += str(kwargs['percent']) + if "percent" in kwargs: + operator += str(kwargs["percent"]) return operator @@ -800,7 +805,7 @@ def multi_model_statistics( span=span, ignore_scalar_coords=ignore_scalar_coords, ) - if all(type(p).__name__ == 'PreprocessorFile' for p in products): + if all(type(p).__name__ == "PreprocessorFile" for p in products): # Avoid circular input: https://stackoverflow.com/q/16964467 statistics_products = set() for group, input_prods in _group_products(products, by_key=groupby): @@ -830,7 +835,7 @@ def ensemble_statistics( products: set[PreprocessorFile] | Iterable[Cube], statistics: list[str | dict], output_products, - span: str = 'overlap', + span: str = "overlap", ignore_scalar_coords: bool = False, ) -> dict | set: """Compute ensemble statistics. @@ -877,7 +882,7 @@ def ensemble_statistics( :func:`esmvalcore.preprocessor.multi_model_statistics` for the full description of the core statistics function. """ - ensemble_grouping = ('project', 'dataset', 'exp', 'sub_experiment') + ensemble_grouping = ("project", "dataset", "exp", "sub_experiment") return multi_model_statistics( products=products, span=span,