-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Py2F]: add profiling support & optimisations (#449)
Adds CachedProgram to diffusion stencils, as well as other optimisations for py2f and changes so that APE experiments can be run with py2f.
- Loading branch information
1 parent
c918d52
commit 1752f87
Showing
34 changed files
with
602 additions
and
165 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
79 changes: 79 additions & 0 deletions
79
model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/cached.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# ICON4Py - ICON inspired code in Python and GT4Py | ||
# | ||
# Copyright (c) 2022, ETH Zurich and MeteoSwiss | ||
# All rights reserved. | ||
# | ||
# This file is free software: you can redistribute it and/or modify it under | ||
# the terms of the GNU General Public License as published by the | ||
# Free Software Foundation, either version 3 of the License, or any later | ||
# version. See the LICENSE.txt file at the top-level directory of this | ||
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>. | ||
# | ||
# SPDX-License-Identifier: GPL-3.0-or-later | ||
|
||
from icon4py.model.atmosphere.diffusion.diffusion_utils import ( | ||
copy_field as copy_field_orig, | ||
init_diffusion_local_fields_for_regular_timestep as init_diffusion_local_fields_for_regular_timestep_orig, | ||
scale_k as scale_k_orig, | ||
setup_fields_for_initial_step as setup_fields_for_initial_step_orig, | ||
) | ||
from icon4py.model.atmosphere.diffusion.stencils.apply_diffusion_to_vn import ( | ||
apply_diffusion_to_vn as apply_diffusion_to_vn_orig, | ||
) | ||
from icon4py.model.atmosphere.diffusion.stencils.apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence import ( | ||
apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence as apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence_orig, | ||
) | ||
from icon4py.model.atmosphere.diffusion.stencils.calculate_diagnostic_quantities_for_turbulence import ( | ||
calculate_diagnostic_quantities_for_turbulence as calculate_diagnostic_quantities_for_turbulence_orig, | ||
) | ||
from icon4py.model.atmosphere.diffusion.stencils.calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools import ( | ||
calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools as calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools_orig, | ||
) | ||
from icon4py.model.atmosphere.diffusion.stencils.calculate_nabla2_and_smag_coefficients_for_vn import ( | ||
calculate_nabla2_and_smag_coefficients_for_vn as calculate_nabla2_and_smag_coefficients_for_vn_orig, | ||
) | ||
from icon4py.model.atmosphere.diffusion.stencils.calculate_nabla2_for_theta import ( | ||
calculate_nabla2_for_theta as calculate_nabla2_for_theta_orig, | ||
) | ||
from icon4py.model.atmosphere.diffusion.stencils.truly_horizontal_diffusion_nabla_of_theta_over_steep_points import ( | ||
truly_horizontal_diffusion_nabla_of_theta_over_steep_points as truly_horizontal_diffusion_nabla_of_theta_over_steep_points_orig, | ||
) | ||
from icon4py.model.atmosphere.diffusion.stencils.update_theta_and_exner import ( | ||
update_theta_and_exner as update_theta_and_exner_orig, | ||
) | ||
from icon4py.model.common.caching import CachedProgram | ||
from icon4py.model.common.interpolation.stencils.mo_intp_rbf_rbf_vec_interpol_vertex import ( | ||
mo_intp_rbf_rbf_vec_interpol_vertex as mo_intp_rbf_rbf_vec_interpol_vertex_orig, | ||
) | ||
|
||
|
||
# diffusion run stencils | ||
apply_diffusion_to_vn = CachedProgram(apply_diffusion_to_vn_orig) | ||
apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence = CachedProgram( | ||
apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence_orig | ||
) | ||
calculate_diagnostic_quantities_for_turbulence = CachedProgram( | ||
calculate_diagnostic_quantities_for_turbulence_orig | ||
) | ||
calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools = CachedProgram( | ||
calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools_orig | ||
) | ||
calculate_nabla2_and_smag_coefficients_for_vn = CachedProgram( | ||
calculate_nabla2_and_smag_coefficients_for_vn_orig | ||
) | ||
calculate_nabla2_for_theta = CachedProgram(calculate_nabla2_for_theta_orig) | ||
truly_horizontal_diffusion_nabla_of_theta_over_steep_points = CachedProgram( | ||
truly_horizontal_diffusion_nabla_of_theta_over_steep_points_orig | ||
) | ||
update_theta_and_exner = CachedProgram(update_theta_and_exner_orig) | ||
|
||
mo_intp_rbf_rbf_vec_interpol_vertex = CachedProgram(mo_intp_rbf_rbf_vec_interpol_vertex_orig) | ||
|
||
|
||
# model init stencils | ||
setup_fields_for_initial_step = CachedProgram(setup_fields_for_initial_step_orig, with_domain=False) | ||
copy_field = CachedProgram(copy_field_orig, with_domain=False) | ||
init_diffusion_local_fields_for_regular_timestep = CachedProgram( | ||
init_diffusion_local_fields_for_regular_timestep_orig, with_domain=False | ||
) | ||
scale_k = CachedProgram(scale_k_orig, with_domain=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# ICON4Py - ICON inspired code in Python and GT4Py | ||
# | ||
# Copyright (c) 2022, ETH Zurich and MeteoSwiss | ||
# All rights reserved. | ||
# | ||
# This file is free software: you can redistribute it and/or modify it under | ||
# the terms of the GNU General Public License as published by the | ||
# Free Software Foundation, either version 3 of the License, or any later | ||
# version. See the LICENSE.txt file at the top-level directory of this | ||
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>. | ||
# | ||
# SPDX-License-Identifier: GPL-3.0-or-later | ||
|
||
import dataclasses | ||
from typing import Any, Callable, Optional | ||
|
||
import numpy as np | ||
from gt4py import next as gtx | ||
from gt4py.next.otf import workflow | ||
from gt4py.next.program_processors.runners.gtfn import extract_connectivity_args | ||
|
||
from icon4py.model.common.settings import device | ||
|
||
|
||
try: | ||
import cupy as cp | ||
from gt4py.next.embedded.nd_array_field import CuPyArrayField | ||
except ImportError: | ||
cp: Optional = None # type:ignore[no-redef] | ||
|
||
from gt4py.next.embedded.nd_array_field import NumPyArrayField | ||
|
||
|
||
def handle_numpy_integer(value): | ||
return int(value) | ||
|
||
|
||
def handle_common_field(value, sizes): | ||
sizes.extend(value.shape) | ||
return value # Return the value unmodified, but side-effect on sizes | ||
|
||
|
||
def handle_default(value): | ||
return value # Return the value unchanged | ||
|
||
|
||
if cp: | ||
type_handlers = { | ||
np.integer: handle_numpy_integer, | ||
NumPyArrayField: handle_common_field, | ||
CuPyArrayField: handle_common_field, | ||
} | ||
else: | ||
type_handlers = { | ||
np.integer: handle_numpy_integer, | ||
NumPyArrayField: handle_common_field, | ||
} | ||
|
||
|
||
def process_arg(value, sizes): | ||
handler = type_handlers.get(type(value), handle_default) | ||
return handler(value, sizes) if handler == handle_common_field else handler(value) | ||
|
||
|
||
@dataclasses.dataclass | ||
class CachedProgram: | ||
"""Class to handle caching and compilation of GT4Py programs. | ||
This class is responsible for caching and compiling GT4Py programs | ||
with optional domain information. The compiled program and its | ||
connectivity arguments are stored for efficient execution. | ||
Attributes: | ||
program (gtx.ffront.decorator.Program): The GT4Py program to be cached and compiled. | ||
with_domain (bool): Flag to indicate if the program should be compiled with domain information. Defaults to True. | ||
_compiled_program (Optional[Callable]): The compiled GT4Py program. | ||
_conn_args (Any): Connectivity arguments extracted from the offset provider. | ||
_compiled_args (tuple): Arguments used during the compilation of the program. | ||
Properties: | ||
compiled_program (Callable): Returns the compiled GT4Py program. | ||
conn_args (Any): Returns the connectivity arguments. | ||
Note: | ||
This functionality will be provided by GT4Py in the future. | ||
""" | ||
|
||
program: gtx.ffront.decorator.Program | ||
with_domain: bool = True | ||
_compiled_program: Optional[Callable] = None | ||
_conn_args: Any = None | ||
_compiled_args: tuple = dataclasses.field(default_factory=tuple) | ||
|
||
@property | ||
def compiled_program(self) -> Callable: | ||
return self._compiled_program | ||
|
||
@property | ||
def conn_args(self) -> Callable: | ||
return self._conn_args | ||
|
||
def compile_the_program( | ||
self, *args, offset_provider: dict[str, gtx.Dimension], **kwargs: Any | ||
) -> Callable: | ||
backend = self.program.backend | ||
program_call = backend.transforms_prog( | ||
workflow.InputWithArgs( | ||
data=self.program.definition_stage, | ||
args=args, | ||
kwargs=kwargs | {"offset_provider": offset_provider}, | ||
) | ||
) | ||
self._compiled_args = program_call.args | ||
return backend.executor.otf_workflow(program_call) | ||
|
||
def __call__(self, *args, offset_provider: dict[str, gtx.Dimension], **kwargs: Any) -> None: | ||
if not self.compiled_program: | ||
self._compiled_program = self.compile_the_program( | ||
*args, offset_provider=offset_provider, **kwargs | ||
) | ||
self._conn_args = extract_connectivity_args(offset_provider, device) | ||
|
||
kwargs_as_tuples = tuple(kwargs.values()) | ||
program_args = list(args) + list(kwargs_as_tuples) | ||
sizes = [] | ||
|
||
# Convert numpy integers in args to int and handle gtx.common.Field | ||
for i in range(len(program_args)): | ||
program_args[i] = process_arg(program_args[i], sizes) | ||
|
||
if not self.with_domain: | ||
program_args.extend(sizes) | ||
|
||
# todo(samkellerhals): if we merge gt4py PR we can also pass connectivity args here conn_args=self.conn_args | ||
return self.compiled_program(*program_args, offset_provider=offset_provider) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.