Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/add probabilistic iterative methods #983

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion benchmark/blobs_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from sklearn.datasets import make_blobs

from coreax import Data, SlicedScoreMatching
from coreax.benchmark_util import IterativeKernelHerding
from coreax.kernels import (
SquaredExponentialKernel,
SteinKernel,
Expand All @@ -46,7 +47,6 @@
from coreax.metrics import KSD, MMD
from coreax.solvers import (
CompressPlusPlus,
IterativeKernelHerding,
KernelHerding,
KernelThinning,
RandomSample,
Expand Down Expand Up @@ -188,6 +188,18 @@ def setup_solvers(
num_iterations=5,
),
),
(
"CubicProbIterativeHerding",
IterativeKernelHerding(
coreset_size=coreset_size,
kernel=sq_exp_kernel,
probabilistic=True,
temperature=0.001,
random_key=random_key,
num_iterations=10,
t_schedule=1 / jnp.linspace(10, 100, 10) ** 3,
),
),
]


Expand Down
81 changes: 75 additions & 6 deletions coreax/benchmark_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@
"""

from collections.abc import Callable
from typing import Optional, Union
from typing import Optional, TypeVar, Union

import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, Float

from coreax import Data
from coreax import Coresubset, Data, SupervisedData
from coreax.kernels import SquaredExponentialKernel, SteinKernel, median_heuristic
from coreax.score_matching import KernelDensityMatching
from coreax.solvers import (
CompressPlusPlus,
IterativeKernelHerding,
HerdingState,
KernelHerding,
KernelThinning,
MapReduce,
Expand All @@ -43,6 +43,46 @@
)
from coreax.util import KeyArrayLike

_Data = TypeVar("_Data", Data, SupervisedData)


class IterativeKernelHerding(KernelHerding[_Data]): # pylint: disable=too-many-ancestors
r"""
Iterative Kernel Herding - perform multiple refinements of Kernel Herding.

Wrapper around :meth:`~coreax.solvers.KernelHerding.reduce_iterative` for
benchmarking purposes.

:param num_iterations: Number of refinement iterations
:param t_schedule: An :class:`Array` of length `num_iterations`, where
`t_schedule[i]` is the temperature parameter used for iteration i. If None,
standard Kernel Herding is used
"""

num_iterations: int = 1
t_schedule: Optional[Array] = None

def reduce(
self,
dataset: _Data,
solver_state: Optional[HerdingState] = None,
) -> tuple[Coresubset[_Data], HerdingState]:
"""
Perform Kernel Herding reduction followed by additional refinement iterations.

:param dataset: The dataset to process.
:param solver_state: Optional solver state.
:return: Refined coresubset and final solver state.
"""
coreset, reduced_solver_state = self.reduce_iterative(
dataset,
solver_state,
num_iterations=self.num_iterations,
t_schedule=self.t_schedule,
)

return coreset, reduced_solver_state


def calculate_delta(n: int) -> Float[Array, "1"]:
r"""
Expand Down Expand Up @@ -100,8 +140,7 @@ def initialise_solvers( # noqa: C901
# Set up kernel using median heuristic
num_data_points = len(train_data_umap)
num_samples_length_scale = min(num_data_points, 300)
random_seed = 45
generator = np.random.default_rng(random_seed)
generator = np.random.default_rng(seed=45)
idx = generator.choice(num_data_points, num_samples_length_scale, replace=False)
length_scale = median_heuristic(jnp.asarray(train_data_umap[idx]))
kernel = SquaredExponentialKernel(length_scale=length_scale)
Expand Down Expand Up @@ -256,13 +295,43 @@ def _get_iterative_herding_solver(
return herding_solver
return MapReduce(herding_solver, leaf_size=leaf_size)

def _get_cubic_iterative_herding_solver(
_size: int,
) -> Union[IterativeKernelHerding, MapReduce]:
"""
Set up KernelHerding with probabilistic selection.

If the `leaf_size` is provided, the solver uses ``MapReduce`` to reduce
datasets.

:param _size: The size of the coreset to be generated.
:return: An `IterativeKernelHerding` solver if `leaf_size` is `None`, otherwise
a `MapReduce` solver with `IterativeKernelHerding` as the base solver.
"""
n_iter = 10
t_schedule = 1 / jnp.linspace(10, 100, n_iter) ** 3

herding_solver = IterativeKernelHerding(
coreset_size=_size,
kernel=kernel,
probabilistic=True,
temperature=0.001,
random_key=key,
num_iterations=n_iter,
t_schedule=t_schedule,
)
if leaf_size is None:
return herding_solver
return MapReduce(herding_solver, leaf_size=leaf_size)

return {
"Random Sample": _get_random_solver,
"RP Cholesky": _get_rp_solver,
"Kernel Herding": _get_herding_solver,
"Stein Thinning": _get_stein_solver,
"Kernel Thinning": _get_thinning_solver,
"Compress++": _get_compress_solver,
"Probabilistic Iterative Herding": _get_probabilistic_herding_solver,
"Iterative Probabilistic Herding (constant)": _get_probabilistic_herding_solver,
"Iterative Herding": _get_iterative_herding_solver,
"Iterative Probabilistic Herding (cubic)": _get_cubic_iterative_herding_solver,
}
2 changes: 0 additions & 2 deletions coreax/solvers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
GreedyKernelPoints,
GreedyKernelPointsState,
HerdingState,
IterativeKernelHerding,
KernelHerding,
KernelThinning,
RandomSample,
Expand Down Expand Up @@ -62,5 +61,4 @@
"CaratheodoryRecombination",
"TreeRecombination",
"CompressPlusPlus",
"IterativeKernelHerding",
]
125 changes: 51 additions & 74 deletions coreax/solvers/coresubset.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,57 @@ def selection_function(
)
return refined_coreset, HerdingState(gramian_row_mean)

def reduce_iterative(
self,
dataset: _Data,
solver_state: Optional[HerdingState] = None,
num_iterations: int = 1,
t_schedule: Optional[Shaped[Array, " {num_iterations}"]] = None,
) -> tuple[Coresubset[_Data], HerdingState]:
"""
Reduce a dataset to a coreset by refining iteratively.

:param dataset: Dataset to reduce
:param solver_state: Solution state information, primarily used to cache
expensive intermediate solution step values
:param num_iterations: Number of iterations of the refine method
:param t_schedule: An :class:`Array` of length `num_iterations`, where
`t_schedule[i]` is the temperature parameter used for iteration i. If None,
standard Kernel Herding is used
:return: A coresubset and relevant intermediate solver state information
"""
initial_coreset = _initial_coresubset(0, self.coreset_size, dataset)
if solver_state is None:
x, bs, un = initial_coreset.pre_coreset_data, self.block_size, self.unroll
solver_state = HerdingState(
self.kernel.gramian_row_mean(x, block_size=bs, unroll=un)
)

def refine_iteration(i: int, coreset: Coresubset) -> Coresubset:
"""
Perform one iteration of the refine method.

:param i: Iteration number
:param coreset: Coreset to be refined
"""
# Update the random key
new_solver = eqx.tree_at(
lambda x: x.random_key, self, jr.fold_in(self.random_key, i)
)
# If the temperature schedule is provided, update temperature too
if t_schedule is not None:
new_solver = eqx.tree_at(
lambda x: x.temperature, new_solver, t_schedule[i]
)

coreset, _ = new_solver.refine(coreset, solver_state)
return coreset

return (
jax.lax.fori_loop(0, num_iterations, refine_iteration, initial_coreset),
solver_state,
)


class SteinThinning(
RefinementSolver[_Data, None], ExplicitSizeSolver, PaddingInvariantSolver
Expand Down Expand Up @@ -1439,77 +1490,3 @@ def _compress_plus_plus(indices: Array) -> Array:

plus_plus_indices = _compress_plus_plus(clipped_indices)
return Coresubset(Data(plus_plus_indices), dataset), None


class IterativeKernelHerding(ExplicitSizeSolver):
r"""
Iterative Kernel Herding - perform multiple refinements of Kernel Herding.

Reduce using :class:`~coreax.solvers.KernelHerding` then refine set number of times.
All the parameters except the `num_iterations` are passed to
:class:`~coreax.solvers.KernelHerding`.

:param coreset_size: The desired size of the solved coreset.
:param kernel: :class:`~coreax.kernels.ScalarValuedKernel` instance implementing a
kernel function.
:math:`k: \\mathbb{R}^d \times \\mathbb{R}^d \rightarrow \\mathbb{R}`
:param unique: Boolean that ensures the resulting coresubset will only contain
unique elements.
:param block_size: Block size passed to
:meth:`~coreax.kernels.ScalarValuedKernel.compute_mean`.
:param unroll: Unroll parameter passed to
:meth:`~coreax.kernels.ScalarValuedKernel.compute_mean`.
:param probabilistic: If :data:`True`, the elements are chosen probabilistically at
each iteration. Otherwise, standard Kernel Herding is run.
:param temperature: Temperature parameter, which controls how uniform the
probabilities are for probabilistic selection.
:param random_key: Key for random number generation, only used if probabilistic
:param num_iterations: Number of refinement iterations.
"""

num_iterations: int
kernel: ScalarValuedKernel
unique: bool = True
block_size: Optional[Union[int, tuple[Optional[int], Optional[int]]]] = None
unroll: Union[int, bool, tuple[Union[int, bool], Union[int, bool]]] = 1
probabilistic: bool = False
temperature: Union[float, Scalar] = eqx.field(default=1.0)
random_key: KeyArrayLike = eqx.field(default_factory=lambda: jax.random.key(0))

def reduce(
self,
dataset: _Data,
solver_state: Optional[HerdingState] = None,
) -> tuple[Coresubset[_Data], HerdingState]:
"""
Perform Kernel Herding reduction followed by additional refinement iterations.

:param dataset: The dataset to process.
:param solver_state: Optional solver state.
:return: Refined coresubset and final solver state.
"""
herding_solver = KernelHerding(
coreset_size=self.coreset_size,
kernel=self.kernel,
unique=self.unique,
block_size=self.block_size,
unroll=self.unroll,
probabilistic=self.probabilistic,
temperature=self.temperature,
random_key=self.random_key,
)

coreset, reduced_solver_state = herding_solver.reduce(dataset, solver_state)

def refine_step(_, state):
coreset, reduced_solver_state = state
coreset, reduced_solver_state = herding_solver.refine(
coreset, reduced_solver_state
)
return (coreset, reduced_solver_state)

coreset, reduced_solver_state = lax.fori_loop(
0, self.num_iterations, refine_step, (coreset, reduced_solver_state)
)

return coreset, reduced_solver_state
9 changes: 4 additions & 5 deletions documentation/source/benchmark.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ Benchmarking Coreset Algorithms

In this benchmark, we assess the performance of different coreset algorithms:
:class:`~coreax.solvers.KernelHerding`, :class:`~coreax.solvers.SteinThinning`,
:class:`~coreax.solvers.RandomSample`, :class:`~coreax.solvers.RPCholesky` and
:class:`~coreax.solvers.KernelThinning`, :class:`~coreax.solvers.CompressPlusPlus`,
:class:`~coreax.solvers.IterativeKernelHerding`. Each of these algorithms is evaluated
:class:`~coreax.solvers.RandomSample`, :class:`~coreax.solvers.RPCholesky`,
:class:`~coreax.solvers.KernelThinning`, and :class:`~coreax.solvers.CompressPlusPlus`.
Each of these algorithms is evaluated
across four different tests, providing a comparison of their performance and
applicability to various datasets.

Expand All @@ -29,8 +29,7 @@ these steps:

4. **Coreset Generation**: Coresets of various sizes are generated using the
different coreset algorithms. For :class:`~coreax.solvers.KernelHerding`,
:class:`~coreax.solvers.SteinThinning`, :class:`~coreax.solvers.KernelThinning`, and
:class:`~coreax.solvers.IterativeKernelHerding`,
:class:`~coreax.solvers.SteinThinning`, and :class:`~coreax.solvers.KernelThinning`,
:class:`~coreax.solvers.MapReduce` is employed to handle large-scale data.

5. **Training**: The model is trained using the selected coresets, and accuracy is
Expand Down
Loading