Skip to content

Commit

Permalink
Load balancing L0
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipDeegan committed Apr 25, 2024
1 parent 73a722d commit 5ad4e2b
Show file tree
Hide file tree
Showing 38 changed files with 1,400 additions and 526 deletions.
35 changes: 29 additions & 6 deletions pyphare/pyphare/pharein/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
serialize as serialize_sim,
deserialize as deserialize_sim,
)
from .load_balancer import LoadBalancer


def NO_GUI():
Expand Down Expand Up @@ -86,9 +87,7 @@ def __init__(self, fn):
self.fn = fn

def __call__(self, *xyz):
args = []
for i, arg in enumerate(xyz):
args.append(np.asarray(arg))
args = [np.asarray(arg) for arg in xyz]
ret = self.fn(*args)
if isinstance(ret, list):
ret = np.asarray(ret)
Expand Down Expand Up @@ -128,6 +127,9 @@ def populateDict():
def add_int(path, val):
pp.add_int(path, int(val))

def add_bool(path, val):
pp.add_bool(path, bool(val))

def add_double(path, val):
pp.add_double(path, float(val))

Expand Down Expand Up @@ -182,8 +184,6 @@ def add_vector_int(path, val):

add_int("simulation/AMR/tag_buffer", simulation.tag_buffer)

add_string("simulation/AMR/loadbalancing", simulation.loadbalancing)

refinement_boxes = simulation.refinement_boxes

def as_paths(rb):
Expand Down Expand Up @@ -223,6 +223,27 @@ def as_paths(rb):
add_double("simulation/algo/ohm/resistivity", simulation.resistivity)
add_double("simulation/algo/ohm/hyper_resistivity", simulation.hyper_resistivity)

# load balancer block start
lb = simulation.load_balancer or LoadBalancer(_register=False)
base = "simulation/AMR/loadbalancing"
add_bool(f"{base}/active", lb.active)
add_string(f"{base}/mode", lb.mode)
add_double(f"{base}/tolerance", lb.tol)

# if mode==nppc, imbalance allowed
add_bool(f"{base}/auto", lb.auto)
add_size_t(f"{base}/next_rebalance", lb.next_rebalance)
add_size_t(f"{base}/max_next_rebalance", lb.max_next_rebalance)
add_size_t(
f"{base}/next_rebalance_backoff_multiplier",
lb.next_rebalance_backoff_multiplier,
)

# cadence based values
add_size_t(f"{base}/every", lb.every)
add_bool(f"{base}/on_init", lb.on_init)
# load balancer block end

init_model = simulation.model
modelDict = init_model.model_dict

Expand All @@ -246,12 +267,14 @@ def as_paths(rb):
addInitFunction(partinit_path + "thermal_velocity_x", fn_wrapper(d["vthx"]))
addInitFunction(partinit_path + "thermal_velocity_y", fn_wrapper(d["vthy"]))
addInitFunction(partinit_path + "thermal_velocity_z", fn_wrapper(d["vthz"]))
add_int(partinit_path + "nbr_part_per_cell", d["nbrParticlesPerCell"])
add_double(partinit_path + "charge", d["charge"])
add_string(partinit_path + "basis", "cartesian")
if "init" in d and "seed" in d["init"]:
pp.add_optional_size_t(partinit_path + "init/seed", d["init"]["seed"])

add_int(partinit_path + "nbr_part_per_cell", d["nbrParticlesPerCell"])
add_double(partinit_path + "density_cut_off", d["density_cut_off"])

add_string("simulation/electromag/name", "EM")
add_string("simulation/electromag/electric/name", "E")

Expand Down
50 changes: 50 additions & 0 deletions pyphare/pyphare/pharein/load_balancer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#
#

from dataclasses import dataclass, field
from . import global_vars as gv


@dataclass
class LoadBalancer:
# whether or not load balancing is performed
active: bool = field(default_factory=lambda: False)

# which way load is assessed
mode: str = field(default_factory=lambda: "nppc")

# acceptable imbalance essentially
tol: float = field(default_factory=lambda: 0.05)

# whether to rebalance/check imbalance on init
on_init: bool = field(default_factory=lambda: True)

# if auto, other values are not used if active
auto: bool = field(default_factory=lambda: True)
next_rebalance_backoff_multiplier: int = field(default_factory=lambda: 2)
next_rebalance: int = field(default_factory=lambda: 200)
max_next_rebalance: int = field(default_factory=lambda: 1000)

# if !auto these values are used if active
every: int = field(default_factory=lambda: 1)

# internal, allows not registering object for default init
_register: bool = field(default_factory=lambda: True)

def __post_init__(self):
allowed_modes = [
"nppc", # count particles per rank
"homogeneous", # count cells per rank
]

if self.mode not in allowed_modes:
raise RuntimeError(f"LoadBalancer mode '{self.mode}' is not valid")

if self._register:
if not gv.sim:
raise RuntimeError(
f"LoadBalancer cannot be registered as no simulation exists"
)
if gv.sim.load_balancer:
raise RuntimeError(f"LoadBalancer is already registered to simulation")
gv.sim.load_balancer = self
2 changes: 2 additions & 0 deletions pyphare/pyphare/pharein/maxwellian_fluid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def add_population(
vthy=None,
vthz=None,
init={},
density_cut_off=1e-16,
):
"""
add a particle population to the current model
Expand Down Expand Up @@ -122,6 +123,7 @@ def add_population(
"vthz": vthz,
"nbrParticlesPerCell": nbr_part_per_cell,
"init": init,
"density_cut_off": density_cut_off,
}
}

Expand Down
12 changes: 8 additions & 4 deletions pyphare/pyphare/pharein/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def check_time(**kwargs):
and "time_step" not in kwargs
)

start_time = kwargs.get("restart_options", {}).get("restart_time", 0)

if final_and_dt:
time_step_nbr = int(kwargs["final_time"] / kwargs["time_step"])
time_step = kwargs["final_time"] / time_step_nbr
Expand All @@ -139,7 +141,11 @@ def check_time(**kwargs):
+ " or 'final_time' and 'time_step_nbr'"
)

return time_step_nbr, time_step, kwargs.get("final_time", time_step * time_step_nbr)
return (
time_step_nbr,
time_step,
kwargs.get("final_time", start_time + time_step * time_step_nbr),
)


# ------------------------------------------------------------------------------
Expand Down Expand Up @@ -831,6 +837,7 @@ def __init__(self, **kwargs):
self.diagnostics = {}
self.model = None
self.electrons = None
self.load_balancer = None

# hard coded in C++ MultiPhysicsIntegrator::getMaxFinerLevelDt
self.nSubcycles = 4
Expand All @@ -846,9 +853,6 @@ def __init__(self, **kwargs):
]
validate_restart_options(self)

def final_time(self):
return self.time_step * self.time_step_nbr

def simulation_domain(self):
return [dl * n + ori for dl, n, ori in zip(self.dl, self.cells, self.origin)]

Expand Down
67 changes: 67 additions & 0 deletions pyphare/pyphare/pharesee/hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from .particles import Particles
from ..core.phare_utilities import listify

from dataclasses import dataclass


class PatchData:
"""
Expand Down Expand Up @@ -86,6 +88,9 @@ def __str__(self):
def __repr__(self):
return self.__str__()

def __eq__(self, that):
return self.field_name == that.field_name and self.dataset[:] == that.dataset[:]

def select(self, box):
"""
return view of internal data based on overlap of input box
Expand Down Expand Up @@ -221,6 +226,9 @@ def __getitem__(self, box):
def size(self):
return self.dataset.size()

def __eq__(self, that):
return self.name == that.name and self.dataset == that.dataset


class Patch:
"""
Expand All @@ -247,6 +255,9 @@ def __str__(self):
def __repr__(self):
return self.__str__()

def __getitem__(self, key):
return self.patch_datas[key]

def copy(self):
"""does not copy patchdatas.datasets (see class PatchData)"""
from copy import deepcopy
Expand Down Expand Up @@ -1755,6 +1766,7 @@ def merge_particles(hierarchy):
popname = domain_pdata[0].split("_")[0]
pdatas[popname + "_particles"] = pdatas[domain_pdata[0]]
del pdatas[domain_pdata[0]]
return hierarchy


def h5_filename_from(diagInfo):
Expand Down Expand Up @@ -1792,3 +1804,58 @@ def getPatch(hier, point):
print("error : ", k, v)
raise RuntimeError("more than one patch found for point")
return patches


@dataclass
class EqualityReport:
ok: bool
reason: str

def __bool__(self):
return self.ok


def hierarchy_compare(this, that):
if not isinstance(this, PatchHierarchy) or not isinstance(that, PatchHierarchy):
return EqualityReport(False, "class type mismatch")

if this.ndim != that.ndim or this.domain_box != that.domain_box:
return EqualityReport(False, "dimensional mismatch")

if this.time_hier.keys() != that.time_hier.keys():
return EqualityReport(False, "timesteps mismatch")

for tidx in this.times():
patch_levels_ref = this.time_hier[tidx]
patch_levels_cmp = that.time_hier[tidx]

if patch_levels_ref.keys() != patch_levels_cmp.keys():
return EqualityReport(False, "levels mismatch")

for level_idx in patch_levels_cmp.keys():
patch_level_ref = patch_levels_ref[level_idx]
patch_level_cmp = patch_levels_cmp[level_idx]

for patch_idx in range(len(patch_level_cmp.patches)):
patch_ref = patch_level_ref.patches[patch_idx]
patch_cmp = patch_level_cmp.patches[patch_idx]

if patch_ref.patch_datas.keys() != patch_cmp.patch_datas.keys():
print(list(patch_ref.patch_datas.keys()))
print(list(patch_cmp.patch_datas.keys()))
return EqualityReport(False, "data keys mismatch")

for patch_data_key in patch_ref.patch_datas.keys():
patch_data_ref = patch_ref.patch_datas[patch_data_key]
patch_data_cmp = patch_cmp.patch_datas[patch_data_key]

if patch_data_cmp != patch_data_ref:
return EqualityReport(
False,
"data mismatch: "
+ type(patch_data_cmp).__name__
+ " "
+ type(patch_data_ref).__name__,
)

return EqualityReport(True, "OK")
29 changes: 29 additions & 0 deletions pyphare/pyphare/pharesee/particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,32 @@ def _arg_sort(particles):
if particles.ndim == 3:
z1 = particles.iCells[:, 2] + particles.deltas[:, 2]
return np.argsort(np.sqrt((x1**2 + y1**2 + z1**2)) / (x1 / y1 / z1))


def single_patch_per_level_per_pop_from(hier, only_keep_L0=True): # dragons
for tidx in hier.times():
if only_keep_L0:
hier.time_hier[tidx] = {0: hier.time_hier[tidx][0]}

patch_levels = hier.time_hier[tidx]

for level_idx in patch_levels.keys():
patch_level = patch_levels[level_idx]

patch0 = patch_level.patches[0]
particles = {} # str:[]

for key in patch0.patch_datas.keys():
if isinstance(patch0[key].dataset, Particles):
particles[key] = []

for patch in patch_level.patches:
for patch_data_key in patch.patch_datas.keys():
particles[key] += [patch[patch_data_key].dataset]

for key in particles.keys():
patch0[key].dataset = aggregate(particles[key])

patch_levels[level_idx].patches = [patch0] # just one patch

return hier
35 changes: 30 additions & 5 deletions pyphare/pyphare/simulator/simulator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import atexit
import time as timem
import numpy as np
Expand Down Expand Up @@ -29,6 +30,28 @@ def startMPI():
life_cycles["samrai"] = cpp_lib().SamraiLifeCycle()


def print_rank0(*args, **kwargs):
from pyphare.cpp import cpp_lib

if cpp_lib().mpi_rank() == 0:
print(*args, **kwargs)


def plot_timestep_time(timestep_times):
from pyphare.cpp import cpp_lib

if cpp_lib().mpi_rank() == 0:
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.plot(timestep_times)
plt.ylabel("timestep time")
plt.xlabel("timestep")
fig.savefig("timestep_times.png")

cpp_lib().mpi_barrier()


class Simulator:
def __init__(self, simulation, auto_dump=True, **kwargs):
import pyphare.pharein as ph
Expand Down Expand Up @@ -107,8 +130,7 @@ def _throw(self, e):
import sys
from pyphare.cpp import cpp_lib

if cpp_lib().mpi_rank() == 0:
print(e)
print_rank0(e)
sys.exit(1)

def advance(self, dt=None):
Expand All @@ -134,7 +156,7 @@ def times(self):
self.timeStep(),
)

def run(self):
def run(self, plot_times=False):
from pyphare.cpp import cpp_lib

self._check_init()
Expand All @@ -153,8 +175,11 @@ def run(self):
out = f"t = {t:8.5f} - {ticktock:6.5f}sec - total {np.sum(perf):7.4}sec"
print(out, end=self.print_eol)

print("mean advance time = {}".format(np.mean(perf)))
print("total advance time = {}".format(np.sum(perf)))
print_rank0(f"mean advance time = {np.mean(perf)}")
print_rank0(f"total advance time = {datetime.timedelta(seconds=np.sum(perf))}")

if plot_times:
plot_timestep_time(perf)

return self.reset()

Expand Down
1 change: 1 addition & 0 deletions pyphare/pyphare_tests/test_pharesee/test_geometry_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
test_geometry_3d.py
Loading

0 comments on commit 5ad4e2b

Please sign in to comment.