Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ManualEnsemble class for specifying all comms in an Ensemble #189

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
88 changes: 68 additions & 20 deletions asQ/ensemble.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import weakref
from firedrake import COMM_WORLD, Ensemble
from pyop2.mpi import internal_comm
from pyop2.mpi import MPI, internal_comm, is_pyop2_comm, PyOP2CommError

__all__ = ['create_ensemble', 'split_ensemble', 'EnsembleConnector']
__all__ = ['create_ensemble', 'split_ensemble']


def create_ensemble(time_partition, comm=COMM_WORLD):
Expand Down Expand Up @@ -42,35 +43,82 @@ def split_ensemble(ensemble, split_size, **kwargs):
split_rank = ensemble.ensemble_comm.rank // split_size

# create split_ensemble.global_comm
split_comm = ensemble.global_comm.Split(color=split_rank,
key=ensemble.global_comm.rank)
split_global_comm = ensemble.global_comm.Split(color=split_rank,
key=ensemble.global_comm.rank)

return EnsembleConnector(split_comm, ensemble.comm, split_size, **kwargs)
# create split_ensemble.ensemble_comm
split_ensemble_comm = ensemble.ensemble_comm.Split(color=split_rank,
key=ensemble.ensemble_comm.rank)

new_ensemble = ManualEnsemble(split_global_comm, ensemble.comm, split_ensemble_comm, **kwargs)

class EnsembleConnector(Ensemble):
def __init__(self, global_comm, local_comm, nmembers, **kwargs):
# make sure the new comms are cleaned up when the split ensemble goes out of scope
weakref.finalize(new_ensemble, split_global_comm.Free)
weakref.finalize(new_ensemble, split_ensemble_comm.Free)
Comment on lines +56 to +57
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the finalizer should be set in the __init__ method of the ManualEnsemble. It's the pattern used elsewhere and it prevents someone (user or developer) forgetting to add the finalizers.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comes back to the question of what ManualEnsemble should be responsible for. In this case the ensemble_comm and the global_comm need finalising but the spatial_comm doesn't.

Currently I've gone with "the user of ManualEnsemble is totally responsible for the comms they pass in", but we could have optional arguments to ManualEnsemble.__init__ for which comms to set finalizers for, I'd be ok with that.

Just to note, I'm saying "user" here but ManualEnsemble isn't exposed publicly, it's meant for internal use so I'd expect it to always be wrapped in something like the split_ensemble function which has more knowledge about comm lifetime.

Copy link
Member Author

@JHopeCollins JHopeCollins May 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g.

class ManualEnsemble(Ensemble)
    def __init__(self, global_comm, spatial_comm, ensemble_comm,
                 finalize_global_comm=False, finalize_spatial_comm=False, finalize_ensemble_comm=False)

        if finalize_global_comm:
            weakref.finalize(self, global_comm.Free)
        if finalize_spatial_comm:
            weakref.finalize(self, spatial_comm.Free)
        if finalize_ensemble_comm:
            weakref.finalize(self, ensemble_comm.Free)
        ...


return new_ensemble


class ManualEnsemble(Ensemble):
def __init__(self, global_comm, spatial_comm, ensemble_comm, **kwargs):
"""
An Ensemble created from provided spatial communicators (ensemble.comm).
An Ensemble created from provided comms.

:arg global_comm: global communicator the Ensemble is defined over.
:arg local_comm: communicator to use for the Ensemble.comm member.
:arg nmembers: number of Ensemble members (ensemble.ensemble_comm.size).
:arg spatial_comm: communicator to use for the Ensemble.comm member.
:arg ensemble_comm: communicator to use for the Ensemble.ensemble_comm member.

The global_comm, spatial_comm, and ensemble_comm must have the same logical meaning
as they do in firedrake.Ensemble. i.e. the global_comm is the union of a cartesian
product of multiple spatial_comms and ensemble_comms.
- ManualEnsemble is logically defined over all ranks in global_comm.
- Each rank in global_comm belongs to only one spatial_comm and one ensemble_comm.
- The size of the intersection of any (spatial_comm, ensemble_comm) pair is 1.

WARNING: Not meeting these requirements may produce in errors, hangs, and nonsensical results.

ManualEnsemble will not Free any of the comms. This is the responsibility of the user.
"""
if nmembers*local_comm.size != global_comm.size:
msg = "The global ensemble must have the same number of ranks as the sum of the local comms"
raise ValueError(msg)
# are we handed user comms?

for comm in (global_comm, spatial_comm, ensemble_comm):
if is_pyop2_comm(comm):
raise PyOP2CommError("Cannot construct Ensemble from PyOP2 internal comm")

# check cartesian product consistency

if spatial_comm.size*ensemble_comm.size != global_comm.size:
msg = "The global comm must have the same number of ranks as the product of spatial and ensemble comms"
raise PyOP2CommError(msg)

global_group = global_comm.Get_group()
spatial_group = spatial_comm.Get_group()
ensemble_group = ensemble_comm.Get_group()

if MPI.Group.Intersection(spatial_group, ensemble_group).size != 1:
raise PyOP2CommError("spatial and ensemble comms must be cartesian product in global_comm")

is_subgroup = lambda sub, group: MPI.Group.Compare(sub, MPI.Group.Intersection(sub, group)) in {MPI.IDENT, MPI.CONGRUENT}

if not is_subgroup(spatial_group, global_group):
raise PyOP2CommError("spatial_comm must be subgroup of global_comm")
if not is_subgroup(ensemble_group, global_group):
raise PyOP2CommError("ensemble_comm must be subgroup of global_comm")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic isn't completely exhaustive, it doesn't currently check whether you have the same communicator. For example:

if __name__ == "__main__":
    r = COMM_WORLD.rank
    s = COMM_WORLD.size

    ensemble_color1 = int(r < s/2)
    ensemble1 = COMM_WORLD.Split(color=ensemble_color1, key=r)

    ensemble_color2 = r >= s/2
    ensemble2 = COMM_WORLD.Split(color=ensemble_color2, key=r)

    spatial_color = r % (s/2)
    spatial = COMM_WORLD.Split(color=spatial_color, key=r)

    correct = ManualEnsemble(COMM_WORLD, spatial, ensemble1)
    if r < s/2:
        broken = ManualEnsemble(COMM_WORLD, spatial, ensemble1)
    else:
        broken = ManualEnsemble(COMM_WORLD, spatial, ensemble2)

    print("ALL PASSED")

Will run just fine, but the broken ensemble uses two different communicators.

This is broken in a very subtle way as the mismatched comm will be destroyed when the ensemble is destroyed.

Copy link
Member Author

@JHopeCollins JHopeCollins May 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see what is broken about this.

Assuming 8 ranks, ranks 0-3 use the ensemble comm from the first Split, and ranks 4-7 use the ensemble comm from the second Split call, but this doesn't matter. All ranks in each ensemble comm use the same one, which is what matters.
They don't "know", and don't need to know, what the other half is doing so long as every rank has an ensemble comm that connects the same part of all spatial comms.

As written so far, ManualEnsemble doesn't destroy any of the comms its given (the docstring explicitly says that the caller is responsible for this).

# create internal duplicates and name comms for debugging
ensemble_name = kwargs.get("name", "Ensemble")

ensemble_name = kwargs.get("ensemble_name", "Ensemble")
self.global_comm = global_comm
if not hasattr(self.global_comm, "name"):
self.global_comm.name = f"{ensemble_name} global comm"
self._comm = internal_comm(self.global_comm, self)

self.comm = local_comm
self.comm.name = f"{ensemble_name} spatial comm"
self.comm = spatial_comm
if not hasattr(self.comm, "name"):
self.comm.name = f"{ensemble_name} spatial comm"
self._spatial_comm = internal_comm(self.comm, self)

self.ensemble_comm = self.global_comm.Split(color=self.comm.rank,
key=global_comm.rank)
self.ensemble_comm.name = f"{ensemble_name} ensemble comm"

self.ensemble_comm = ensemble_comm
if not hasattr(self.ensemble_comm, "name"):
self.ensemble_comm.name = f"{ensemble_name} ensemble comm"
self._ensemble_comm = internal_comm(self.ensemble_comm, self)
Loading