Skip to content

Commit

Permalink
Push index resolution to init (#539)
Browse files Browse the repository at this point in the history
* Diffusion: move start/end index lookup to init

* Velocity advection: move start/end index lookup to init

* Solve non hydro: move start/end index lookup to init

* pre-commit
  • Loading branch information
halungge authored Sep 5, 2024
1 parent 2e47d34 commit c33682e
Show file tree
Hide file tree
Showing 4 changed files with 349 additions and 403 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -378,14 +378,6 @@ def init(

self._allocate_temporary_fields()

def _get_start_index_for_w_diffusion() -> gtx.int32:
cell_domain = h_grid.domain(dims.CellDim)
return (
self.grid.start_index(cell_domain(h_grid.Zone.NUDGING))
if self.grid.limited_area
else self.grid.start_index(cell_domain(h_grid.Zone.LATERAL_BOUNDARY_LEVEL_4))
)

self.nudgezone_diff: float = 0.04 / (params.scaled_nudge_max_coeff + sys.float_info.epsilon)
self.bdy_diff: float = 0.015 / (params.scaled_nudge_max_coeff + sys.float_info.epsilon)
self.fac_bdydiff_v: float = (
Expand Down Expand Up @@ -414,7 +406,8 @@ def _get_start_index_for_w_diffusion() -> gtx.int32:
physical_heights=self.vertical_grid.interface_physical_height,
nrdmax=self.vertical_grid.end_index_of_damping_layer,
)
self._horizontal_start_index_w_diffusion = _get_start_index_for_w_diffusion()
self._determine_horizontal_domains()

self._initialized = True

@property
Expand Down Expand Up @@ -444,6 +437,41 @@ def _allocate_temporary_fields(self):
xp.zeros((self.grid.num_cells, self.grid.num_levels + 1), dtype=float),
)

def _determine_horizontal_domains(self):
cell_domain = h_grid.domain(dims.CellDim)
edge_domain = h_grid.domain(dims.EdgeDim)
vertex_domain = h_grid.domain(dims.VertexDim)

def _get_start_index_for_w_diffusion() -> gtx.int32:
return (
self.grid.start_index(cell_domain(h_grid.Zone.NUDGING))
if self.grid.limited_area
else self.grid.start_index(cell_domain(h_grid.Zone.LATERAL_BOUNDARY_LEVEL_4))
)

self._cell_start_interior = self.grid.start_index(cell_domain(h_grid.Zone.INTERIOR))
self._cell_start_nudging = self.grid.start_index(cell_domain(h_grid.Zone.NUDGING))
self._cell_end_local = self.grid.end_index(cell_domain(h_grid.Zone.LOCAL))
self._cell_end_halo = self.grid.end_index(cell_domain(h_grid.Zone.HALO))

self._edge_start_lateral_boundary_level_5 = self.grid.start_index(
edge_domain(h_grid.Zone.LATERAL_BOUNDARY_LEVEL_5)
)
self._edge_start_nudging = self.grid.start_index(edge_domain(h_grid.Zone.NUDGING))
self._edge_start_nudging_level_2 = self.grid.start_index(
edge_domain(h_grid.Zone.NUDGING_LEVEL_2)
)
self._edge_end_local = self.grid.end_index(edge_domain(h_grid.Zone.LOCAL))
self._edge_end_halo = self.grid.end_index(edge_domain(h_grid.Zone.HALO))
self._edge_end_halo_level_2 = self.grid.end_index(edge_domain(h_grid.Zone.HALO_LEVEL_2))

self._vertex_start_lateral_boundary_level_2 = self.grid.start_index(
vertex_domain(h_grid.Zone.LATERAL_BOUNDARY_LEVEL_2)
)
self._vertex_end_local = self.grid.end_index(vertex_domain(h_grid.Zone.LOCAL))

self._horizontal_start_index_w_diffusion = _get_start_index_for_w_diffusion()

def initial_run(
self,
diagnostic_state: diffusion_states.DiffusionDiagnosticState,
Expand Down Expand Up @@ -540,30 +568,6 @@ def _do_diffusion_step(
"""
num_levels = self.grid.num_levels
cell_domain = h_grid.domain(dims.CellDim)
cell_start_interior = self.grid.start_index(cell_domain(h_grid.Zone.INTERIOR))

cell_start_nudging = self.grid.start_index(cell_domain(h_grid.Zone.NUDGING))
cell_end_local = self.grid.end_index(cell_domain(h_grid.Zone.LOCAL))
cell_end_halo = self.grid.end_index(cell_domain(h_grid.Zone.HALO))

edge_domain = h_grid.domain(dims.EdgeDim)
edge_start_nudging_level_2 = self.grid.start_index(edge_domain(h_grid.Zone.NUDGING_LEVEL_2))
edge_start_nudging = self.grid.start_index(edge_domain(h_grid.Zone.NUDGING))

edge_start_lateral_boundary_level_5 = self.grid.start_index(
edge_domain(h_grid.Zone.LATERAL_BOUNDARY_LEVEL_5)
)
edge_end_local = self.grid.end_index(edge_domain(h_grid.Zone.LOCAL))
edge_end_halo_level_2 = self.grid.end_index(edge_domain(h_grid.Zone.HALO_LEVEL_2))
edge_end_halo = self.grid.end_index(edge_domain(h_grid.Zone.HALO))

vertex_domain = h_grid.domain(dims.VertexDim)
vertex_start_lateral_boundary_level_2 = self.grid.start_index(
vertex_domain(h_grid.Zone.LATERAL_BOUNDARY_LEVEL_2)
)
vertex_end_local = self.grid.end_index(vertex_domain(h_grid.Zone.LOCAL))

# dtime dependent: enh_smag_factor,
cached.scale_k(self.enh_smag_fac, dtime, self.diff_multfac_smag, offset_provider={})

Expand All @@ -574,8 +578,8 @@ def _do_diffusion_step(
ptr_coeff_2=self.interpolation_state.rbf_coeff_2,
p_u_out=self.u_vert,
p_v_out=self.v_vert,
horizontal_start=vertex_start_lateral_boundary_level_2,
horizontal_end=vertex_end_local,
horizontal_start=self._vertex_start_lateral_boundary_level_2,
horizontal_end=self._vertex_end_local,
vertical_start=0,
vertical_end=num_levels,
offset_provider=self.grid.offset_providers,
Expand Down Expand Up @@ -605,8 +609,8 @@ def _do_diffusion_step(
kh_smag_ec=self.kh_smag_ec,
z_nabla2_e=self.z_nabla2_e,
smag_offset=smag_offset,
horizontal_start=edge_start_lateral_boundary_level_5,
horizontal_end=edge_end_halo_level_2,
horizontal_start=self._edge_start_lateral_boundary_level_5,
horizontal_end=self._edge_end_halo_level_2,
vertical_start=0,
vertical_end=num_levels,
offset_provider=self.grid.offset_providers,
Expand All @@ -628,8 +632,8 @@ def _do_diffusion_step(
wgtfac_c=self.metric_state.wgtfac_c,
div_ic=diagnostic_state.div_ic,
hdef_ic=diagnostic_state.hdef_ic,
horizontal_start=cell_start_nudging,
horizontal_end=cell_end_local,
horizontal_start=self._cell_start_nudging,
horizontal_end=self._cell_end_local,
vertical_start=1,
vertical_end=num_levels,
offset_provider=self.grid.offset_providers,
Expand All @@ -652,8 +656,8 @@ def _do_diffusion_step(
ptr_coeff_2=self.interpolation_state.rbf_coeff_2,
p_u_out=self.u_vert,
p_v_out=self.v_vert,
horizontal_start=vertex_start_lateral_boundary_level_2,
horizontal_end=vertex_end_local,
horizontal_start=self._vertex_start_lateral_boundary_level_2,
horizontal_end=self._vertex_end_local,
vertical_start=0,
vertical_end=num_levels,
offset_provider=self.grid.offset_providers,
Expand Down Expand Up @@ -682,10 +686,10 @@ def _do_diffusion_step(
edge=self.horizontal_edge_index,
nudgezone_diff=self.nudgezone_diff,
fac_bdydiff_v=self.fac_bdydiff_v,
start_2nd_nudge_line_idx_e=gtx.int32(edge_start_nudging_level_2),
start_2nd_nudge_line_idx_e=self._edge_start_nudging_level_2,
limited_area=self.grid.limited_area,
horizontal_start=edge_start_lateral_boundary_level_5,
horizontal_end=edge_end_local,
horizontal_start=self._edge_start_lateral_boundary_level_5,
horizontal_end=self._edge_end_local,
vertical_start=0,
vertical_end=num_levels,
offset_provider=self.grid.offset_providers,
Expand Down Expand Up @@ -717,10 +721,10 @@ def _do_diffusion_step(
nrdmax=gtx.int32(
self.vertical_grid.end_index_of_damping_layer + 1
), # +1 since Fortran includes boundaries
interior_idx=gtx.int32(cell_start_interior),
halo_idx=gtx.int32(cell_end_local),
interior_idx=self._cell_start_interior,
halo_idx=self._cell_end_local,
horizontal_start=self._horizontal_start_index_w_diffusion,
horizontal_end=cell_end_halo,
horizontal_end=self._cell_end_halo,
vertical_start=0,
vertical_end=num_levels,
offset_provider=self.grid.offset_providers,
Expand All @@ -739,8 +743,8 @@ def _do_diffusion_step(
thresh_tdiff=self.thresh_tdiff,
smallest_vpfloat=constants.DBL_EPS,
kh_smag_e=self.kh_smag_e,
horizontal_start=edge_start_nudging,
horizontal_end=edge_end_halo,
horizontal_start=self._edge_start_nudging,
horizontal_end=self._edge_end_halo,
vertical_start=(num_levels - 2),
vertical_end=num_levels,
offset_provider=self.grid.offset_providers,
Expand All @@ -756,8 +760,8 @@ def _do_diffusion_step(
theta_v=prognostic_state.theta_v,
geofac_div=self.interpolation_state.geofac_div,
z_temp=self.z_temp,
horizontal_start=cell_start_nudging,
horizontal_end=cell_end_local,
horizontal_start=self._cell_start_nudging,
horizontal_end=self._cell_end_local,
vertical_start=0,
vertical_end=num_levels,
offset_provider=self.grid.offset_providers,
Expand All @@ -776,8 +780,8 @@ def _do_diffusion_step(
vcoef=self.metric_state.zd_intcoef,
theta_v=prognostic_state.theta_v,
z_temp=self.z_temp,
horizontal_start=cell_start_nudging,
horizontal_end=cell_end_local,
horizontal_start=self._cell_start_nudging,
horizontal_end=self._cell_end_local,
vertical_start=0,
vertical_end=num_levels,
offset_provider=self.grid.offset_providers,
Expand All @@ -793,8 +797,8 @@ def _do_diffusion_step(
theta_v=prognostic_state.theta_v,
exner=prognostic_state.exner,
rd_o_cvd=self.rd_o_cvd,
horizontal_start=cell_start_nudging,
horizontal_end=cell_end_local,
horizontal_start=self._cell_start_nudging,
horizontal_end=self._cell_end_local,
vertical_start=0,
vertical_end=num_levels,
offset_provider={},
Expand Down
Loading

0 comments on commit c33682e

Please sign in to comment.