Skip to content

Commit

Permalink
Merge.
Browse files Browse the repository at this point in the history
  • Loading branch information
ejpaul committed Aug 19, 2024
1 parent 4559e25 commit 415e6e5
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 9 deletions.
28 changes: 20 additions & 8 deletions src/simsopt/_core/finite_difference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# coding: utf-8
# Copyright (c) HiddenSymmetries Development Team.
# Distributed under the terms of the LGPL License
# Distributed under the terms of the MIT License

"""
This module provides Jacobian evaluated with finite difference scheme
Expand All @@ -13,7 +13,7 @@
import collections
from time import time
from datetime import datetime
from typing import Callable, Sequence
from typing import Callable, Union, IO
from numbers import Real

import numpy as np
Expand Down Expand Up @@ -132,7 +132,7 @@ def __init__(self, func: Callable,
abs_step: Real = 1.0e-7,
rel_step: Real = 0.0,
diff_method: str = "forward",
log_file: Union[str, typing.IO] = "jac_log") -> None:
log_file: Union[str, IO] = "jac_log") -> None:

try:
if not isinstance(func.__self__, Optimizable):
Expand Down Expand Up @@ -163,6 +163,10 @@ def __init__(self, func: Callable,
self.jac_size = None
self.eval_cnt = 1

# initialize cache
self.x_cache = None
self.jac_cache = None

def __enter__(self):
self.mpi_apart()
self.init_log()
Expand Down Expand Up @@ -294,14 +298,14 @@ def mpi_leaders_task(self, *args):
logger.debug('mpi leaders task')

# x is a buffer for receiving the state vector:
x = np.empty(self.opt.dof_size, dtype='d')
full_x = np.empty(self.opt.full_dof_size, dtype='d')
# If we make it here, we must be doing a fd_jac_par
# calculation, so receive the state vector: mpi4py has
# separate bcast and Bcast functions!! comm.Bcast(x,
# root=0)
x = self.mpi.comm_leaders.bcast(x, root=0)
logger.debug(f'mpi leaders loop x={x}')
self.opt.x = x
full_x = self.mpi.comm_leaders.bcast(full_x, root=0)
logger.debug(f'mpi leaders loop full_x={full_x}')
self.opt.full_x = full_x
self._jac()

def mpi_workers_task(self, *args):
Expand Down Expand Up @@ -333,6 +337,8 @@ def jac(self, x: RealArray = None, *args, **kwargs):
"""
Called by proc0
"""
if np.all(x == self.x_cache) and (self.jac_cache is not None):
return self.jac_cache

ARB_VAL = 100
logger.debug("Entering jac evaluation")
Expand All @@ -352,7 +358,9 @@ def jac(self, x: RealArray = None, *args, **kwargs):
dtype=np.int32)

self.mpi.mobilize_leaders(ARB_VAL) # Any value not equal to STOP
self.mpi.comm_leaders.bcast(x, root=0)
full_x = self.opt.full_x
self.mpi.comm_leaders.bcast(full_x, root=0)
self.opt.full_x = full_x

jac, xs, evals = self._jac(x)
logger.debug(f'jac is {jac}')
Expand All @@ -378,4 +386,8 @@ def jac(self, x: RealArray = None, *args, **kwargs):

self.eval_cnt += nevals

# cache it
self.x_cache = x
self.jac_cache = jac

return jac
55 changes: 54 additions & 1 deletion src/simsopt/objectives/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .._core.optimizable import Optimizable
from .._core.derivative import Derivative, derivative_dec

__all__ = ['MPIObjective', 'QuadraticPenalty', 'Weight', 'forward_backward']
__all__ = ['MPIOptimizable', 'MPIObjective', 'QuadraticPenalty', 'Weight', 'forward_backward']


def forward_backward(P, L, U, rhs, iterative_refinement=False):
Expand Down Expand Up @@ -49,6 +49,59 @@ def sum_across_comm(derivative, comm):
return Derivative(newdict)


class MPIOptimizable(Optimizable):

def __init__(self, optimizables, attributes, comm):
r"""
Ensures that a list of Optimizables on separate ranks have a consistent set of attributes on all ranks.
For example, say that all ranks have the list ``optimizables``. Rank ``i`` modifies attributes
of ``optimizable[i]``. The value attribute ``attr``, i.e., ``optimizables[i].attr`` potentially
will be different on ranks ``i`` and ``j``, for ``i`` not equal to ``j``. This class ensures that
if the cache is invalidated on the ``Optimizables`` in the list ``optimizables``, then when the list
is accessed, the attributes in ``attributes`` will be communicated accross all ranks.
Args:
objectives: A python list of ``Optimizables`` with attributes in ``attributes`` that can be
communicated using ``mpi4py``.
attributes: A python list of strings corresponding to the list of attributes that is to be
maintained consistent across all ranks.
comm: The MPI communicator to use.
"""

from simsopt._core.util import parallel_loop_bounds
startidx, endidx = parallel_loop_bounds(comm, len(optimizables))
self.local_optimizables = optimizables[startidx:endidx]
self.global_optimizables = optimizables

self.comm = comm
self.attributes = attributes
Optimizable.__init__(self, x0=np.asarray([]), depends_on=optimizables)

for opt in optimizables:
for attr in self.attributes:
if not hasattr(opt, attr):
raise Exception(f'All Optimizables in the optimizable list must contain the attribute {attr}')

def __getitem__(self, key):
if self.need_to_communicate:
self.communicate()
return self.global_optimizables[key]

def communicate(self):
if self.need_to_communicate:
for attr in self.attributes:
local_vals = [getattr(J, attr) for J in self.local_optimizables]
global_vals = local_vals if self.comm is None else [i for o in self.comm.allgather(local_vals) for i in o]
for val, J in zip(global_vals, self.global_optimizables):
if J in self.local_optimizables:
continue
setattr(J, attr, val)
self.need_to_communicate = False

def recompute_bell(self, parent=None):
self.need_to_communicate = True


class MPIObjective(Optimizable):

def __init__(self, objectives, comm, needs_splitting=False):
Expand Down

0 comments on commit 415e6e5

Please sign in to comment.