Skip to content

Commit

Permalink
Cache the distributed geometry
Browse files Browse the repository at this point in the history
  • Loading branch information
gaohao95 committed Feb 6, 2024
1 parent c270674 commit 58562f6
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 89 deletions.
183 changes: 94 additions & 89 deletions pytential/qbx/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,98 @@ def compute_local_geometry_data(
qbx_center_to_target_box_source_level))


def distribute_geo_data(comm, actx, insn, bound_expr, evaluate,
global_geo_data_device):
geo_data_cache = bound_expr._geo_data_cache

if insn in geo_data_cache:
return geo_data_cache[insn]

boxes_time = None
global_geo_data = None

if comm.Get_rank() == 0:
# Use the cost model to estimate execution time for partitioning
from pytential.qbx.cost import AbstractQBXCostModel, QBXCostModel

# FIXME: If the expansion wrangler is not FMMLib, the argument
# 'uses_pde_expansions' might be different
cost_model = QBXCostModel()

import warnings
warnings.warn(
"Kernel-specific calibration parameters are not supplied when"
"using distributed FMM.")
# TODO: supply better default calibration parameters
calibration_params = AbstractQBXCostModel.get_unit_calibration_params()

kernel_args = {}
for arg_name, arg_expr in insn.kernel_arguments.items():
kernel_args[arg_name] = evaluate(arg_expr)

boxes_time, _ = cost_model.qbx_cost_per_box(
actx.queue, global_geo_data_device, insn.target_kernels[0],
kernel_args, calibration_params)
boxes_time = boxes_time.get()

from pytential.qbx.utils import ToHostTransferredGeoDataWrapper
global_geo_data = ToHostTransferredGeoDataWrapper(global_geo_data_device)

# {{{ Construct a traversal builder

# NOTE: The distributed implementation relies on building the same traversal
# objects as the one on the root rank. This means here the traversal builder
# should use the same parameters as `QBXFMMGeometryData.traversal`. To make
# it consistent across ranks, we broadcast the parameters here.

trav_param = None
if comm.Get_rank() == 0:
trav_param = {
"well_sep_is_n_away":
global_geo_data.geo_data.code_getter.build_traversal
.well_sep_is_n_away,
"from_sep_smaller_crit":
global_geo_data.geo_data.code_getter.build_traversal.
from_sep_smaller_crit,
"_from_sep_smaller_min_nsources_cumul":
global_geo_data.geo_data.lpot_source.
_from_sep_smaller_min_nsources_cumul}
trav_param = comm.bcast(trav_param, root=0)

traversal_builder = QBXFMMGeometryDataTraversalBuilder(
actx.context,
well_sep_is_n_away=trav_param["well_sep_is_n_away"],
from_sep_smaller_crit=trav_param["from_sep_smaller_crit"],
_from_sep_smaller_min_nsources_cumul=trav_param[
"_from_sep_smaller_min_nsources_cumul"])

# }}}

# {{{ Broadcast the subset of the global geometry data to worker ranks

global_geo_data = broadcast_global_geometry_data(
comm, actx, traversal_builder, global_geo_data)

# }}}

# {{{ Compute the local geometry data from the global geometry data

if comm.Get_rank() != 0:
boxes_time = np.empty(
global_geo_data.global_traversal.tree.nboxes, dtype=np.float64)

comm.Bcast(boxes_time, root=0)

local_geo_data = compute_local_geometry_data(
actx, comm, global_geo_data, boxes_time, traversal_builder)

# }}}

geo_data_cache[insn] = (global_geo_data, local_geo_data)

return global_geo_data, local_geo_data


class DistributedQBXLayerPotentialSource(QBXLayerPotentialSource):
def __init__(self, comm, cl_context, *args,
_use_target_specific_qbx: Optional[bool] = None,
Expand Down Expand Up @@ -613,12 +705,7 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext,
from pytential.qbx import get_flat_strengths_from_densities
from meshmode.discretization import Discretization

target_name_and_side_to_number = None
target_discrs_and_qbx_sides = None
global_geo_data_device = None
global_geo_data = None
local_geo_data = None
boxes_time = None
output_and_expansion_dtype = None
flat_strengths = None

Expand All @@ -631,90 +718,8 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext,
insn.source.geometry,
target_discrs_and_qbx_sides)

# Use the cost model to estimate execution time for partitioning
from pytential.qbx.cost import AbstractQBXCostModel, QBXCostModel

# FIXME: If the expansion wrangler is not FMMLib, the argument
# 'uses_pde_expansions' might be different
cost_model = QBXCostModel()

import warnings
warnings.warn(
"Kernel-specific calibration parameters are not supplied when"
"using distributed FMM.")
# TODO: supply better default calibration parameters
calibration_params = AbstractQBXCostModel.get_unit_calibration_params()

kernel_args = {}
for arg_name, arg_expr in insn.kernel_arguments.items():
kernel_args[arg_name] = evaluate(arg_expr)

boxes_time, _ = cost_model.qbx_cost_per_box(
actx.queue, global_geo_data_device, insn.target_kernels[0],
kernel_args, calibration_params)
boxes_time = boxes_time.get()

from pytential.qbx.utils import ToHostTransferredGeoDataWrapper
global_geo_data = ToHostTransferredGeoDataWrapper(global_geo_data_device)

# FIXME Exert more positive control over geo_data attribute lifetimes using
# geo_data.<method>.clear_cache(geo_data).

# FIXME Synthesize "bad centers" around corners and edges that have
# inadequate QBX coverage.

# FIXME don't compute *all* output kernels on all targets--respect that
# some target discretizations may only be asking for derivatives (e.g.)

# {{{ Construct a traversal builder

# NOTE: The distributed implementation relies on building the same traversal
# objects as the one on the root rank. This means here the traversal builder
# should use the same parameters as `QBXFMMGeometryData.traversal`. To make
# it consistent across ranks, we broadcast the parameters here.

trav_param = None
if self.comm.Get_rank() == 0:
trav_param = {
"well_sep_is_n_away":
global_geo_data.geo_data.code_getter.build_traversal
.well_sep_is_n_away,
"from_sep_smaller_crit":
global_geo_data.geo_data.code_getter.build_traversal.
from_sep_smaller_crit,
"_from_sep_smaller_min_nsources_cumul":
global_geo_data.geo_data.lpot_source.
_from_sep_smaller_min_nsources_cumul}
trav_param = self.comm.bcast(trav_param, root=0)

traversal_builder = QBXFMMGeometryDataTraversalBuilder(
actx.context,
well_sep_is_n_away=trav_param["well_sep_is_n_away"],
from_sep_smaller_crit=trav_param["from_sep_smaller_crit"],
_from_sep_smaller_min_nsources_cumul=trav_param[
"_from_sep_smaller_min_nsources_cumul"])

# }}}

# {{{ Broadcast the subset of the global geometry data to worker ranks

global_geo_data = broadcast_global_geometry_data(
self.comm, actx, traversal_builder, global_geo_data)

# }}}

# {{{ Compute the local geometry data from the global geometry data

if self.comm.Get_rank() != 0:
boxes_time = np.empty(
global_geo_data.global_traversal.tree.nboxes, dtype=np.float64)

self.comm.Bcast(boxes_time, root=0)

local_geo_data = compute_local_geometry_data(
actx, self.comm, global_geo_data, boxes_time, traversal_builder)

# }}}
global_geo_data, local_geo_data = distribute_geo_data(
self.comm, actx, insn, bound_expr, evaluate, global_geo_data_device)

tree_indep = self._tree_indep_data_for_wrangler(
target_kernels=insn.target_kernels,
Expand Down
1 change: 1 addition & 0 deletions pytential/symbolic/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,7 @@ class DistributedBoundExpression(BoundExpression):
def __init__(self, comm, places, sym_op_expr):
self.comm = comm
self._code = None
self._geo_data_cache = {}

if self.comm.Get_rank() == 0:
super().__init__(places, sym_op_expr)
Expand Down

0 comments on commit 58562f6

Please sign in to comment.