From b7d6957b134a87a96490d7c796c2805ec9369d46 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Fri, 26 Apr 2024 17:49:38 +0100 Subject: [PATCH] Fix documentation check errors --- src/calibr/calibration.py | 6 +++--- src/calibr/emulation.py | 10 +++++----- src/calibr/optimization.py | 20 +++++++++++++++++++- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/src/calibr/calibration.py b/src/calibr/calibration.py index 3651fca..da7b4fe 100644 --- a/src/calibr/calibration.py +++ b/src/calibr/calibration.py @@ -22,7 +22,7 @@ GaussianProcessParameterFitter, fit_gaussian_process_parameters_map, ) -from .optimization import GlobalMinimizer, minimize_with_restarts +from .optimization import GlobalMinimizer, GlobalMinimizerKwarg, minimize_with_restarts #: Type alias for function generating random initial values for inputs. InitialInputSampler: TypeAlias = Callable[[Generator, int], Array] @@ -40,7 +40,7 @@ def get_next_inputs_batch_by_joint_optimization( batch_size: int, *, minimize_function: GlobalMinimizer = minimize_with_restarts, - **minimize_function_kwargs, + **minimize_function_kwargs: GlobalMinimizerKwarg, ) -> tuple[Array, float]: """ Get next batch of inputs to evaluate by jointly optimizing acquisition function. @@ -89,7 +89,7 @@ def get_next_inputs_batch_by_greedy_optimization( batch_size: int, *, minimize_function: GlobalMinimizer = minimize_with_restarts, - **minimize_function_kwargs, + **minimize_function_kwargs: GlobalMinimizerKwarg, ) -> tuple[Array, float]: """ Get next batch of inputs to evaluate by greedily optimizing acquisition function. diff --git a/src/calibr/emulation.py b/src/calibr/emulation.py index f678f25..bf52338 100644 --- a/src/calibr/emulation.py +++ b/src/calibr/emulation.py @@ -17,7 +17,7 @@ from jax.typing import ArrayLike from numpy.random import Generator -from .optimization import GlobalMinimizer, minimize_with_restarts +from .optimization import GlobalMinimizer, GlobalMinimizerKwarg, minimize_with_restarts try: import mici @@ -76,11 +76,11 @@ def get_gaussian_process_factory( mean_function: Mean function for Gaussian process. covariance_function: Covariance function for Gaussian process. neg_log_prior_density: Negative logarithm of density of prior distribution on - vector of unconstrained parameters for Gaussian process model. + vector of unconstrained parameters for Gaussian process model. transform_parameters: Function which maps flat unconstrained parameter vector to - a dictionary of (potential constrained) parameters, keyed by parameter name. + a dictionary of (potential constrained) parameters, keyed by parameter name. sample_unconstrained_parameters: Function generating random values for - unconstrained vector of Gaussian process parameters. + unconstrained vector of Gaussian process parameters. Returns: Gaussian process factory function. @@ -116,7 +116,7 @@ def fit_gaussian_process_parameters_map( gaussian_process: GaussianProcessModel, *, minimize_function: GlobalMinimizer = minimize_with_restarts, - **minimize_function_kwargs, + **minimize_function_kwargs: GlobalMinimizerKwarg, ) -> ParametersDict: """Fit parameters of Gaussian process model by maximimizing posterior density. diff --git a/src/calibr/optimization.py b/src/calibr/optimization.py index ccf7cd5..3befee5 100644 --- a/src/calibr/optimization.py +++ b/src/calibr/optimization.py @@ -2,7 +2,7 @@ from collections.abc import Callable from heapq import heappush -from typing import Protocol, TypeAlias +from typing import Any, Protocol, TypeAlias import jax from jax.typing import ArrayLike @@ -22,6 +22,8 @@ class ConvergenceError(Exception): #: Type alias for function sampling initial optimization state given random generator InitialStateSampler: TypeAlias = Callable[[Generator], ndarray] +GlobalMinimizerKwarg: TypeAlias = Any + class GlobalMinimizer(Protocol): """Function which attempts to find global minimum of a scalar objective function.""" @@ -31,6 +33,7 @@ def __call__( objective_function: ObjectiveFunction, sample_initial_state: InitialStateSampler, rng: Generator, + **kwargs: GlobalMinimizerKwarg, ) -> tuple[jax.Array, float]: """ Minimize a differentiable objective function. @@ -45,6 +48,7 @@ def __call__( random number generator returns a random initial state for optimization of appropriate dimension. rng: Seeded NumPy random number generator. + **kwargs: Any keyword arguments to global minimizer function. Returns: Tuple with first entry the state corresponding to the minima point and the @@ -71,6 +75,14 @@ def hvp(x: ArrayLike, v: ArrayLike) -> jax.Array: return hvp +def _check_unknown_kwargs(unknown_kwargs: dict[str, GlobalMinimizerKwarg]) -> None: + if unknown_kwargs: + msg = ". ".join( + f"Unknown keyword argument {k}={v}" for k, v in unknown_kwargs.items() + ) + raise ValueError(msg) + + def minimize_with_restarts( objective_function: ObjectiveFunction, sample_initial_state: InitialStateSampler, @@ -82,6 +94,7 @@ def minimize_with_restarts( minimize_max_iterations: int | None = None, minimize_tol: float | None = None, logging_function: Callable[[str], None] = lambda _: None, + **unknown_kwargs: GlobalMinimizerKwarg, ) -> tuple[jax.Array, float]: """Minimize a differentiable objective function with random restarts. @@ -102,6 +115,8 @@ def minimize_with_restarts( random number generator returns a random initial state for optimization of appropriate dimension. rng: Seeded NumPy random number generator. + + Keyword Args: number_minima_to_find: Number of candidate minima of objective function to try to find. maximum_minimize_calls: Maximum number of times to try calling @@ -120,6 +135,7 @@ def minimize_with_restarts( Tuple with first entry the state corresponding to the best minima candidate found and the second entry the corresponding objective function value. """ + _check_unknown_kwargs(unknown_kwargs) minima_found: list[tuple[jax.Array, int, jax.Array]] = [] minimize_calls = 0 while ( @@ -171,6 +187,7 @@ def basin_hopping( minimize_method: str = "Newton-CG", minimize_max_iterations: int | None = None, minimize_tol: float | None = None, + **unknown_kwargs: GlobalMinimizerKwarg, ) -> tuple[jax.Array, float]: """Minimize a differentiable objective function with SciPy basin-hopping algorithm. @@ -201,6 +218,7 @@ def basin_hopping( Tuple with first entry the state corresponding to the best minima candidate found and the second entry the corresponding objective function value. """ + _check_unknown_kwargs(unknown_kwargs) results = _basin_hopping( jax.jit(objective_function), x0=sample_initial_state(rng),