Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update EnsembleConnector with new pyop2.internal_comm implementation #188

Merged
merged 1 commit into from
May 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions asQ/ensemble.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from firedrake import COMM_WORLD, Ensemble
from pyop2.mpi import internal_comm, decref
from pyop2.mpi import internal_comm

__all__ = ['create_ensemble', 'split_ensemble', 'EnsembleConnector']

Expand All @@ -23,7 +23,7 @@ def create_ensemble(time_partition, comm=COMM_WORLD):
return Ensemble(comm, nspatial_domains)


def split_ensemble(ensemble, split_size):
def split_ensemble(ensemble, split_size, **kwargs):
"""
Split an Ensemble into multiple smaller Ensembles which share the same
spatial communicators `ensemble.comm`.
Expand All @@ -45,11 +45,11 @@ def split_ensemble(ensemble, split_size):
split_comm = ensemble.global_comm.Split(color=split_rank,
key=ensemble.global_comm.rank)

return EnsembleConnector(split_comm, ensemble.comm, split_size)
return EnsembleConnector(split_comm, ensemble.comm, split_size, **kwargs)


class EnsembleConnector(Ensemble):
def __init__(self, global_comm, local_comm, nmembers):
def __init__(self, global_comm, local_comm, nmembers, **kwargs):
"""
An Ensemble created from provided spatial communicators (ensemble.comm).

Expand All @@ -61,22 +61,16 @@ def __init__(self, global_comm, local_comm, nmembers):
msg = "The global ensemble must have the same number of ranks as the sum of the local comms"
raise ValueError(msg)

ensemble_name = kwargs.get("ensemble_name", "Ensemble")
self.global_comm = global_comm
self._global_comm = internal_comm(self.global_comm)
self._comm = internal_comm(self.global_comm, self)

self.comm = local_comm
self._comm = internal_comm(self.comm)
self.comm.name = f"{ensemble_name} spatial comm"
self._spatial_comm = internal_comm(self.comm, self)

self.ensemble_comm = self.global_comm.Split(color=self.comm.rank,
key=global_comm.rank)
self.ensemble_comm.name = f"{ensemble_name} ensemble comm"

self._ensemble_comm = internal_comm(self.ensemble_comm)

def __del__(self):
if hasattr(self, "ensemble_comm"):
self.ensemble_comm.Free()
del self.ensemble_comm
for comm_name in ["_global_comm", "_comm", "_ensemble_comm"]:
if hasattr(self, comm_name):
comm = getattr(self, comm_name)
decref(comm)
self._ensemble_comm = internal_comm(self.ensemble_comm, self)
Loading