From 415e6e5d41e7710155c5735289368f141afab9aa Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Mon, 19 Aug 2024 17:34:54 -0400 Subject: [PATCH] Merge. --- src/simsopt/_core/finite_difference.py | 28 +++++++++---- src/simsopt/objectives/utilities.py | 55 +++++++++++++++++++++++++- 2 files changed, 74 insertions(+), 9 deletions(-) diff --git a/src/simsopt/_core/finite_difference.py b/src/simsopt/_core/finite_difference.py index 6ce780d8a..ab22a1a62 100644 --- a/src/simsopt/_core/finite_difference.py +++ b/src/simsopt/_core/finite_difference.py @@ -1,6 +1,6 @@ # coding: utf-8 # Copyright (c) HiddenSymmetries Development Team. -# Distributed under the terms of the LGPL License +# Distributed under the terms of the MIT License """ This module provides Jacobian evaluated with finite difference scheme @@ -13,7 +13,7 @@ import collections from time import time from datetime import datetime -from typing import Callable, Sequence +from typing import Callable, Union, IO from numbers import Real import numpy as np @@ -132,7 +132,7 @@ def __init__(self, func: Callable, abs_step: Real = 1.0e-7, rel_step: Real = 0.0, diff_method: str = "forward", - log_file: Union[str, typing.IO] = "jac_log") -> None: + log_file: Union[str, IO] = "jac_log") -> None: try: if not isinstance(func.__self__, Optimizable): @@ -163,6 +163,10 @@ def __init__(self, func: Callable, self.jac_size = None self.eval_cnt = 1 + # initialize cache + self.x_cache = None + self.jac_cache = None + def __enter__(self): self.mpi_apart() self.init_log() @@ -294,14 +298,14 @@ def mpi_leaders_task(self, *args): logger.debug('mpi leaders task') # x is a buffer for receiving the state vector: - x = np.empty(self.opt.dof_size, dtype='d') + full_x = np.empty(self.opt.full_dof_size, dtype='d') # If we make it here, we must be doing a fd_jac_par # calculation, so receive the state vector: mpi4py has # separate bcast and Bcast functions!! comm.Bcast(x, # root=0) - x = self.mpi.comm_leaders.bcast(x, root=0) - logger.debug(f'mpi leaders loop x={x}') - self.opt.x = x + full_x = self.mpi.comm_leaders.bcast(full_x, root=0) + logger.debug(f'mpi leaders loop full_x={full_x}') + self.opt.full_x = full_x self._jac() def mpi_workers_task(self, *args): @@ -333,6 +337,8 @@ def jac(self, x: RealArray = None, *args, **kwargs): """ Called by proc0 """ + if np.all(x == self.x_cache) and (self.jac_cache is not None): + return self.jac_cache ARB_VAL = 100 logger.debug("Entering jac evaluation") @@ -352,7 +358,9 @@ def jac(self, x: RealArray = None, *args, **kwargs): dtype=np.int32) self.mpi.mobilize_leaders(ARB_VAL) # Any value not equal to STOP - self.mpi.comm_leaders.bcast(x, root=0) + full_x = self.opt.full_x + self.mpi.comm_leaders.bcast(full_x, root=0) + self.opt.full_x = full_x jac, xs, evals = self._jac(x) logger.debug(f'jac is {jac}') @@ -378,4 +386,8 @@ def jac(self, x: RealArray = None, *args, **kwargs): self.eval_cnt += nevals + # cache it + self.x_cache = x + self.jac_cache = jac + return jac diff --git a/src/simsopt/objectives/utilities.py b/src/simsopt/objectives/utilities.py index 8e63954ff..2fe22c2e2 100644 --- a/src/simsopt/objectives/utilities.py +++ b/src/simsopt/objectives/utilities.py @@ -5,7 +5,7 @@ from .._core.optimizable import Optimizable from .._core.derivative import Derivative, derivative_dec -__all__ = ['MPIObjective', 'QuadraticPenalty', 'Weight', 'forward_backward'] +__all__ = ['MPIOptimizable', 'MPIObjective', 'QuadraticPenalty', 'Weight', 'forward_backward'] def forward_backward(P, L, U, rhs, iterative_refinement=False): @@ -49,6 +49,59 @@ def sum_across_comm(derivative, comm): return Derivative(newdict) +class MPIOptimizable(Optimizable): + + def __init__(self, optimizables, attributes, comm): + r""" + Ensures that a list of Optimizables on separate ranks have a consistent set of attributes on all ranks. + For example, say that all ranks have the list ``optimizables``. Rank ``i`` modifies attributes + of ``optimizable[i]``. The value attribute ``attr``, i.e., ``optimizables[i].attr`` potentially + will be different on ranks ``i`` and ``j``, for ``i`` not equal to ``j``. This class ensures that + if the cache is invalidated on the ``Optimizables`` in the list ``optimizables``, then when the list + is accessed, the attributes in ``attributes`` will be communicated accross all ranks. + + Args: + objectives: A python list of ``Optimizables`` with attributes in ``attributes`` that can be + communicated using ``mpi4py``. + attributes: A python list of strings corresponding to the list of attributes that is to be + maintained consistent across all ranks. + comm: The MPI communicator to use. + """ + + from simsopt._core.util import parallel_loop_bounds + startidx, endidx = parallel_loop_bounds(comm, len(optimizables)) + self.local_optimizables = optimizables[startidx:endidx] + self.global_optimizables = optimizables + + self.comm = comm + self.attributes = attributes + Optimizable.__init__(self, x0=np.asarray([]), depends_on=optimizables) + + for opt in optimizables: + for attr in self.attributes: + if not hasattr(opt, attr): + raise Exception(f'All Optimizables in the optimizable list must contain the attribute {attr}') + + def __getitem__(self, key): + if self.need_to_communicate: + self.communicate() + return self.global_optimizables[key] + + def communicate(self): + if self.need_to_communicate: + for attr in self.attributes: + local_vals = [getattr(J, attr) for J in self.local_optimizables] + global_vals = local_vals if self.comm is None else [i for o in self.comm.allgather(local_vals) for i in o] + for val, J in zip(global_vals, self.global_optimizables): + if J in self.local_optimizables: + continue + setattr(J, attr, val) + self.need_to_communicate = False + + def recompute_bell(self, parent=None): + self.need_to_communicate = True + + class MPIObjective(Optimizable): def __init__(self, objectives, comm, needs_splitting=False):