diff --git a/ci/base.yml b/ci/base.yml
index 2b1732888c..24d9a01314 100644
--- a/ci/base.yml
+++ b/ci/base.yml
@@ -43,3 +43,4 @@ variables:
VIRTUALENV_SYSTEM_SITE_PACKAGES: 1
CSCS_NEEDED_DATA: icon4py
TEST_DATA_PATH: "/project/d121/icon4py/ci/testdata"
+ ICON_GRID_LOC: "/project/d121/icon4py/ci/testdata/grids/mch_ch_r04b09_dsl"
diff --git a/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/cached.py b/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/cached.py
new file mode 100644
index 0000000000..30834c291d
--- /dev/null
+++ b/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/cached.py
@@ -0,0 +1,79 @@
+# ICON4Py - ICON inspired code in Python and GT4Py
+#
+# Copyright (c) 2022, ETH Zurich and MeteoSwiss
+# All rights reserved.
+#
+# This file is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from icon4py.model.atmosphere.diffusion.diffusion_utils import (
+ copy_field as copy_field_orig,
+ init_diffusion_local_fields_for_regular_timestep as init_diffusion_local_fields_for_regular_timestep_orig,
+ scale_k as scale_k_orig,
+ setup_fields_for_initial_step as setup_fields_for_initial_step_orig,
+)
+from icon4py.model.atmosphere.diffusion.stencils.apply_diffusion_to_vn import (
+ apply_diffusion_to_vn as apply_diffusion_to_vn_orig,
+)
+from icon4py.model.atmosphere.diffusion.stencils.apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence import (
+ apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence as apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence_orig,
+)
+from icon4py.model.atmosphere.diffusion.stencils.calculate_diagnostic_quantities_for_turbulence import (
+ calculate_diagnostic_quantities_for_turbulence as calculate_diagnostic_quantities_for_turbulence_orig,
+)
+from icon4py.model.atmosphere.diffusion.stencils.calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools import (
+ calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools as calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools_orig,
+)
+from icon4py.model.atmosphere.diffusion.stencils.calculate_nabla2_and_smag_coefficients_for_vn import (
+ calculate_nabla2_and_smag_coefficients_for_vn as calculate_nabla2_and_smag_coefficients_for_vn_orig,
+)
+from icon4py.model.atmosphere.diffusion.stencils.calculate_nabla2_for_theta import (
+ calculate_nabla2_for_theta as calculate_nabla2_for_theta_orig,
+)
+from icon4py.model.atmosphere.diffusion.stencils.truly_horizontal_diffusion_nabla_of_theta_over_steep_points import (
+ truly_horizontal_diffusion_nabla_of_theta_over_steep_points as truly_horizontal_diffusion_nabla_of_theta_over_steep_points_orig,
+)
+from icon4py.model.atmosphere.diffusion.stencils.update_theta_and_exner import (
+ update_theta_and_exner as update_theta_and_exner_orig,
+)
+from icon4py.model.common.caching import CachedProgram
+from icon4py.model.common.interpolation.stencils.mo_intp_rbf_rbf_vec_interpol_vertex import (
+ mo_intp_rbf_rbf_vec_interpol_vertex as mo_intp_rbf_rbf_vec_interpol_vertex_orig,
+)
+
+
+# diffusion run stencils
+apply_diffusion_to_vn = CachedProgram(apply_diffusion_to_vn_orig)
+apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence = CachedProgram(
+ apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence_orig
+)
+calculate_diagnostic_quantities_for_turbulence = CachedProgram(
+ calculate_diagnostic_quantities_for_turbulence_orig
+)
+calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools = CachedProgram(
+ calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools_orig
+)
+calculate_nabla2_and_smag_coefficients_for_vn = CachedProgram(
+ calculate_nabla2_and_smag_coefficients_for_vn_orig
+)
+calculate_nabla2_for_theta = CachedProgram(calculate_nabla2_for_theta_orig)
+truly_horizontal_diffusion_nabla_of_theta_over_steep_points = CachedProgram(
+ truly_horizontal_diffusion_nabla_of_theta_over_steep_points_orig
+)
+update_theta_and_exner = CachedProgram(update_theta_and_exner_orig)
+
+mo_intp_rbf_rbf_vec_interpol_vertex = CachedProgram(mo_intp_rbf_rbf_vec_interpol_vertex_orig)
+
+
+# model init stencils
+setup_fields_for_initial_step = CachedProgram(setup_fields_for_initial_step_orig, with_domain=False)
+copy_field = CachedProgram(copy_field_orig, with_domain=False)
+init_diffusion_local_fields_for_regular_timestep = CachedProgram(
+ init_diffusion_local_fields_for_regular_timestep_orig, with_domain=False
+)
+scale_k = CachedProgram(scale_k_orig, with_domain=False)
diff --git a/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py b/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py
index ac9a332aaa..e9c5f962d0 100644
--- a/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py
+++ b/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py
@@ -28,35 +28,27 @@
DiffusionMetricState,
)
from icon4py.model.atmosphere.diffusion.diffusion_utils import (
- copy_field,
- init_diffusion_local_fields_for_regular_timestep,
init_nabla2_factor_in_upper_damping_zone,
- scale_k,
- setup_fields_for_initial_step,
zero_field,
)
-from icon4py.model.atmosphere.diffusion.stencils.apply_diffusion_to_vn import apply_diffusion_to_vn
-from icon4py.model.atmosphere.diffusion.stencils.apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence import (
- apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence,
-)
-from icon4py.model.atmosphere.diffusion.stencils.calculate_diagnostic_quantities_for_turbulence import (
- calculate_diagnostic_quantities_for_turbulence,
-)
-from icon4py.model.atmosphere.diffusion.stencils.calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools import (
- calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools,
-)
-from icon4py.model.atmosphere.diffusion.stencils.calculate_nabla2_and_smag_coefficients_for_vn import (
+
+# cached program import
+from icon4py.model.atmosphere.diffusion.cached import (
+ init_diffusion_local_fields_for_regular_timestep,
+ setup_fields_for_initial_step,
+ scale_k,
calculate_nabla2_and_smag_coefficients_for_vn,
-)
-from icon4py.model.atmosphere.diffusion.stencils.calculate_nabla2_for_theta import (
calculate_nabla2_for_theta,
-)
-from icon4py.model.atmosphere.diffusion.stencils.truly_horizontal_diffusion_nabla_of_theta_over_steep_points import (
truly_horizontal_diffusion_nabla_of_theta_over_steep_points,
-)
-from icon4py.model.atmosphere.diffusion.stencils.update_theta_and_exner import (
+ apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence,
+ apply_diffusion_to_vn,
+ calculate_diagnostic_quantities_for_turbulence,
+ calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools,
update_theta_and_exner,
+ copy_field,
+ mo_intp_rbf_rbf_vec_interpol_vertex,
)
+
from icon4py.model.common.constants import (
CPD,
DEFAULT_PHYSICS_DYNAMICS_TIMESTEP_RATIO,
@@ -68,9 +60,6 @@
from icon4py.model.common.grid.horizontal import CellParams, EdgeParams, HorizontalMarkerIndex
from icon4py.model.common.grid.icon import IconGrid
from icon4py.model.common.grid.vertical import VerticalModelParams
-from icon4py.model.common.interpolation.stencils.mo_intp_rbf_rbf_vec_interpol_vertex import (
- mo_intp_rbf_rbf_vec_interpol_vertex,
-)
from icon4py.model.common.states.prognostic_state import PrognosticState
from icon4py.model.common.settings import xp
@@ -752,6 +741,7 @@ def _do_diffusion_step(
)
# TODO (magdalena) get rid of this copying. So far passing an empty buffer instead did not verify?
copy_field(prognostic_state.w, self.w_tmp, offset_provider={})
+
apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence(
area=self.cell_params.area,
geofac_n2s=self.interpolation_state.geofac_n2s,
@@ -784,6 +774,7 @@ def _do_diffusion_step(
log.debug(
"running fused stencils 11 12 (calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools): start"
)
+
calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools(
theta_v=prognostic_state.theta_v,
theta_ref_mc=self.metric_state.theta_ref_mc,
@@ -799,6 +790,7 @@ def _do_diffusion_step(
log.debug(
"running stencils 11 12 (calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools): end"
)
+
log.debug("running stencils 13 14 (calculate_nabla2_for_theta): start")
calculate_nabla2_for_theta(
kh_smag_e=self.kh_smag_e,
diff --git a/model/atmosphere/diffusion/tests/diffusion_stencil_tests/test_temporary_fields_for_turbulence_diagnostics.py b/model/atmosphere/diffusion/tests/diffusion_stencil_tests/test_temporary_fields_for_turbulence_diagnostics.py
index e45018c211..bcfd669274 100644
--- a/model/atmosphere/diffusion/tests/diffusion_stencil_tests/test_temporary_fields_for_turbulence_diagnostics.py
+++ b/model/atmosphere/diffusion/tests/diffusion_stencil_tests/test_temporary_fields_for_turbulence_diagnostics.py
@@ -43,12 +43,15 @@ def reference(
**kwargs,
) -> dict:
c2e = grid.connectivities[C2EDim]
+ c2ce = grid.get_offset_provider("C2CE").table
+
geofac_div = np.expand_dims(geofac_div, axis=-1)
- vn_geofac = vn[c2e] * geofac_div[grid.get_offset_provider("C2CE").table]
- div = np.sum(vn_geofac, axis=1)
e_bln_c_s = np.expand_dims(e_bln_c_s, axis=-1)
diff_multfac_smag = np.expand_dims(diff_multfac_smag, axis=0)
- mul = kh_smag_ec[c2e] * e_bln_c_s[grid.get_offset_provider("C2CE").table]
+
+ vn_geofac = vn[c2e] * geofac_div[c2ce]
+ div = np.sum(vn_geofac, axis=1)
+ mul = kh_smag_ec[c2e] * e_bln_c_s[c2ce]
summed = np.sum(mul, axis=1)
kh_c = summed / diff_multfac_smag
diff --git a/model/atmosphere/dycore/tests/dycore_stencil_tests/test_add_interpolated_horizontal_advection_of_w.py b/model/atmosphere/dycore/tests/dycore_stencil_tests/test_add_interpolated_horizontal_advection_of_w.py
index ce8cf650ba..e53ca35c58 100644
--- a/model/atmosphere/dycore/tests/dycore_stencil_tests/test_add_interpolated_horizontal_advection_of_w.py
+++ b/model/atmosphere/dycore/tests/dycore_stencil_tests/test_add_interpolated_horizontal_advection_of_w.py
@@ -27,8 +27,10 @@ def add_interpolated_horizontal_advection_of_w_numpy(
grid, e_bln_c_s: np.array, z_v_grad_w: np.array, ddt_w_adv: np.array, **kwargs
) -> np.array:
e_bln_c_s = np.expand_dims(e_bln_c_s, axis=-1)
+ c2ce = grid.get_offset_provider("C2CE").table
+
ddt_w_adv = ddt_w_adv + np.sum(
- z_v_grad_w[grid.connectivities[C2EDim]] * e_bln_c_s[grid.get_offset_provider("C2CE").table],
+ z_v_grad_w[grid.connectivities[C2EDim]] * e_bln_c_s[c2ce],
axis=1,
)
return ddt_w_adv
diff --git a/model/atmosphere/dycore/tests/dycore_stencil_tests/test_interpolate_to_cell_center.py b/model/atmosphere/dycore/tests/dycore_stencil_tests/test_interpolate_to_cell_center.py
index 19ce7bbac2..1e24efcf7e 100644
--- a/model/atmosphere/dycore/tests/dycore_stencil_tests/test_interpolate_to_cell_center.py
+++ b/model/atmosphere/dycore/tests/dycore_stencil_tests/test_interpolate_to_cell_center.py
@@ -30,9 +30,10 @@ def interpolate_to_cell_center_numpy(
grid, interpolant: np.array, e_bln_c_s: np.array, **kwargs
) -> np.array:
e_bln_c_s = np.expand_dims(e_bln_c_s, axis=-1)
+ c2ce = grid.get_offset_provider("C2CE").table
+
interpolation = np.sum(
- interpolant[grid.connectivities[C2EDim]]
- * e_bln_c_s[grid.get_offset_provider("C2CE").table],
+ interpolant[grid.connectivities[C2EDim]] * e_bln_c_s[c2ce],
axis=1,
)
return interpolation
diff --git a/model/atmosphere/dycore/tests/dycore_stencil_tests/test_mcompute_divergence_of_fluxes_of_rho_and_theta.py b/model/atmosphere/dycore/tests/dycore_stencil_tests/test_mcompute_divergence_of_fluxes_of_rho_and_theta.py
index f1fff490a3..7e6f2f8cd4 100644
--- a/model/atmosphere/dycore/tests/dycore_stencil_tests/test_mcompute_divergence_of_fluxes_of_rho_and_theta.py
+++ b/model/atmosphere/dycore/tests/dycore_stencil_tests/test_mcompute_divergence_of_fluxes_of_rho_and_theta.py
@@ -28,7 +28,7 @@
from icon4py.model.common.type_alias import vpfloat, wpfloat
-class TestComputeDivergenceOfFluxesOfRhoAndTheta(StencilTest):
+class TestComputeDivergenconnectivityceOfFluxesOfRhoAndTheta(StencilTest):
PROGRAM = compute_divergence_of_fluxes_of_rho_and_theta
OUTPUTS = ("z_flxdiv_mass", "z_flxdiv_theta")
@@ -42,12 +42,14 @@ def reference(
) -> tuple[np.array]:
c2e = grid.connectivities[C2EDim]
geofac_div = np.expand_dims(geofac_div, axis=-1)
+ c2ce = grid.get_offset_provider("C2CE").table
+
z_flxdiv_mass = np.sum(
- geofac_div[grid.get_offset_provider("C2CE").table] * mass_fl_e[c2e],
+ geofac_div[c2ce] * mass_fl_e[c2e],
axis=1,
)
z_flxdiv_theta = np.sum(
- geofac_div[grid.get_offset_provider("C2CE").table] * z_theta_v_fl_e[c2e],
+ geofac_div[c2ce] * z_theta_v_fl_e[c2e],
axis=1,
)
return dict(z_flxdiv_mass=z_flxdiv_mass, z_flxdiv_theta=z_flxdiv_theta)
diff --git a/model/common/src/icon4py/model/common/caching.py b/model/common/src/icon4py/model/common/caching.py
new file mode 100644
index 0000000000..a10e36f420
--- /dev/null
+++ b/model/common/src/icon4py/model/common/caching.py
@@ -0,0 +1,135 @@
+# ICON4Py - ICON inspired code in Python and GT4Py
+#
+# Copyright (c) 2022, ETH Zurich and MeteoSwiss
+# All rights reserved.
+#
+# This file is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import dataclasses
+from typing import Any, Callable, Optional
+
+import numpy as np
+from gt4py import next as gtx
+from gt4py.next.otf import workflow
+from gt4py.next.program_processors.runners.gtfn import extract_connectivity_args
+
+from icon4py.model.common.settings import device
+
+
+try:
+ import cupy as cp
+ from gt4py.next.embedded.nd_array_field import CuPyArrayField
+except ImportError:
+ cp: Optional = None # type:ignore[no-redef]
+
+from gt4py.next.embedded.nd_array_field import NumPyArrayField
+
+
+def handle_numpy_integer(value):
+ return int(value)
+
+
+def handle_common_field(value, sizes):
+ sizes.extend(value.shape)
+ return value # Return the value unmodified, but side-effect on sizes
+
+
+def handle_default(value):
+ return value # Return the value unchanged
+
+
+if cp:
+ type_handlers = {
+ np.integer: handle_numpy_integer,
+ NumPyArrayField: handle_common_field,
+ CuPyArrayField: handle_common_field,
+ }
+else:
+ type_handlers = {
+ np.integer: handle_numpy_integer,
+ NumPyArrayField: handle_common_field,
+ }
+
+
+def process_arg(value, sizes):
+ handler = type_handlers.get(type(value), handle_default)
+ return handler(value, sizes) if handler == handle_common_field else handler(value)
+
+
+@dataclasses.dataclass
+class CachedProgram:
+ """Class to handle caching and compilation of GT4Py programs.
+
+ This class is responsible for caching and compiling GT4Py programs
+ with optional domain information. The compiled program and its
+ connectivity arguments are stored for efficient execution.
+
+ Attributes:
+ program (gtx.ffront.decorator.Program): The GT4Py program to be cached and compiled.
+ with_domain (bool): Flag to indicate if the program should be compiled with domain information. Defaults to True.
+ _compiled_program (Optional[Callable]): The compiled GT4Py program.
+ _conn_args (Any): Connectivity arguments extracted from the offset provider.
+ _compiled_args (tuple): Arguments used during the compilation of the program.
+
+ Properties:
+ compiled_program (Callable): Returns the compiled GT4Py program.
+ conn_args (Any): Returns the connectivity arguments.
+
+ Note:
+ This functionality will be provided by GT4Py in the future.
+ """
+
+ program: gtx.ffront.decorator.Program
+ with_domain: bool = True
+ _compiled_program: Optional[Callable] = None
+ _conn_args: Any = None
+ _compiled_args: tuple = dataclasses.field(default_factory=tuple)
+
+ @property
+ def compiled_program(self) -> Callable:
+ return self._compiled_program
+
+ @property
+ def conn_args(self) -> Callable:
+ return self._conn_args
+
+ def compile_the_program(
+ self, *args, offset_provider: dict[str, gtx.Dimension], **kwargs: Any
+ ) -> Callable:
+ backend = self.program.backend
+ program_call = backend.transforms_prog(
+ workflow.InputWithArgs(
+ data=self.program.definition_stage,
+ args=args,
+ kwargs=kwargs | {"offset_provider": offset_provider},
+ )
+ )
+ self._compiled_args = program_call.args
+ return backend.executor.otf_workflow(program_call)
+
+ def __call__(self, *args, offset_provider: dict[str, gtx.Dimension], **kwargs: Any) -> None:
+ if not self.compiled_program:
+ self._compiled_program = self.compile_the_program(
+ *args, offset_provider=offset_provider, **kwargs
+ )
+ self._conn_args = extract_connectivity_args(offset_provider, device)
+
+ kwargs_as_tuples = tuple(kwargs.values())
+ program_args = list(args) + list(kwargs_as_tuples)
+ sizes = []
+
+ # Convert numpy integers in args to int and handle gtx.common.Field
+ for i in range(len(program_args)):
+ program_args[i] = process_arg(program_args[i], sizes)
+
+ if not self.with_domain:
+ program_args.extend(sizes)
+
+ # todo(samkellerhals): if we merge gt4py PR we can also pass connectivity args here conn_args=self.conn_args
+ return self.compiled_program(*program_args, offset_provider=offset_provider)
diff --git a/model/common/src/icon4py/model/common/config.py b/model/common/src/icon4py/model/common/config.py
index f48034bf07..f7f4389dcc 100644
--- a/model/common/src/icon4py/model/common/config.py
+++ b/model/common/src/icon4py/model/common/config.py
@@ -69,3 +69,7 @@ def device(self):
}
device = device_map[self.icon4py_backend]
return device
+
+ @cached_property
+ def limited_area(self):
+ return os.environ.get("ICON4PY_LAM", False)
diff --git a/model/common/src/icon4py/model/common/grid/base.py b/model/common/src/icon4py/model/common/grid/base.py
index ede8df26d3..fc65c00742 100644
--- a/model/common/src/icon4py/model/common/grid/base.py
+++ b/model/common/src/icon4py/model/common/grid/base.py
@@ -24,6 +24,7 @@
from icon4py.model.common.dimension import CellDim, EdgeDim, KDim, VertexDim
from icon4py.model.common.grid.utils import neighbortable_offset_provider_for_1d_sparse_fields
from icon4py.model.common.grid.vertical import VerticalGridSize
+from icon4py.model.common.settings import xp
from icon4py.model.common.utils import builder
@@ -128,11 +129,6 @@ def _get_offset_provider(self, dim, from_dim, to_dim):
if dim not in self.connectivities:
raise MissingConnectivity()
- if self.config.on_gpu:
- import cupy as xp
- else:
- xp = np
-
return NeighborTableOffsetProvider(
xp.asarray(self.connectivities[dim]),
from_dim,
diff --git a/model/common/src/icon4py/model/common/grid/horizontal.py b/model/common/src/icon4py/model/common/grid/horizontal.py
index 484974f2c5..7ac2050f3b 100644
--- a/model/common/src/icon4py/model/common/grid/horizontal.py
+++ b/model/common/src/icon4py/model/common/grid/horizontal.py
@@ -348,13 +348,13 @@ def __init__(
@dataclass(frozen=True)
class CellParams:
#: Latitude at the cell center. The cell center is defined to be the circumcenter of a triangle.
- cell_center_lat: Field[[CellDim], float]
+ cell_center_lat: Field[[CellDim], float] = None
#: Longitude at the cell center. The cell center is defined to be the circumcenter of a triangle.
- cell_center_lon: Field[[CellDim], float]
+ cell_center_lon: Field[[CellDim], float] = None
#: Area of a cell, defined in ICON in mo_model_domain.f90:t_grid_cells%area
- area: Field[[CellDim], float]
+ area: Field[[CellDim], float] = None
#: Mean area of a cell [m^2] = total surface area / numer of cells defined in ICON in in mo_model_domimp_patches.f90
- mean_cell_area: float
+ mean_cell_area: float = None
length_rescale_factor: float = 1.0
@classmethod
diff --git a/model/common/src/icon4py/model/common/grid/utils.py b/model/common/src/icon4py/model/common/grid/utils.py
index 840d60f32e..e8a02984e8 100644
--- a/model/common/src/icon4py/model/common/grid/utils.py
+++ b/model/common/src/icon4py/model/common/grid/utils.py
@@ -14,6 +14,8 @@
import numpy as np
from gt4py.next import Dimension, NeighborTableOffsetProvider
+from icon4py.model.common.settings import xp
+
def neighbortable_offset_provider_for_1d_sparse_fields(
old_shape: tuple[int, int],
@@ -21,7 +23,7 @@ def neighbortable_offset_provider_for_1d_sparse_fields(
neighbor_axis: Dimension,
has_skip_values: bool,
):
- table = np.arange(old_shape[0] * old_shape[1]).reshape(old_shape)
+ table = xp.asarray(np.arange(old_shape[0] * old_shape[1]).reshape(old_shape))
return NeighborTableOffsetProvider(
table,
origin_axis,
diff --git a/model/common/src/icon4py/model/common/settings.py b/model/common/src/icon4py/model/common/settings.py
index 0f5d1a82c0..b6d482eee5 100644
--- a/model/common/src/icon4py/model/common/settings.py
+++ b/model/common/src/icon4py/model/common/settings.py
@@ -17,3 +17,4 @@
backend = config.gt4py_runner
xp = config.array_ns
device = config.device
+limited_area = config.limited_area
diff --git a/model/common/src/icon4py/model/common/test_utils/grid_utils.py b/model/common/src/icon4py/model/common/test_utils/grid_utils.py
index b3ff01df84..57e3bdd439 100644
--- a/model/common/src/icon4py/model/common/test_utils/grid_utils.py
+++ b/model/common/src/icon4py/model/common/test_utils/grid_utils.py
@@ -17,6 +17,7 @@
from icon4py.model.common.grid.grid_manager import GridManager, ToGt4PyTransformation
from icon4py.model.common.grid.icon import IconGrid
from icon4py.model.common.grid.vertical import VerticalGridSize
+from icon4py.model.common.test_utils.data_handling import download_and_extract
from icon4py.model.common.test_utils.datatest_utils import (
GLOBAL_EXPERIMENT,
GRID_URIS,
@@ -34,7 +35,7 @@
@functools.cache
def get_icon_grid_from_gridfile(experiment: str, on_gpu: bool = False) -> IconGrid:
if experiment == GLOBAL_EXPERIMENT:
- return _load_from_gridfile(
+ return _download_and_load_from_gridfile(
R02B04_GLOBAL,
"icon_grid_0013_R02B04_R.nc",
num_levels=GLOBAL_NUM_LEVELS,
@@ -42,7 +43,7 @@ def get_icon_grid_from_gridfile(experiment: str, on_gpu: bool = False) -> IconGr
limited_area=False,
)
elif experiment == REGIONAL_EXPERIMENT:
- return _load_from_gridfile(
+ return _download_and_load_from_gridfile(
REGIONAL_EXPERIMENT,
"grid.nc",
num_levels=MCH_CH_R04B09_LEVELS,
@@ -53,18 +54,20 @@ def get_icon_grid_from_gridfile(experiment: str, on_gpu: bool = False) -> IconGr
raise ValueError(f"Unknown experiment: {experiment}")
-def _load_from_gridfile(
- file_path: str, filename: str, num_levels: int, on_gpu: bool, limited_area: bool
-) -> IconGrid:
+def download_grid_file(file_path: str, filename: str):
grid_file = GRIDS_PATH.joinpath(file_path, filename)
if not grid_file.exists():
- from icon4py.model.common.test_utils.data_handling import download_and_extract
-
download_and_extract(
GRID_URIS[file_path],
grid_file.parent,
grid_file.parent,
)
+ return grid_file
+
+
+def load_grid_from_file(
+ grid_file: str, num_levels: int, on_gpu: bool, limited_area: bool
+) -> IconGrid:
gm = GridManager(
ToGt4PyTransformation(),
str(grid_file),
@@ -74,6 +77,13 @@ def _load_from_gridfile(
return gm.get_grid()
+def _download_and_load_from_gridfile(
+ file_path: str, filename: str, num_levels: int, on_gpu: bool, limited_area: bool
+) -> IconGrid:
+ grid_file = download_grid_file(file_path, filename)
+ return load_grid_from_file(grid_file, num_levels, on_gpu, limited_area)
+
+
@pytest.fixture
def grid(request):
return request.param
diff --git a/model/requirements.txt b/model/requirements.txt
index 3330309e2e..8da0b03ad2 100644
--- a/model/requirements.txt
+++ b/model/requirements.txt
@@ -2,5 +2,5 @@
./atmosphere/dycore
./atmosphere/diffusion
./atmosphere/advection
-./common
+./common[netcdf]
./driver
diff --git a/tools/README.md b/tools/README.md
index 72132757df..f4d8b7b398 100644
--- a/tools/README.md
+++ b/tools/README.md
@@ -344,7 +344,7 @@ Options:
## Important Environment Variables
When embedding granules it may be necessary to use an ICON grid file, as is the case in the diffusion granule wrapper.
-The granule expects an `ICON_GRID_LOC` environment variable with the path to the folder holding the grid netcdf file.
+The granule expects an `ICON_GRID_LOC` environment variable with the path to the folder containing the different grids.
### Example
diff --git a/tools/requirements-dev.txt b/tools/requirements-dev.txt
index ab8d302ab9..8aee68ae06 100644
--- a/tools/requirements-dev.txt
+++ b/tools/requirements-dev.txt
@@ -2,5 +2,5 @@
-e ../model/atmosphere/dycore
-e ../model/atmosphere/diffusion
-e ../model/atmosphere/advection
--e ../model/common
+-e ../model/common[netcdf]
-e .[all]
diff --git a/tools/requirements.txt b/tools/requirements.txt
index 3fe8f73e35..f5357c839c 100644
--- a/tools/requirements.txt
+++ b/tools/requirements.txt
@@ -2,5 +2,5 @@
../model/atmosphere/dycore
../model/atmosphere/advection
../model/atmosphere/diffusion
-../model/common
+../model/common[netcdf]
.
diff --git a/tools/src/icon4pytools/icon4pygen/bindings/codegen/type_conversion.py b/tools/src/icon4pytools/icon4pygen/bindings/codegen/type_conversion.py
index 752763439b..cb86c05a65 100644
--- a/tools/src/icon4pytools/icon4pygen/bindings/codegen/type_conversion.py
+++ b/tools/src/icon4pytools/icon4pygen/bindings/codegen/type_conversion.py
@@ -28,3 +28,10 @@
ts.ScalarKind.INT32: "int",
ts.ScalarKind.INT64: "long",
}
+BUILTIN_TO_NUMPY_TYPE: dict[ts.ScalarKind, str] = {
+ ts.ScalarKind.FLOAT64: "xp.float64",
+ ts.ScalarKind.FLOAT32: "xp.float32",
+ ts.ScalarKind.BOOL: "xp.int32",
+ ts.ScalarKind.INT32: "xp.int32",
+ ts.ScalarKind.INT64: "xp.int64",
+}
diff --git a/tools/src/icon4pytools/py2fgen/cli.py b/tools/src/icon4pytools/py2fgen/cli.py
index 99db23a5c8..0a80ae9333 100644
--- a/tools/src/icon4pytools/py2fgen/cli.py
+++ b/tools/src/icon4pytools/py2fgen/cli.py
@@ -57,6 +57,7 @@ def parse_comma_separated_list(ctx, param, value) -> list[str]:
is_flag=True,
help="Enable debug mode to log additional Python runtime information.",
)
+@click.option("--limited-area", is_flag=True, help="Enable limited area mode.")
def main(
module_import_path: str,
functions: list[str],
@@ -64,6 +65,7 @@ def main(
output_path: pathlib.Path,
debug_mode: bool,
backend: str,
+ limited_area: str,
) -> None:
"""Generate C and F90 wrappers and C library for embedding a Python module in C and Fortran."""
output_path.mkdir(exist_ok=True, parents=True)
@@ -71,8 +73,8 @@ def main(
plugin = parse(module_import_path, functions, plugin_name)
c_header = generate_c_header(plugin)
- python_wrapper = generate_python_wrapper(plugin, backend, debug_mode)
- f90_interface = generate_f90_interface(plugin)
+ python_wrapper = generate_python_wrapper(plugin, backend, debug_mode, limited_area)
+ f90_interface = generate_f90_interface(plugin, limited_area)
generate_and_compile_cffi_plugin(plugin.plugin_name, c_header, python_wrapper, output_path)
write_string(f90_interface, output_path, f"{plugin.plugin_name}.f90")
diff --git a/tools/src/icon4pytools/py2fgen/generate.py b/tools/src/icon4pytools/py2fgen/generate.py
index 9812d8a3f8..14a92890a1 100644
--- a/tools/src/icon4pytools/py2fgen/generate.py
+++ b/tools/src/icon4pytools/py2fgen/generate.py
@@ -45,7 +45,9 @@ def generate_c_header(plugin: CffiPlugin) -> str:
return codegen.format_source("cpp", generated_code, style="LLVM")
-def generate_python_wrapper(plugin: CffiPlugin, backend: Optional[str], debug_mode: bool) -> str:
+def generate_python_wrapper(
+ plugin: CffiPlugin, backend: Optional[str], debug_mode: bool, limited_area: str
+) -> str:
"""
Generate Python wrapper code.
@@ -65,13 +67,14 @@ def generate_python_wrapper(plugin: CffiPlugin, backend: Optional[str], debug_mo
imports=plugin.imports,
backend=backend,
debug_mode=debug_mode,
+ limited_area=limited_area,
)
generated_code = PythonWrapperGenerator.apply(python_wrapper)
return codegen.format_source("python", generated_code)
-def generate_f90_interface(plugin: CffiPlugin) -> str:
+def generate_f90_interface(plugin: CffiPlugin, limited_area: str) -> str:
"""
Generate Fortran 90 interface code.
@@ -79,5 +82,7 @@ def generate_f90_interface(plugin: CffiPlugin) -> str:
plugin: The CffiPlugin instance containing information for code generation.
"""
logger.info("Generating Fortran interface...")
- generated_code = F90InterfaceGenerator.apply(F90Interface(cffi_plugin=plugin))
+ generated_code = F90InterfaceGenerator.apply(
+ F90Interface(cffi_plugin=plugin, limited_area=limited_area)
+ )
return format_fortran_code(generated_code)
diff --git a/tools/src/icon4pytools/py2fgen/plugin.py b/tools/src/icon4pytools/py2fgen/plugin.py
index 8a85c90594..fb49f08b62 100644
--- a/tools/src/icon4pytools/py2fgen/plugin.py
+++ b/tools/src/icon4pytools/py2fgen/plugin.py
@@ -11,6 +11,7 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
+import math
import typing
from pathlib import Path
@@ -48,7 +49,7 @@ def unpack(ptr, *sizes: int) -> NDArray:
This array shares the underlying data with the original Fortran code, allowing
modifications made through the array to affect the original data.
"""
- length = np.prod(sizes)
+ length = math.prod(sizes)
c_type = ffi.getctype(ffi.typeof(ptr).item)
# Map C data types to NumPy dtypes
@@ -88,7 +89,7 @@ def unpack_gpu(ptr, *sizes: int):
if not sizes:
raise ValueError("Sizes must be provided to determine the array shape.")
- length = np.prod(sizes)
+ length = math.prod(sizes)
c_type = ffi.getctype(ffi.typeof(ptr).item)
dtype_map = {
@@ -108,7 +109,6 @@ def unpack_gpu(ptr, *sizes: int):
mem = cp.cuda.UnownedMemory(ptr_val, total_size, owner=ptr, device_id=current_device.id)
memptr = cp.cuda.MemoryPointer(mem, 0)
arr = cp.ndarray(shape=sizes, dtype=dtype, memptr=memptr, order="F")
-
return arr
diff --git a/tools/src/icon4pytools/py2fgen/template.py b/tools/src/icon4pytools/py2fgen/template.py
index 406c3d03d3..f258ac5a55 100644
--- a/tools/src/icon4pytools/py2fgen/template.py
+++ b/tools/src/icon4pytools/py2fgen/template.py
@@ -22,9 +22,11 @@
from icon4pytools.icon4pygen.bindings.codegen.type_conversion import (
BUILTIN_TO_CPP_TYPE,
BUILTIN_TO_ISO_C_TYPE,
+ BUILTIN_TO_NUMPY_TYPE,
)
from icon4pytools.py2fgen.plugin import int_array_to_bool_array, unpack, unpack_gpu
from icon4pytools.py2fgen.utils import flatten_and_get_unique_elts
+from icon4pytools.py2fgen.wrappers.experiments import UNINITIALISED_ARRAYS
CFFI_DECORATOR = "@ffi.def_extern()"
@@ -44,16 +46,20 @@ class FuncParameter(Node):
size_args: list[str] = datamodels.field(init=False)
is_array: bool = datamodels.field(init=False)
gtdims: list[str] = datamodels.field(init=False)
+ size_args_len: int = datamodels.field(init=False)
+ np_type: str = datamodels.field(init=False)
def __post_init__(self):
self.size_args = dims_to_size_strings(self.dimensions)
+ self.size_args_len = len(self.size_args)
self.is_array = True if len(self.dimensions) >= 1 else False
# We need some fields to have nlevp1 levels on the fortran wrapper side, which we make
# happen by using KHalfDim as a type hint. However, this is not yet supported on the icon4py
- # side. So before generating the python wrapper code, we replace occurences of KHalfDim with KDim
+ # side. So before generating the python wrapper code, we replace occurrences of KHalfDim with KDim
self.gtdims = [
dimension.value.replace("KHalf", "K") + "Dim" for dimension in self.dimensions
]
+ self.np_type = to_np_type(self.d_type)
class Func(Node):
@@ -78,6 +84,7 @@ class CffiPlugin(Node):
class PythonWrapper(CffiPlugin):
backend: str
debug_mode: bool
+ limited_area: bool
cffi_decorator: str = CFFI_DECORATOR
cffi_unpack: str = inspect.getsource(unpack)
cffi_unpack_gpu: str = inspect.getsource(unpack_gpu)
@@ -88,6 +95,11 @@ class PythonWrapper(CffiPlugin):
def __post_init__(self, *args: Any, **kwargs: Any) -> None:
self.gt4py_backend = GT4PyBackend[self.backend].value
self.is_gt4py_program_present = any(func.is_gt4py_program for func in self.functions)
+ self.uninitialised_arrays = get_uninitialised_arrays(self.limited_area)
+
+
+def get_uninitialised_arrays(limited_area: bool):
+ return UNINITIALISED_ARRAYS if not limited_area else []
def build_array_size_args() -> dict[str, str]:
@@ -109,6 +121,11 @@ def to_c_type(scalar_type: ScalarKind) -> str:
return BUILTIN_TO_CPP_TYPE[scalar_type]
+def to_np_type(scalar_type: ScalarKind) -> str:
+ """Convert a scalar type to its corresponding numpy type."""
+ return BUILTIN_TO_NUMPY_TYPE[scalar_type]
+
+
def to_iso_c_type(scalar_type: ScalarKind) -> str:
"""Convert a scalar type to its corresponding ISO C type."""
return BUILTIN_TO_ISO_C_TYPE[scalar_type]
@@ -196,11 +213,13 @@ class PythonWrapperGenerator(TemplatedGenerator):
"""\
# imports for generated wrapper code
import logging
+import math
from {{ plugin_name }} import ffi
import numpy as np
{% if _this_node.backend == 'GPU' %}import cupy as cp {% endif %}
from numpy.typing import NDArray
from gt4py.next.iterator.embedded import np_as_located_field
+from icon4py.model.common.settings import xp
{% if _this_node.is_gt4py_program_present %}
# necessary imports when embedding a gt4py program directly
@@ -260,7 +279,12 @@ def {{ func.name }}_wrapper(
msg = '{{ arg.name }} before unpacking: %s' % str({{ arg.name}})
logging.debug(msg)
{% endif %}
+
+ {%- if arg.name in _this_node.uninitialised_arrays -%}
+ {{ arg.name }} = xp.ones((1,) * {{ arg.size_args_len }}, dtype={{arg.np_type}}, order="F")
+ {%- else -%}
{{ arg.name }} = unpack{%- if _this_node.backend == 'GPU' -%}_gpu{%- endif -%}({{ arg.name }}, {{ ", ".join(arg.size_args) }})
+ {%- endif -%}
{%- if arg.d_type.name == "BOOL" %}
{{ arg.name }} = int_array_to_bool_array({{ arg.name }})
@@ -354,23 +378,33 @@ class DimensionPosition(Node):
class F90FunctionDefinition(Func):
- dimension_size_declarations: Sequence[DimensionPosition] = datamodels.field(init=False)
+ limited_area: bool
+ dimension_positions: Sequence[DimensionPosition] = datamodels.field(init=False)
def __post_init__(self, *args: Any, **kwargs: Any) -> None:
super().__post_init__() # call Func __post_init__
+ self.dimension_positions = self.extract_dimension_positions()
+ self.uninitialised_arrays = get_uninitialised_arrays(self.limited_area)
- dim_positions = []
+ def extract_dimension_positions(self) -> Sequence[DimensionPosition]:
+ """Extract a unique set of dimension positions which are used to infer dimension sizes at runtime."""
+ dim_positions: list[DimensionPosition] = []
+ unique_size_args: set[str] = set()
for arg in self.args:
for index, size_arg in enumerate(arg.size_args):
- dim_positions.append(
- DimensionPosition(variable=str(arg.name), size_arg=size_arg, index=index + 1)
- ) # Use Fortran indexing
-
- self.dimension_size_declarations = dim_positions
+ if size_arg not in unique_size_args:
+ dim_positions.append(
+ DimensionPosition(
+ variable=str(arg.name), size_arg=size_arg, index=index + 1
+ )
+ ) # Use Fortran indexing
+ unique_size_args.add(size_arg)
+ return dim_positions
class F90Interface(Node):
cffi_plugin: CffiPlugin
+ limited_area: bool
function_declaration: list[F90FunctionDeclaration] = datamodels.field(init=False)
function_definition: list[F90FunctionDefinition] = datamodels.field(init=False)
@@ -381,7 +415,12 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None:
for f in functions
]
self.function_definition = [
- F90FunctionDefinition(name=f.name, args=f.args, is_gt4py_program=f.is_gt4py_program)
+ F90FunctionDefinition(
+ name=f.name,
+ args=f.args,
+ is_gt4py_program=f.is_gt4py_program,
+ limited_area=self.limited_area,
+ )
for f in functions
]
@@ -435,18 +474,22 @@ def visit_F90FunctionDefinition(self, func: F90FunctionDefinition, **kwargs):
arg_names = ", &\n ".join(map(lambda x: x.name, func.args))
param_names_with_size_args = arg_names + ",&\n" + ", &\n".join(func.global_size_args)
+ return_code_param = ",&\nrc" if len(func.args) >= 1 else "rc"
+
return self.generic_visit(
func,
assumed_size_array=False,
param_names=arg_names,
param_names_with_size_args=param_names_with_size_args,
- arrays=[arg for arg in func.args if arg.is_array],
+ arrays=set([arg.name for arg in func.args if arg.is_array]).difference(
+ set(func.uninitialised_arrays)
+ ),
+ return_code_param=return_code_param,
)
- # todo(samkellerhals): Consider using unique SIZE args
F90FunctionDefinition = as_jinja(
"""
-subroutine {{name}}({{param_names}}, &\nrc)
+subroutine {{name}}({{param_names}} {{ return_code_param }})
use, intrinsic :: iso_c_binding
{% for size_arg in global_size_args %}
integer(c_int) :: {{ size_arg }}
@@ -456,19 +499,23 @@ def visit_F90FunctionDefinition(self, func: F90FunctionDefinition, **kwargs):
{% endfor %}
integer(c_int) :: rc ! Stores the return code
+ {% if arrays | length >= 1 %}
!$ACC host_data use_device( &
{%- for arr in arrays %}
- !$ACC {{ arr.name }}{% if not loop.last %}, &{% else %} &{% endif %}
+ !$ACC {{ arr }}{% if not loop.last %}, &{% else %} &{% endif %}
{%- endfor %}
!$ACC )
+ {% endif %}
- {% for d in _this_node.dimension_size_declarations %}
+ {% for d in _this_node.dimension_positions %}
{{ d.size_arg }} = SIZE({{ d.variable }}, {{ d.index }})
{% endfor %}
rc = {{ name }}_wrapper({{ param_names_with_size_args }})
+ {% if arrays | length >= 1 %}
!$acc end host_data
+ {% endif %}
end subroutine {{name}}
"""
)
diff --git a/tools/src/icon4pytools/py2fgen/utils.py b/tools/src/icon4pytools/py2fgen/utils.py
index 0313c04aa6..3fad1c34b8 100644
--- a/tools/src/icon4pytools/py2fgen/utils.py
+++ b/tools/src/icon4pytools/py2fgen/utils.py
@@ -31,14 +31,14 @@ def flatten_and_get_unique_elts(list_of_lists: list[list[str]]) -> list[str]:
return sorted(set(item for sublist in list_of_lists for item in sublist))
-def get_local_test_grid():
+def get_local_test_grid(grid_folder: str):
test_folder = "testdata"
module_spec = importlib.util.find_spec("icon4pytools")
if module_spec and module_spec.origin:
# following namespace package conventions the root is three levels down
repo_root = Path(module_spec.origin).parents[3]
- return os.path.join(repo_root, test_folder)
+ return os.path.join(repo_root, test_folder, "grids", grid_folder)
else:
raise FileNotFoundError(
"The `icon4pytools` package could not be found. Ensure the package is installed "
@@ -52,7 +52,9 @@ def get_icon_grid_loc():
if env_path is not None:
return env_path
else:
- return get_local_test_grid()
+ raise ValueError(
+ "Need to define ICON_GRID_LOC environment variable specifying absolute path to folder containing grid."
+ )
def get_grid_filename():
diff --git a/tools/src/icon4pytools/py2fgen/wrappers/diffusion.py b/tools/src/icon4pytools/py2fgen/wrappers/diffusion.py
index b4c7fd7391..d264b17b31 100644
--- a/tools/src/icon4pytools/py2fgen/wrappers/diffusion.py
+++ b/tools/src/icon4pytools/py2fgen/wrappers/diffusion.py
@@ -20,6 +20,9 @@
- all arguments needed from external sources are passed.
- passing of scalar types or fields of simple types
"""
+import cProfile
+import os
+import pstats
from gt4py.next.common import Field
from gt4py.next.ffront.fbuiltins import float64, int32
@@ -52,9 +55,9 @@
)
from icon4py.model.common.grid.horizontal import CellParams, EdgeParams
from icon4py.model.common.grid.vertical import VerticalModelParams
-from icon4py.model.common.settings import device
+from icon4py.model.common.settings import device, limited_area
from icon4py.model.common.states.prognostic_state import PrognosticState
-from icon4py.model.common.test_utils.grid_utils import _load_from_gridfile
+from icon4py.model.common.test_utils.grid_utils import load_grid_from_file
from icon4py.model.common.test_utils.helpers import as_1D_sparse_field, flatten_first_two_dims
from icon4pytools.common.logger import setup_logger
@@ -66,6 +69,19 @@
# global diffusion object
diffusion_granule: Diffusion = Diffusion()
+# global profiler object
+profiler = cProfile.Profile()
+
+
+def profile_enable():
+ profiler.enable()
+
+
+def profile_disable():
+ profiler.disable()
+ stats = pstats.Stats(profiler)
+ stats.dump_stats(f"{__name__}.profile")
+
def diffusion_init(
vct_a: Field[[KHalfDim], float64],
@@ -98,6 +114,11 @@ def diffusion_init(
hdiff_efdt_ratio: float64,
smagorinski_scaling_factor: float64,
hdiff_temp: bool,
+ thslp_zdiffu: float,
+ thhgtd_zdiffu: float,
+ denom_diffu_v: float,
+ nudge_max_coeff: float,
+ itype_sher: int32,
tangent_orientation: Field[[EdgeDim], float64],
inverse_primal_edge_lengths: Field[[EdgeDim], float64],
inv_dual_edge_length: Field[[EdgeDim], float64],
@@ -122,12 +143,13 @@ def diffusion_init(
else:
on_gpu = False
- icon_grid = _load_from_gridfile(
- file_path=get_icon_grid_loc(),
- filename=get_grid_filename(),
+ grid_file_path = os.path.join(get_icon_grid_loc(), get_grid_filename())
+
+ icon_grid = load_grid_from_file(
+ grid_file=grid_file_path,
num_levels=num_levels,
on_gpu=on_gpu,
- limited_area=True,
+ limited_area=True if limited_area else False,
)
# Edge geometry
@@ -163,11 +185,11 @@ def diffusion_init(
smagorinski_scaling_factor=smagorinski_scaling_factor,
hdiff_temp=hdiff_temp,
n_substeps=ndyn_substeps,
- thslp_zdiffu=0.02,
- thhgtd_zdiffu=125.0,
- velocity_boundary_diffusion_denom=150.0,
- max_nudging_coeff=0.075,
- shear_type=TurbulenceShearForcingType.VERTICAL_HORIZONTAL_OF_HORIZONTAL_VERTICAL_WIND,
+ thslp_zdiffu=thslp_zdiffu,
+ thhgtd_zdiffu=thhgtd_zdiffu,
+ velocity_boundary_diffusion_denom=denom_diffu_v,
+ max_nudging_coeff=nudge_max_coeff,
+ shear_type=TurbulenceShearForcingType(itype_sher),
)
diffusion_params = DiffusionParams(config)
diff --git a/tools/src/icon4pytools/py2fgen/wrappers/experiments.py b/tools/src/icon4pytools/py2fgen/wrappers/experiments.py
new file mode 100644
index 0000000000..ab78122b62
--- /dev/null
+++ b/tools/src/icon4pytools/py2fgen/wrappers/experiments.py
@@ -0,0 +1,25 @@
+# ICON4Py - ICON inspired code in Python and GT4Py
+#
+# Copyright (c) 2022, ETH Zurich and MeteoSwiss
+# All rights reserved.
+#
+# This file is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+# these arrays are not initialised in global experiments (e.g. ape_r02b04) and are not used
+# therefore unpacking needs to be skipped as otherwise it will trigger an error.
+UNINITIALISED_ARRAYS = [
+ "mask_hdiff",
+ "zd_diffcoef",
+ "zd_vertoffset",
+ "zd_intcoef",
+ "hdef_ic",
+ "div_ic",
+ "dwdx",
+ "dwdy",
+]
diff --git a/tools/src/icon4pytools/py2fgen/wrappers/simple.py b/tools/src/icon4pytools/py2fgen/wrappers/simple.py
index f7e072ff02..afe53e5fd3 100644
--- a/tools/src/icon4pytools/py2fgen/wrappers/simple.py
+++ b/tools/src/icon4pytools/py2fgen/wrappers/simple.py
@@ -11,13 +11,35 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
# mypy: ignore-errors
+import cProfile
+import pstats
+
from gt4py.next.common import GridType
from gt4py.next.ffront.decorator import field_operator, program
-from gt4py.next.ffront.fbuiltins import Field, float64, int32
-from icon4py.model.common.dimension import CellDim, EdgeDim, KDim
+from gt4py.next.ffront.fbuiltins import Field, float64, int32, neighbor_sum
+from icon4py.model.common.caching import CachedProgram
+from icon4py.model.common.dimension import C2CE, C2E, C2EDim, CEDim, CellDim, EdgeDim, KDim
+from icon4py.model.common.grid.simple import SimpleGrid
+from icon4py.model.common.settings import backend
from icon4py.model.common.type_alias import wpfloat
+# global profiler object
+profiler = cProfile.Profile()
+
+grid = SimpleGrid()
+
+
+def profile_enable():
+ profiler.enable()
+
+
+def profile_disable():
+ profiler.disable()
+ stats = pstats.Stats(profiler)
+ stats.dump_stats(f"{__name__}.profile")
+
+
@field_operator
def _square(
inp: Field[[CellDim, KDim], float64],
@@ -25,7 +47,7 @@ def _square(
return inp**2
-@program(grid_type=GridType.UNSTRUCTURED)
+@program(grid_type=GridType.UNSTRUCTURED, backend=backend)
def square(
inp: Field[[CellDim, KDim], float64],
result: Field[[CellDim, KDim], float64],
@@ -46,20 +68,25 @@ def _multi_return(
mass_fl_e: Field[[EdgeDim, KDim], wpfloat],
vn_traj: Field[[EdgeDim, KDim], wpfloat],
mass_flx_me: Field[[EdgeDim, KDim], wpfloat],
+ geofac_div: Field[[CEDim], wpfloat],
+ z_nabla2_e: Field[[EdgeDim, KDim], wpfloat],
r_nsubsteps: wpfloat,
) -> tuple[Field[[EdgeDim, KDim], wpfloat], Field[[EdgeDim, KDim], wpfloat]]:
"""accumulate_prep_adv_fields stencil formerly known as _mo_solve_nonhydro_stencil_34."""
vn_traj_wp = vn_traj + r_nsubsteps * z_vn_avg
mass_flx_me_wp = mass_flx_me + r_nsubsteps * mass_fl_e
+ z_temp_wp = neighbor_sum(z_nabla2_e(C2E) * geofac_div(C2CE), axis=C2EDim) # noqa: F841
return vn_traj_wp, mass_flx_me_wp
-@program(grid_type=GridType.UNSTRUCTURED)
+@program(grid_type=GridType.UNSTRUCTURED, backend=backend)
def multi_return(
z_vn_avg: Field[[EdgeDim, KDim], wpfloat],
mass_fl_e: Field[[EdgeDim, KDim], wpfloat],
vn_traj: Field[[EdgeDim, KDim], wpfloat],
mass_flx_me: Field[[EdgeDim, KDim], wpfloat],
+ geofac_div: Field[[CEDim], wpfloat],
+ z_nabla2_e: Field[[EdgeDim, KDim], wpfloat],
r_nsubsteps: wpfloat,
horizontal_start: int32,
horizontal_end: int32,
@@ -71,6 +98,8 @@ def multi_return(
mass_fl_e,
vn_traj,
mass_flx_me,
+ geofac_div,
+ z_nabla2_e,
r_nsubsteps,
out=(vn_traj, mass_flx_me),
domain={
@@ -85,3 +114,35 @@ def square_error(
result: Field[[CellDim, KDim], float64],
):
raise Exception("Exception foo occurred")
+
+
+multi_return_cached = CachedProgram(multi_return)
+
+
+def multi_return_from_function(
+ z_vn_avg: Field[[EdgeDim, KDim], wpfloat],
+ mass_fl_e: Field[[EdgeDim, KDim], wpfloat],
+ vn_traj: Field[[EdgeDim, KDim], wpfloat],
+ mass_flx_me: Field[[EdgeDim, KDim], wpfloat],
+ geofac_div: Field[[CEDim], wpfloat],
+ z_nabla2_e: Field[[EdgeDim, KDim], wpfloat],
+ r_nsubsteps: wpfloat,
+ horizontal_start: int32,
+ horizontal_end: int32,
+ vertical_start: int32,
+ vertical_end: int32,
+):
+ multi_return_cached(
+ z_vn_avg,
+ mass_fl_e,
+ vn_traj,
+ mass_flx_me,
+ geofac_div,
+ z_nabla2_e,
+ r_nsubsteps,
+ horizontal_start,
+ horizontal_end,
+ vertical_start,
+ vertical_end,
+ offset_provider=grid.offset_providers,
+ )
diff --git a/tools/tests/py2fgen/fortran_samples/test_diffusion.f90 b/tools/tests/py2fgen/fortran_samples/test_diffusion.f90
index 460de43512..fbb3685375 100644
--- a/tools/tests/py2fgen/fortran_samples/test_diffusion.f90
+++ b/tools/tests/py2fgen/fortran_samples/test_diffusion.f90
@@ -98,6 +98,7 @@ program diffusion_simulation
implicit none
integer(c_int) :: rc
+ integer(c_int) :: n
! Constants and types
integer(c_int), parameter :: num_cells = 20480
@@ -125,7 +126,7 @@ program diffusion_simulation
real(c_double), parameter :: hdiff_efdt_ratio = 24.0
real(c_double), parameter :: smagorinski_scaling_factor = 0.025
logical(c_int), parameter :: hdiff_temp = .true.
- logical(c_int), parameter :: linit = .true.
+ logical(c_int), parameter :: linit = .false.
! Declaring arrays for diffusion_init and diffusion_run
real(c_double), dimension(:), allocatable :: vct_a
@@ -298,8 +299,12 @@ program diffusion_simulation
call exit(1)
end if
+ do n = 1, 60
! Call diffusion_run
+ call profile_enable(rc)
call diffusion_run(w, vn, exner, theta_v, rho, hdef_ic, div_ic, dwdx, dwdy, dtime, linit, rc)
+ call profile_disable(rc)
+ end do
print *, "Python exit code = ", rc
if (rc /= 0) then
diff --git a/tools/tests/py2fgen/fortran_samples/test_multi_return.f90 b/tools/tests/py2fgen/fortran_samples/test_multi_return.f90
index 86c13b7495..f0ae430bac 100644
--- a/tools/tests/py2fgen/fortran_samples/test_multi_return.f90
+++ b/tools/tests/py2fgen/fortran_samples/test_multi_return.f90
@@ -1,27 +1,33 @@
program call_multi_return_cffi_plugin
use, intrinsic :: iso_c_binding
- use multi_return_plugin
+ use multi_return_from_function_plugin
implicit none
- integer(c_int) :: edim, kdim, i, j, horizontal_start, horizontal_end, vertical_start, vertical_end, rc
+ integer(c_int) :: cdim, edim, kdim, cedim, i, j, horizontal_start, horizontal_end, vertical_start, vertical_end, rc, n
logical :: computation_correct
character(len=100) :: str_buffer
real(c_double) :: r_nsubsteps
- real(c_double), dimension(:, :), allocatable :: z_vn_avg, mass_fl_e, vn_traj, mass_flx_me
+ real(c_double), dimension(:, :), allocatable :: z_vn_avg, mass_fl_e, vn_traj, mass_flx_me, z_nabla2_e
+ real(c_double), dimension(:), allocatable :: geofac_div
! array dimensions
+ cdim = 18
edim = 27
kdim = 10
+ cedim = cdim * edim
!$ACC enter data create(z_vn_avg, mass_fl_e, vn_traj, mass_flx_me)
! allocate arrays (allocate in column-major order)
+ allocate(geofac_div(cedim))
+ allocate(z_nabla2_e(edim, kdim))
allocate (z_vn_avg(edim, kdim))
allocate (mass_fl_e(edim, kdim))
allocate (vn_traj(edim, kdim))
allocate (mass_flx_me(edim, kdim))
! initialize arrays and variables
+ geofac_div = 3.5d0
z_vn_avg = 1.0d0
mass_fl_e = 2.0d0
vn_traj = 3.0d0
@@ -46,9 +52,32 @@ program call_multi_return_cffi_plugin
print *, trim(str_buffer)
print *
+ ! call once just so that we compile the code once (profiling becomes easier later)
+ call multi_return_from_function(z_vn_avg, mass_fl_e, vn_traj, mass_flx_me, geofac_div, z_nabla2_e, r_nsubsteps, &
+ horizontal_start, horizontal_end, vertical_start, vertical_end, rc)
+
! call the cffi plugin
- call multi_return(z_vn_avg, mass_fl_e, vn_traj, mass_flx_me, r_nsubsteps, &
+ call profile_enable(rc)
+ do n = 1, 1000
+ call multi_return_from_function(z_vn_avg, mass_fl_e, vn_traj, mass_flx_me, geofac_div, z_nabla2_e, r_nsubsteps, &
horizontal_start, horizontal_end, vertical_start, vertical_end, rc)
+
+ ! print array shapes and values
+ print *, "Arrays after computation:"
+ print *, "First value of vn_traj:", vn_traj(1, 1)
+ print *, "First value of mass_flx_me:", mass_flx_me(1, 1)
+ write (str_buffer, '("Shape of z_vn_avg = ", I2, ",", I2)') size(z_vn_avg, 1), size(z_vn_avg, 2)
+ print *, trim(str_buffer)
+ write (str_buffer, '("Shape of mass_fl_e = ", I2, ",", I2)') size(mass_fl_e, 1), size(mass_fl_e, 2)
+ print *, trim(str_buffer)
+ write (str_buffer, '("Shape of vn_traj = ", I2, ",", I2)') size(vn_traj, 1), size(vn_traj, 2)
+ print *, trim(str_buffer)
+ write (str_buffer, '("Shape of mass_flx_me = ", I2, ",", I2)') size(mass_flx_me, 1), size(mass_flx_me, 2)
+ print *, trim(str_buffer)
+ print *, "passed"
+
+ end do
+ call profile_disable(rc)
print *, "Python exit code = ", rc
if (rc /= 0) then
call exit(1)
@@ -56,29 +85,6 @@ program call_multi_return_cffi_plugin
!$ACC update host(z_vn_avg, mass_fl_e, vn_traj, mass_flx_me)
- ! print array shapes and values before computation
- print *, "Arrays after computation:"
- write (str_buffer, '("Shape of z_vn_avg = ", I2, ",", I2)') size(z_vn_avg, 1), size(z_vn_avg, 2)
- print *, trim(str_buffer)
- write (str_buffer, '("Shape of mass_fl_e = ", I2, ",", I2)') size(mass_fl_e, 1), size(mass_fl_e, 2)
- print *, trim(str_buffer)
- write (str_buffer, '("Shape of vn_traj = ", I2, ",", I2)') size(vn_traj, 1), size(vn_traj, 2)
- print *, trim(str_buffer)
- write (str_buffer, '("Shape of mass_flx_me = ", I2, ",", I2)') size(mass_flx_me, 1), size(mass_flx_me, 2)
- print *, trim(str_buffer)
- print *
-
- ! Assert vn_traj == 12 and mass_flx_me == 22
- computation_correct = .true.
- do i = 1, edim
- do j = 1, kdim
- if (vn_traj(i, j) /= 12.0d0 .or. mass_flx_me(i, j) /= 22.0d0) then
- computation_correct = .false.
- exit
- end if
- end do
- if (.not. computation_correct) exit
- end do
!$ACC end data
!$ACC exit data delete(z_vn_avg, mass_fl_e, vn_traj, mass_flx_me)
@@ -86,11 +92,4 @@ program call_multi_return_cffi_plugin
! deallocate arrays
deallocate (z_vn_avg, mass_fl_e, vn_traj, mass_flx_me)
- ! Check and print the result of the assertion
- if (computation_correct) then
- print *, "passed: vn_traj and mass_flx_me have expected values."
- else
- print *, "failed: vn_traj or mass_flx_me does not have the expected values."
- stop 1
- end if
end program call_multi_return_cffi_plugin
diff --git a/tools/tests/py2fgen/fortran_samples/test_square.f90 b/tools/tests/py2fgen/fortran_samples/test_square.f90
index 24bebd0ea0..29098bcdf4 100644
--- a/tools/tests/py2fgen/fortran_samples/test_square.f90
+++ b/tools/tests/py2fgen/fortran_samples/test_square.f90
@@ -39,6 +39,10 @@ program call_square_wrapper_cffi_plugin
call square_from_function(input, result, rc)
#elif USE_SQUARE_ERROR
call square_error(input, result, rc)
+#elif PROFILE_SQUARE_FROM_FUNCTION
+ call profile_enable(rc)
+ call square_from_function(input, result, rc)
+ call profile_disable(rc)
#else
call square(input, result, rc)
#endif
diff --git a/tools/tests/py2fgen/test_cli.py b/tools/tests/py2fgen/test_cli.py
index ce51e8eb30..2111fd083e 100644
--- a/tools/tests/py2fgen/test_cli.py
+++ b/tools/tests/py2fgen/test_cli.py
@@ -39,7 +39,7 @@ def run_test_case(
module: str,
function: str,
plugin_name: str,
- backend: str,
+ py2fbackend: str,
samples_path: Path,
fortran_driver: str,
compiler: str = "gfortran",
@@ -47,7 +47,7 @@ def run_test_case(
expected_error_code: int = 0,
):
with cli.isolated_filesystem():
- result = cli.invoke(main, [module, function, plugin_name, "-b", backend, "-d"])
+ result = cli.invoke(main, [module, function, plugin_name, "-b", py2fbackend])
assert result.exit_code == 0, "CLI execution failed"
try:
@@ -102,7 +102,7 @@ def run_fortran_executable(plugin_name: str):
@pytest.mark.parametrize(
- "backend, extra_flags",
+ "py2fbackend, extra_flags",
[
("CPU", ("-DUSE_SQUARE_FROM_FUNCTION",)),
("ROUNDTRIP", ""),
@@ -111,7 +111,7 @@ def run_fortran_executable(plugin_name: str):
],
)
def test_py2fgen_compilation_and_execution_square_cpu(
- cli_runner, backend, samples_path, wrapper_module, extra_flags
+ cli_runner, py2fbackend, samples_path, wrapper_module, extra_flags
):
"""Tests embedding Python functions, and GT4Py program directly.
Also tests embedding multiple functions in one shared library.
@@ -121,7 +121,7 @@ def test_py2fgen_compilation_and_execution_square_cpu(
wrapper_module,
"square,square_from_function",
"square_plugin",
- backend,
+ py2fbackend,
samples_path,
"test_square",
extra_compiler_flags=extra_flags,
@@ -143,17 +143,17 @@ def test_py2fgen_python_error_propagation_to_fortran(cli_runner, samples_path, w
)
-@pytest.mark.parametrize("backend", ("CPU", "ROUNDTRIP"))
-def test_py2fgen_compilation_and_execution_multi_return(
- cli_runner, backend, samples_path, wrapper_module
+@pytest.mark.parametrize("py2fbackend", ("CPU",))
+def test_py2fgen_compilation_and_execution_multi_return_profile(
+ cli_runner, py2fbackend, samples_path, wrapper_module
):
"""Tests embedding multi return gt4py program."""
run_test_case(
cli_runner,
wrapper_module,
- "multi_return",
- "multi_return_plugin",
- backend,
+ "multi_return_from_function,profile_enable,profile_disable",
+ "multi_return_from_function_plugin",
+ py2fbackend,
samples_path,
"test_multi_return",
)
@@ -165,7 +165,7 @@ def test_py2fgen_compilation_and_execution_diffusion(cli_runner, samples_path):
run_test_case(
cli_runner,
"icon4pytools.py2fgen.wrappers.diffusion",
- "diffusion_init,diffusion_run",
+ "diffusion_init,diffusion_run,profile_enable,profile_disable",
"diffusion_plugin",
"CPU",
samples_path,
@@ -176,7 +176,7 @@ def test_py2fgen_compilation_and_execution_diffusion(cli_runner, samples_path):
# todo: enable on CI
@pytest.mark.skip("Requires setting various environment variables.")
@pytest.mark.parametrize(
- "function_name, plugin_name, test_name, backend, extra_flags",
+ "function_name, plugin_name, test_name, py2fbackend, extra_flags",
[
("square", "square_plugin", "test_square", "GPU", ("-acc", "-Minfo=acc")),
("multi_return", "multi_return_plugin", "test_multi_return", "GPU", ("-acc", "-Minfo=acc")),
@@ -187,7 +187,7 @@ def test_py2fgen_compilation_and_execution_gpu(
function_name,
plugin_name,
test_name,
- backend,
+ py2fbackend,
samples_path,
wrapper_module,
extra_flags,
@@ -197,7 +197,7 @@ def test_py2fgen_compilation_and_execution_gpu(
wrapper_module,
function_name,
plugin_name,
- backend,
+ py2fbackend,
samples_path,
test_name,
"/opt/nvidia/hpc_sdk/Linux_x86_64/2024/compilers/bin/nvfortran", # Ensure NVFORTRAN_COMPILER is set in your environment variables
@@ -209,11 +209,11 @@ def test_py2fgen_compilation_and_execution_gpu(
# Need to compile using nvfortran, and set CUDACXX path to nvcc cuda compiler. Also need to set ICON_GRID_LOC for path to gridfile, and ICON4PY_BACKEND to determine device at runtime.
@pytest.mark.skip("Requires setting various environment variables.")
@pytest.mark.parametrize(
- "backend, extra_flags",
+ "py2fbackend, extra_flags",
[("GPU", ("-acc", "-Minfo=acc"))],
)
def test_py2fgen_compilation_and_execution_diffusion_gpu(
- cli_runner, samples_path, backend, extra_flags
+ cli_runner, samples_path, py2fbackend, extra_flags
):
# todo: requires setting ICON_GRID_LOC
run_test_case(
@@ -221,9 +221,31 @@ def test_py2fgen_compilation_and_execution_diffusion_gpu(
"icon4pytools.py2fgen.wrappers.diffusion",
"diffusion_init,diffusion_run",
"diffusion_plugin",
- backend,
+ py2fbackend,
samples_path,
"test_diffusion",
"/opt/nvidia/hpc_sdk/Linux_x86_64/2024/compilers/bin/nvfortran", # todo: set nvfortran location in base.yml file.
extra_flags,
)
+
+
+@pytest.mark.parametrize(
+ "py2fbackend, extra_flags",
+ [
+ ("CPU", ("-DPROFILE_SQUARE_FROM_FUNCTION",)),
+ ],
+)
+def test_py2fgen_compilation_and_profiling(
+ cli_runner, py2fbackend, samples_path, wrapper_module, extra_flags
+):
+ """Test profiling using cProfile of the generated wrapper."""
+ run_test_case(
+ cli_runner,
+ wrapper_module,
+ "square_from_function,profile_enable,profile_disable",
+ "square_plugin",
+ py2fbackend,
+ samples_path,
+ "test_square",
+ extra_compiler_flags=extra_flags,
+ )
diff --git a/tools/tests/py2fgen/test_codegen.py b/tools/tests/py2fgen/test_codegen.py
index 1ff57f9410..319048a62a 100644
--- a/tools/tests/py2fgen/test_codegen.py
+++ b/tools/tests/py2fgen/test_codegen.py
@@ -118,7 +118,7 @@ def dummy_plugin():
def test_fortran_interface(dummy_plugin):
- interface = generate_f90_interface(dummy_plugin)
+ interface = generate_f90_interface(dummy_plugin, limited_area=True)
expected = """
module libtest_plugin
use, intrinsic :: iso_c_binding
@@ -238,15 +238,17 @@ def test_fortran_interface(dummy_plugin):
def test_python_wrapper(dummy_plugin):
- interface = generate_python_wrapper(dummy_plugin, "GPU", False)
+ interface = generate_python_wrapper(dummy_plugin, "GPU", False, limited_area=True)
expected = '''
# imports for generated wrapper code
import logging
+import math
from libtest_plugin import ffi
import numpy as np
import cupy as cp
from numpy.typing import NDArray
from gt4py.next.iterator.embedded import np_as_located_field
+from icon4py.model.common.settings import xp
# logger setup
log_format = '%(asctime)s.%(msecs)03d - %(levelname)s - %(message)s'
@@ -282,11 +284,10 @@ def unpack_gpu(ptr, *sizes: int):
This array shares the underlying data with the original Fortran code, allowing
modifications made through the array to affect the original data.
"""
-
if not sizes:
raise ValueError("Sizes must be provided to determine the array shape.")
- length = np.prod(sizes)
+ length = math.prod(sizes)
c_type = ffi.getctype(ffi.typeof(ptr).item)
dtype_map = {
@@ -306,7 +307,6 @@ def unpack_gpu(ptr, *sizes: int):
mem = cp.cuda.UnownedMemory(ptr_val, total_size, owner=ptr, device_id=current_device.id)
memptr = cp.cuda.MemoryPointer(mem, 0)
arr = cp.ndarray(shape=sizes, dtype=dtype, memptr=memptr, order="F")
-
return arr
def int_array_to_bool_array(int_array: NDArray) -> NDArray:
diff --git a/tools/tests/py2fgen/test_diffusion_wrapper.py b/tools/tests/py2fgen/test_diffusion_wrapper.py
index cb49e594c1..36b1705385 100644
--- a/tools/tests/py2fgen/test_diffusion_wrapper.py
+++ b/tools/tests/py2fgen/test_diffusion_wrapper.py
@@ -11,7 +11,6 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
# type: ignore
-import pytest
from gt4py.next import as_field
from icon4py.model.atmosphere.diffusion.diffusion import DiffusionType
from icon4py.model.common.dimension import (
@@ -27,30 +26,29 @@
VertexDim,
)
from icon4py.model.common.settings import xp
+from icon4py.model.common.test_utils.grid_utils import MCH_CH_R04B09_LEVELS
from icon4pytools.py2fgen.wrappers.diffusion import diffusion_init, diffusion_run
-# todo(samkellerhals): turn on and off using a marker/option (required ICON_GRID_LOC)
-@pytest.mark.skip
-def test_diffusion_wrapper_py():
+def test_diffusion_wrapper_interface():
# grid parameters
- num_cells = 20480
- num_edges = 30720
- num_vertices = 10242
- num_levels = 60
+ num_cells = 20896
+ num_edges = 31558
+ num_vertices = 10663
+ num_levels = MCH_CH_R04B09_LEVELS
num_c2ec2o = 4
num_v2e = 6
- num_c2e = 2
+ num_c2e = 3
num_e2c2v = 4
num_c2e2c = 3
- num_e2c = 3
+ num_e2c = 2
mean_cell_area = 24907282236.708576
# other configuration parameters
ndyn_substeps = 2
- dtime = 2.0
- rayleigh_damping_height = 50000
+ dtime = 10.0
+ rayleigh_damping_height = 12500.0
nflatlev = 30
nflat_gradp = 59
@@ -64,6 +62,11 @@ def test_diffusion_wrapper_py():
hdiff_efdt_ratio = 24.0
smagorinski_scaling_factor = 0.025
hdiff_temp = True
+ thslp_zdiffu = 0.02
+ thhgtd_zdiffu = 125.0
+ denom_diffu_v = 150.0
+ nudge_max_coeff = 0.075
+ itype_sher = 2 # TurbulenceShearForcingType.VERTICAL_HORIZONTAL_OF_HORIZONTAL_VERTICAL_WIND
# input data - numpy
rng = xp.random.default_rng()
@@ -183,6 +186,11 @@ def test_diffusion_wrapper_py():
hdiff_efdt_ratio=hdiff_efdt_ratio,
smagorinski_scaling_factor=smagorinski_scaling_factor,
hdiff_temp=hdiff_temp,
+ thslp_zdiffu=thslp_zdiffu,
+ thhgtd_zdiffu=thhgtd_zdiffu,
+ denom_diffu_v=denom_diffu_v,
+ nudge_max_coeff=nudge_max_coeff,
+ itype_sher=itype_sher,
tangent_orientation=tangent_orientation,
inverse_primal_edge_lengths=inverse_primal_edge_lengths,
inv_dual_edge_length=inv_dual_edge_length,
diff --git a/tools/tox.ini b/tools/tox.ini
index 0c8b7eb934..ad79aebbda 100644
--- a/tools/tox.ini
+++ b/tools/tox.ini
@@ -11,6 +11,7 @@ skipsdist = true
passenv =
PIP_USER
PYTHONUSERBASE
+ ICON_GRID_LOC
deps =
-r {toxinidir}/requirements-dev.txt
commands =
@@ -25,7 +26,6 @@ allowlist_externals =
rm
[testenv:dev]
-basepython = python3.10
setenv =
PIP_SRC = _external_src
skip_install = true