Skip to content

Commit

Permalink
Add support for greedy optimization of acquisition function
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-graham committed Apr 24, 2024
1 parent 35808cf commit 9fa7b99
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 5 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ automatic differentiation.
The package is still in the early stages of development, with only a subset of the
algorithmic variants proposed by Järvenpää, Gutmann, Vehtari and Marttinen (2021)
currently implemented. In particular there is no support yet for models with noisy
likelihood evaluations or greedy strategies for optimizing the acquisition functions.
Expect lots of rough edges!
likelihood evaluations. Expect lots of rough edges!

This project is developed in collaboration with the [Centre for Advanced Research Computing](https://ucl.ac.uk/arc), University College London.

Expand Down
61 changes: 58 additions & 3 deletions src/calibr/calibration.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Functions for iteratively calibrating the parameters of a probabilistic model."""

from collections.abc import Callable
from functools import partial
from typing import TypeAlias

import jax.numpy as jnp
import numpy as np
from emul.types import DataDict, ParametersDict, PosteriorPredictiveMeanAndVariance
from jax import Array
Expand Down Expand Up @@ -53,8 +55,8 @@ def get_next_inputs_batch_by_joint_optimization(
batch of inputs when passed a random number generator and batch size. Used
to initialize state for optimization runs.
batch_size: Number of inputs in batch.
minimize_function: Function used to attempt to find global minimum of
acquisition function.
minimize_function: Function used to attempt to find minimum of acquisition
function.
**minimize_function_kwargs: Any keyword arguments to pass to
`minimize_function` function used to optimize acquisition function.
Expand All @@ -80,6 +82,59 @@ def acquisition_function_flat_input(flat_inputs: ArrayLike) -> float:
return flat_inputs.reshape((batch_size, -1)), min_acquisition_function


def get_next_inputs_batch_by_greedy_optimization(
rng: Generator,
acquisition_function: AcquisitionFunction,
sample_initial_inputs: InitialInputSampler,
batch_size: int,
*,
minimize_function: GlobalMinimizer = minimize_with_restarts,
**minimize_function_kwargs,
) -> tuple[Array, float]:
"""
Get next batch of inputs to evaluate by greedily optimizing acquisition function.
Sequentially minimizes acquisition function for `b` in 1 to `batch_size` by fixing
`b - 1` inputs already optimized and minimizing over a single new input in each
iteration.
Args:
rng: NumPy random number generator for initializing optimization runs.
acquisition_function: Scalar-valued function of a batch of inputs to optimize to
find new batch of inputs to evaluate model for.
sample_initial_inputs: Function outputting reasonable random initial values for
batch of inputs when passed a random number generator and batch size. Used
to initialize state for optimization runs.
batch_size: Number of inputs in batch.
minimize_function: Function used to attempt to find minimum of (sequence of)
acquisition functions.
**minimize_function_kwargs: Any keyword arguments to pass to
`minimize_function` function used to optimize acquisition function.
Returns:
Tuple of optimized inputs batch and corresponding value of acquisition function.
"""

def acquisition_function_greedy(
current_input: ArrayLike, fixed_inputs: list[ArrayLike]
) -> float:
return acquisition_function(jnp.stack([current_input, *fixed_inputs]))

fixed_inputs: list[ArrayLike] = []
for _ in range(batch_size):
current_input, min_acquisition_function = minimize_function(
objective_function=partial(
acquisition_function_greedy, fixed_inputs=fixed_inputs
),
sample_initial_state=lambda r: sample_initial_inputs(r, 1).flatten(),
rng=rng,
**minimize_function_kwargs,
)
fixed_inputs.append(current_input)

return np.stack(fixed_inputs), min_acquisition_function


def calibrate( # noqa: PLR0913
num_initial_inputs: int,
batch_size: int,
Expand Down Expand Up @@ -149,7 +204,7 @@ def calibrate( # noqa: PLR0913
given Gaussian process posterior predictive functions.
get_next_inputs_batch: Function which computes next batch of inputs to evaluate
model at by optimizing the current acquisition function. Passed a seeded
random number generator, acquisition functino, input sampler and batch size.
random number generator, acquisition function, input sampler and batch size.
end_of_iteration_callback: Optional callback function evaluate at end of each
calibration iteration, for example for logging metrics or plotting / saving
intermediate outputs. Passed current iteration index, Gaussian process
Expand Down

0 comments on commit 9fa7b99

Please sign in to comment.