Skip to content

Commit

Permalink
Merge pull request #893 from gchq/feature/video_benchmarking
Browse files Browse the repository at this point in the history
feat: Initial implementation of video benchmark
  • Loading branch information
bk958178 authored Dec 19, 2024
2 parents c7ef0f7 + 874467a commit 6d9546f
Show file tree
Hide file tree
Showing 7 changed files with 348 additions and 246 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
and KSD metrics) (https://github.com/gchq/coreax/pull/802)
- David (extract pixel locations and values from an image and plot coresets side by
side for visual benchmarking) (https://github.com/gchq/coreax/pull/880)
- Pounce (extract frames from a video and use coreset algorithms to select the best
frames) (https://github.com/gchq/coreax/issues/892)
- `benchmark` dependency group for benchmarking dependencies.
(https://github.com/gchq/coreax/pull/888)
- Added a method `SquaredExponentialKernel.get_sqrt_kernel` which returns a square
Expand Down
18 changes: 14 additions & 4 deletions benchmark/mnist_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,19 @@ def compute_metrics(logits: jnp.ndarray, labels: jnp.ndarray) -> dict[str, jnp.n


class MLP(nn.Module):
"""Multi-layer perceptron with optional batch normalization and dropout."""
"""
Multi-layer perceptron with optional batch normalisation and dropout.
:param hidden_size: Number of units in the hidden layer.
:param output_size: Number of output units.
:param use_batchnorm: Whether to apply batch norm.
:param dropout_rate: Dropout rate to use during training.
"""

hidden_size: int
output_size: int = 10
use_batchnorm: bool = True
dropout_rate: float = 0.5
dropout_rate: float = 0.2

@nn.compact
def __call__(self, x: jnp.ndarray, training: bool = True) -> jnp.ndarray:
Expand All @@ -128,10 +135,11 @@ def __call__(self, x: jnp.ndarray, training: bool = True) -> jnp.ndarray:
:return: Output logits of the network.
"""
x = nn.Dense(self.hidden_size)(x)
if training:
x = nn.Dropout(rate=self.dropout_rate, deterministic=False)(x)
if self.use_batchnorm:
x = nn.BatchNorm(use_running_average=not training)(x)
x = nn.relu(x)
x = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(x)
x = nn.Dense(self.output_size)(x)
return x

Expand Down Expand Up @@ -474,7 +482,9 @@ def _get_stein_solver(_size: int) -> MapReduce:
train_data_umap[idx]
)
stein_kernel = SteinKernel(kernel, score_function)
stein_solver = SteinThinning(coreset_size=_size, kernel=stein_kernel)
stein_solver = SteinThinning(
coreset_size=_size, kernel=stein_kernel, regularise=False
)
return MapReduce(stein_solver, leaf_size=3 * _size)

def _get_random_solver(_size: int) -> RandomSample:
Expand Down
1 change: 0 additions & 1 deletion benchmark/mnist_benchmark_coresets_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def main() -> None:

# Run the experiment with 5 different random keys
for i in range(5):
print(f"Run {i + 1} of 5:")
key = jax.random.PRNGKey(i)
solvers = initialise_solvers(train_data_umap, key)
for getter in solvers:
Expand Down
Loading

0 comments on commit 6d9546f

Please sign in to comment.