Skip to content

Commit

Permalink
Refactor reorder_and_finalize_potentials into wranglers
Browse files Browse the repository at this point in the history
  • Loading branch information
gaohao95 committed Aug 25, 2024
1 parent 625cb3e commit e725b3d
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 48 deletions.
32 changes: 4 additions & 28 deletions pytential/qbx/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext,
# Execute global QBX.
timing_data: Dict[str, Any] = {}
all_potentials_on_every_target = drive_dfmm(
self.comm, flat_strengths, wrangler, timing_data)
flat_strengths, wrangler, timing_data)

if self.comm.Get_rank() == 0:
assert global_geo_data_device is not None
Expand Down Expand Up @@ -816,11 +816,9 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext,
return results, timing_data


def drive_dfmm(comm, src_weight_vecs, wrangler, timing_data=None):
def drive_dfmm(src_weight_vecs, wrangler, timing_data=None):
# TODO: Integrate the distributed functionality with `qbx.fmm.drive_fmm`,
# similar to that in `boxtree`.

current_rank = comm.Get_rank()
local_traversal = wrangler.traversal

# {{{ Distribute source weights
Expand Down Expand Up @@ -993,30 +991,8 @@ def drive_dfmm(comm, src_weight_vecs, wrangler, timing_data=None):
non_qbx_potentials = wrangler.gather_non_qbx_potentials(non_qbx_potentials)
qbx_potentials = wrangler.gather_qbx_potentials(qbx_potentials)

if current_rank != 0: # worker process
result = None

else: # master process

all_potentials_in_tree_order = wrangler.full_output_zeros(template_ary)

nqbtl = wrangler.global_geo_data.non_qbx_box_target_lists

for ap_i, nqp_i in zip(
all_potentials_in_tree_order, non_qbx_potentials):
ap_i[nqbtl.unfiltered_from_filtered_target_indices] = nqp_i

all_potentials_in_tree_order += qbx_potentials

def reorder_and_finalize_potentials(x):
# "finalize" gives host FMMs (like FMMlib) a chance to turn the
# potential back into a CL array.
return wrangler.finalize_potentials(
x[wrangler.global_traversal.tree.sorted_target_ids], template_ary)

from pytools.obj_array import with_object_array_or_scalar
result = with_object_array_or_scalar(
reorder_and_finalize_potentials, all_potentials_in_tree_order)
result = wrangler.reorder_and_finalize_potentials(
non_qbx_potentials, qbx_potentials, template_ary)

if timing_data is not None:
timing_data.update(recorder.summarize())
Expand Down
44 changes: 24 additions & 20 deletions pytential/qbx/fmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,28 @@ def make_container():

# {{{ FMM top-level

def _reorder_and_finalize_potentials(
wrangler, non_qbx_potentials, qbx_potentials, template_ary):
nqbtl = wrangler.geo_data.non_qbx_box_target_lists()

all_potentials_in_tree_order = wrangler.full_output_zeros(template_ary)

for ap_i, nqp_i in zip(all_potentials_in_tree_order, non_qbx_potentials):
ap_i[nqbtl.unfiltered_from_filtered_target_indices] = nqp_i

all_potentials_in_tree_order += qbx_potentials

def reorder_and_finalize_potentials(x):
# "finalize" gives host FMMs (like FMMlib) a chance to turn the
# potential back into a CL array.
return wrangler.finalize_potentials(x[
wrangler.geo_data.traversal().tree.sorted_target_ids], template_ary)

from pytools.obj_array import obj_array_vectorize
return obj_array_vectorize(
reorder_and_finalize_potentials, all_potentials_in_tree_order)


def drive_fmm(expansion_wrangler, src_weight_vecs, timing_data=None):
"""Top-level driver routine for the QBX fast multipole calculation.
Expand All @@ -416,10 +438,7 @@ def drive_fmm(expansion_wrangler, src_weight_vecs, timing_data=None):
See also :func:`boxtree.fmm.drive_fmm`.
"""
wrangler = expansion_wrangler

geo_data = wrangler.geo_data
traversal = wrangler.traversal
tree = traversal.tree

template_ary = src_weight_vecs[0]

Expand Down Expand Up @@ -581,23 +600,8 @@ def drive_fmm(expansion_wrangler, src_weight_vecs, timing_data=None):

# {{{ reorder potentials

nqbtl = geo_data.non_qbx_box_target_lists()

all_potentials_in_tree_order = wrangler.full_output_zeros(template_ary)

for ap_i, nqp_i in zip(all_potentials_in_tree_order, non_qbx_potentials):
ap_i[nqbtl.unfiltered_from_filtered_target_indices] = nqp_i

all_potentials_in_tree_order += qbx_potentials

def reorder_and_finalize_potentials(x):
# "finalize" gives host FMMs (like FMMlib) a chance to turn the
# potential back into a CL array.
return wrangler.finalize_potentials(x[tree.sorted_target_ids], template_ary)

from pytools.obj_array import obj_array_vectorize
result = obj_array_vectorize(
reorder_and_finalize_potentials, all_potentials_in_tree_order)
result = _reorder_and_finalize_potentials(
wrangler, non_qbx_potentials, qbx_potentials, template_ary)

# }}}

Expand Down
27 changes: 27 additions & 0 deletions pytential/qbx/fmmlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,4 +752,31 @@ def gather_qbx_potentials(self, qbx_potentials):
ntargets, qbx_potentials,
self.geo_data.qbx_target_mask, self.MPITags["qbx_potentials"])

def reorder_and_finalize_potentials(
self, non_qbx_potentials, qbx_potentials, template_ary):
mpi_rank = self.comm.Get_rank()

if mpi_rank == 0:
all_potentials_in_tree_order = self.full_output_zeros(template_ary)

nqbtl = self.global_geo_data.non_qbx_box_target_lists

for ap_i, nqp_i in zip(
all_potentials_in_tree_order, non_qbx_potentials):
ap_i[nqbtl.unfiltered_from_filtered_target_indices] = nqp_i

all_potentials_in_tree_order += qbx_potentials

def _reorder_and_finalize_potentials(x):
# "finalize" gives host FMMs (like FMMlib) a chance to turn the
# potential back into a CL array.
return self.finalize_potentials(
x[self.global_traversal.tree.sorted_target_ids], template_ary)

from pytools.obj_array import with_object_array_or_scalar
return with_object_array_or_scalar(
_reorder_and_finalize_potentials, all_potentials_in_tree_order)
else:
return None

# vim: foldmethod=marker

0 comments on commit e725b3d

Please sign in to comment.