diff --git a/esmvalcore/preprocessor/_multimodel.py b/esmvalcore/preprocessor/_multimodel.py index dcce65ebd3..b790b45117 100644 --- a/esmvalcore/preprocessor/_multimodel.py +++ b/esmvalcore/preprocessor/_multimodel.py @@ -484,7 +484,8 @@ def _compute_eager( input_slices = cubes # scalar cubes else: input_slices = [cube[chunk] for cube in cubes] - result_slice = _compute(input_slices, operator=operator, **kwargs) + combined_cube = _combine(input_slices) + result_slice = _compute(combined_cube, operator=operator, **kwargs) result_slices.append(result_slice) try: @@ -503,10 +504,13 @@ def _compute_eager( return result_cube -def _compute(cubes: list, *, operator: iris.analysis.Aggregator, **kwargs): +def _compute( + cube: iris.cube.Cube, + *, + operator: iris.analysis.Aggregator, + **kwargs, +): """Compute statistic.""" - cube = _combine(cubes) - with warnings.catch_warnings(): warnings.filterwarnings( "ignore", @@ -531,8 +535,6 @@ def _compute(cubes: list, *, operator: iris.analysis.Aggregator, **kwargs): # Remove concatenation dimension added by _combine result_cube.remove_coord(CONCAT_DIM) - for cube in cubes: - cube.remove_coord(CONCAT_DIM) # some iris aggregators modify dtype, see e.g. # https://numpy.org/doc/stable/reference/generated/numpy.ma.average.html @@ -545,7 +547,6 @@ def _compute(cubes: list, *, operator: iris.analysis.Aggregator, **kwargs): method=cell_method.method, coords=cell_method.coord_names, intervals=cell_method.intervals, - comments=f"input_cubes: {len(cubes)}", ) result_cube.add_cell_method(updated_method) return result_cube @@ -602,27 +603,26 @@ def _multicube_statistics( # Calculate statistics statistics_cubes = {} lazy_input = any(cube.has_lazy_data() for cube in cubes) - for stat in statistics: - (stat_id, result_cube) = _compute_statistic(cubes, lazy_input, stat) + combined_cube = None + for statistic in statistics: + stat_id = _get_stat_identifier(statistic) + logger.debug("Multicube statistics: computing: %s", stat_id) + + (operator, kwargs) = _get_operator_and_kwargs(statistic) + (agg, agg_kwargs) = get_iris_aggregator(operator, **kwargs) + if lazy_input and agg.lazy_func is not None: + if combined_cube is None: + # Merge input cubes only once as this is can be computationally + # expensive. + combined_cube = _combine(cubes) + result_cube = _compute(combined_cube, operator=agg, **agg_kwargs) + else: + result_cube = _compute_eager(cubes, operator=agg, **agg_kwargs) statistics_cubes[stat_id] = result_cube return statistics_cubes -def _compute_statistic(cubes, lazy_input, statistic): - """Compute a single statistic.""" - stat_id = _get_stat_identifier(statistic) - logger.debug("Multicube statistics: computing: %s", stat_id) - - (operator, kwargs) = _get_operator_and_kwargs(statistic) - (agg, agg_kwargs) = get_iris_aggregator(operator, **kwargs) - if lazy_input and agg.lazy_func is not None: - result_cube = _compute(cubes, operator=agg, **agg_kwargs) - else: - result_cube = _compute_eager(cubes, operator=agg, **agg_kwargs) - return (stat_id, result_cube) - - def _multiproduct_statistics( products, statistics,