From f95d4c7e8b844cf95a6f6fb1cf4bcb051eebf214 Mon Sep 17 00:00:00 2001 From: Hao Gao Date: Sun, 4 Aug 2024 17:17:28 -0700 Subject: [PATCH] Refactor reorder_and_finalize_potentials into wranglers --- pytential/qbx/distributed.py | 32 ++++----------------------- pytential/qbx/fmm.py | 43 ++++++++++++++++++++---------------- pytential/qbx/fmmlib.py | 27 ++++++++++++++++++++++ 3 files changed, 55 insertions(+), 47 deletions(-) diff --git a/pytential/qbx/distributed.py b/pytential/qbx/distributed.py index d5a541f0e..aebf25d0a 100644 --- a/pytential/qbx/distributed.py +++ b/pytential/qbx/distributed.py @@ -787,7 +787,7 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext, # Execute global QBX. timing_data: Dict[str, Any] = {} all_potentials_on_every_target = drive_dfmm( - self.comm, flat_strengths, wrangler, timing_data) + flat_strengths, wrangler, timing_data) if self.comm.Get_rank() == 0: assert global_geo_data_device is not None @@ -816,11 +816,9 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext, return results, timing_data -def drive_dfmm(comm, src_weight_vecs, wrangler, timing_data=None): +def drive_dfmm(src_weight_vecs, wrangler, timing_data=None): # TODO: Integrate the distributed functionality with `qbx.fmm.drive_fmm`, # similar to that in `boxtree`. - - current_rank = comm.Get_rank() local_traversal = wrangler.traversal # {{{ Distribute source weights @@ -993,30 +991,8 @@ def drive_dfmm(comm, src_weight_vecs, wrangler, timing_data=None): non_qbx_potentials = wrangler.gather_non_qbx_potentials(non_qbx_potentials) qbx_potentials = wrangler.gather_qbx_potentials(qbx_potentials) - if current_rank != 0: # worker process - result = None - - else: # master process - - all_potentials_in_tree_order = wrangler.full_output_zeros(template_ary) - - nqbtl = wrangler.global_geo_data.non_qbx_box_target_lists - - for ap_i, nqp_i in zip( - all_potentials_in_tree_order, non_qbx_potentials): - ap_i[nqbtl.unfiltered_from_filtered_target_indices] = nqp_i - - all_potentials_in_tree_order += qbx_potentials - - def reorder_and_finalize_potentials(x): - # "finalize" gives host FMMs (like FMMlib) a chance to turn the - # potential back into a CL array. - return wrangler.finalize_potentials( - x[wrangler.global_traversal.tree.sorted_target_ids], template_ary) - - from pytools.obj_array import with_object_array_or_scalar - result = with_object_array_or_scalar( - reorder_and_finalize_potentials, all_potentials_in_tree_order) + result = wrangler.reorder_and_finalize_potentials( + non_qbx_potentials, qbx_potentials, template_ary) if timing_data is not None: timing_data.update(recorder.summarize()) diff --git a/pytential/qbx/fmm.py b/pytential/qbx/fmm.py index 4779a6000..fe5f20a12 100644 --- a/pytential/qbx/fmm.py +++ b/pytential/qbx/fmm.py @@ -400,6 +400,28 @@ def make_container(): # {{{ FMM top-level +def _reorder_and_finalize_potentials( + wrangler, non_qbx_potentials, qbx_potentials, template_ary): + nqbtl = wrangler.geo_data.non_qbx_box_target_lists() + + all_potentials_in_tree_order = wrangler.full_output_zeros(template_ary) + + for ap_i, nqp_i in zip(all_potentials_in_tree_order, non_qbx_potentials): + ap_i[nqbtl.unfiltered_from_filtered_target_indices] = nqp_i + + all_potentials_in_tree_order += qbx_potentials + + def reorder_and_finalize_potentials(x): + # "finalize" gives host FMMs (like FMMlib) a chance to turn the + # potential back into a CL array. + return wrangler.finalize_potentials(x[ + wrangler.geo_data.traversal().tree.sorted_target_ids], template_ary) + + from pytools.obj_array import obj_array_vectorize + return obj_array_vectorize( + reorder_and_finalize_potentials, all_potentials_in_tree_order) + + def drive_fmm(expansion_wrangler, src_weight_vecs, timing_data=None, traversal=None): """Top-level driver routine for the QBX fast multipole calculation. @@ -423,8 +445,6 @@ def drive_fmm(expansion_wrangler, src_weight_vecs, timing_data=None, if traversal is None: traversal = geo_data.traversal() - tree = traversal.tree - template_ary = src_weight_vecs[0] recorder = TimingRecorder() @@ -585,23 +605,8 @@ def drive_fmm(expansion_wrangler, src_weight_vecs, timing_data=None, # {{{ reorder potentials - nqbtl = geo_data.non_qbx_box_target_lists() - - all_potentials_in_tree_order = wrangler.full_output_zeros(template_ary) - - for ap_i, nqp_i in zip(all_potentials_in_tree_order, non_qbx_potentials): - ap_i[nqbtl.unfiltered_from_filtered_target_indices] = nqp_i - - all_potentials_in_tree_order += qbx_potentials - - def reorder_and_finalize_potentials(x): - # "finalize" gives host FMMs (like FMMlib) a chance to turn the - # potential back into a CL array. - return wrangler.finalize_potentials(x[tree.sorted_target_ids], template_ary) - - from pytools.obj_array import obj_array_vectorize - result = obj_array_vectorize( - reorder_and_finalize_potentials, all_potentials_in_tree_order) + result = _reorder_and_finalize_potentials( + wrangler, non_qbx_potentials, qbx_potentials, template_ary) # }}} diff --git a/pytential/qbx/fmmlib.py b/pytential/qbx/fmmlib.py index d75c770bb..d51c3589e 100644 --- a/pytential/qbx/fmmlib.py +++ b/pytential/qbx/fmmlib.py @@ -752,4 +752,31 @@ def gather_qbx_potentials(self, qbx_potentials): ntargets, qbx_potentials, self.geo_data.qbx_target_mask, self.MPITags["qbx_potentials"]) + def reorder_and_finalize_potentials( + self, non_qbx_potentials, qbx_potentials, template_ary): + mpi_rank = self.comm.Get_rank() + + if mpi_rank == 0: + all_potentials_in_tree_order = self.full_output_zeros(template_ary) + + nqbtl = self.global_geo_data.non_qbx_box_target_lists + + for ap_i, nqp_i in zip( + all_potentials_in_tree_order, non_qbx_potentials): + ap_i[nqbtl.unfiltered_from_filtered_target_indices] = nqp_i + + all_potentials_in_tree_order += qbx_potentials + + def _reorder_and_finalize_potentials(x): + # "finalize" gives host FMMs (like FMMlib) a chance to turn the + # potential back into a CL array. + return self.finalize_potentials( + x[self.global_traversal.tree.sorted_target_ids], template_ary) + + from pytools.obj_array import with_object_array_or_scalar + return with_object_array_or_scalar( + _reorder_and_finalize_potentials, all_potentials_in_tree_order) + else: + return None + # vim: foldmethod=marker