Skip to content

Commit

Permalink
Fix documentation check errors
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-graham committed Apr 26, 2024
1 parent 78f6fcc commit b7d6957
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 9 deletions.
6 changes: 3 additions & 3 deletions src/calibr/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
GaussianProcessParameterFitter,
fit_gaussian_process_parameters_map,
)
from .optimization import GlobalMinimizer, minimize_with_restarts
from .optimization import GlobalMinimizer, GlobalMinimizerKwarg, minimize_with_restarts

#: Type alias for function generating random initial values for inputs.
InitialInputSampler: TypeAlias = Callable[[Generator, int], Array]
Expand All @@ -40,7 +40,7 @@ def get_next_inputs_batch_by_joint_optimization(
batch_size: int,
*,
minimize_function: GlobalMinimizer = minimize_with_restarts,
**minimize_function_kwargs,
**minimize_function_kwargs: GlobalMinimizerKwarg,
) -> tuple[Array, float]:
"""
Get next batch of inputs to evaluate by jointly optimizing acquisition function.
Expand Down Expand Up @@ -89,7 +89,7 @@ def get_next_inputs_batch_by_greedy_optimization(
batch_size: int,
*,
minimize_function: GlobalMinimizer = minimize_with_restarts,
**minimize_function_kwargs,
**minimize_function_kwargs: GlobalMinimizerKwarg,
) -> tuple[Array, float]:
"""
Get next batch of inputs to evaluate by greedily optimizing acquisition function.
Expand Down
10 changes: 5 additions & 5 deletions src/calibr/emulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from jax.typing import ArrayLike
from numpy.random import Generator

from .optimization import GlobalMinimizer, minimize_with_restarts
from .optimization import GlobalMinimizer, GlobalMinimizerKwarg, minimize_with_restarts

try:
import mici
Expand Down Expand Up @@ -76,11 +76,11 @@ def get_gaussian_process_factory(
mean_function: Mean function for Gaussian process.
covariance_function: Covariance function for Gaussian process.
neg_log_prior_density: Negative logarithm of density of prior distribution on
vector of unconstrained parameters for Gaussian process model.
vector of unconstrained parameters for Gaussian process model.
transform_parameters: Function which maps flat unconstrained parameter vector to
a dictionary of (potential constrained) parameters, keyed by parameter name.
a dictionary of (potential constrained) parameters, keyed by parameter name.
sample_unconstrained_parameters: Function generating random values for
unconstrained vector of Gaussian process parameters.
unconstrained vector of Gaussian process parameters.
Returns:
Gaussian process factory function.
Expand Down Expand Up @@ -116,7 +116,7 @@ def fit_gaussian_process_parameters_map(
gaussian_process: GaussianProcessModel,
*,
minimize_function: GlobalMinimizer = minimize_with_restarts,
**minimize_function_kwargs,
**minimize_function_kwargs: GlobalMinimizerKwarg,
) -> ParametersDict:
"""Fit parameters of Gaussian process model by maximimizing posterior density.
Expand Down
20 changes: 19 additions & 1 deletion src/calibr/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections.abc import Callable
from heapq import heappush
from typing import Protocol, TypeAlias
from typing import Any, Protocol, TypeAlias

import jax
from jax.typing import ArrayLike
Expand All @@ -22,6 +22,8 @@ class ConvergenceError(Exception):
#: Type alias for function sampling initial optimization state given random generator
InitialStateSampler: TypeAlias = Callable[[Generator], ndarray]

GlobalMinimizerKwarg: TypeAlias = Any


class GlobalMinimizer(Protocol):
"""Function which attempts to find global minimum of a scalar objective function."""
Expand All @@ -31,6 +33,7 @@ def __call__(
objective_function: ObjectiveFunction,
sample_initial_state: InitialStateSampler,
rng: Generator,
**kwargs: GlobalMinimizerKwarg,
) -> tuple[jax.Array, float]:
"""
Minimize a differentiable objective function.
Expand All @@ -45,6 +48,7 @@ def __call__(
random number generator returns a random initial state for optimization
of appropriate dimension.
rng: Seeded NumPy random number generator.
**kwargs: Any keyword arguments to global minimizer function.
Returns:
Tuple with first entry the state corresponding to the minima point and the
Expand All @@ -71,6 +75,14 @@ def hvp(x: ArrayLike, v: ArrayLike) -> jax.Array:
return hvp


def _check_unknown_kwargs(unknown_kwargs: dict[str, GlobalMinimizerKwarg]) -> None:
if unknown_kwargs:
msg = ". ".join(
f"Unknown keyword argument {k}={v}" for k, v in unknown_kwargs.items()
)
raise ValueError(msg)


def minimize_with_restarts(
objective_function: ObjectiveFunction,
sample_initial_state: InitialStateSampler,
Expand All @@ -82,6 +94,7 @@ def minimize_with_restarts(
minimize_max_iterations: int | None = None,
minimize_tol: float | None = None,
logging_function: Callable[[str], None] = lambda _: None,
**unknown_kwargs: GlobalMinimizerKwarg,
) -> tuple[jax.Array, float]:
"""Minimize a differentiable objective function with random restarts.
Expand All @@ -102,6 +115,8 @@ def minimize_with_restarts(
random number generator returns a random initial state for optimization of
appropriate dimension.
rng: Seeded NumPy random number generator.
Keyword Args:
number_minima_to_find: Number of candidate minima of objective function to try
to find.
maximum_minimize_calls: Maximum number of times to try calling
Expand All @@ -120,6 +135,7 @@ def minimize_with_restarts(
Tuple with first entry the state corresponding to the best minima candidate
found and the second entry the corresponding objective function value.
"""
_check_unknown_kwargs(unknown_kwargs)
minima_found: list[tuple[jax.Array, int, jax.Array]] = []
minimize_calls = 0
while (
Expand Down Expand Up @@ -171,6 +187,7 @@ def basin_hopping(
minimize_method: str = "Newton-CG",
minimize_max_iterations: int | None = None,
minimize_tol: float | None = None,
**unknown_kwargs: GlobalMinimizerKwarg,
) -> tuple[jax.Array, float]:
"""Minimize a differentiable objective function with SciPy basin-hopping algorithm.
Expand Down Expand Up @@ -201,6 +218,7 @@ def basin_hopping(
Tuple with first entry the state corresponding to the best minima candidate
found and the second entry the corresponding objective function value.
"""
_check_unknown_kwargs(unknown_kwargs)
results = _basin_hopping(
jax.jit(objective_function),
x0=sample_initial_state(rng),
Expand Down

0 comments on commit b7d6957

Please sign in to comment.