From f1189768c3ba2edcc998d05c498e3d00933e2e8e Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Fri, 13 Dec 2024 12:49:18 +0000 Subject: [PATCH 1/8] feat: Initial implementation of video benchmark #892 --- benchmark/pounce_benchmark.py | 87 +++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 benchmark/pounce_benchmark.py diff --git a/benchmark/pounce_benchmark.py b/benchmark/pounce_benchmark.py new file mode 100644 index 00000000..ec0029af --- /dev/null +++ b/benchmark/pounce_benchmark.py @@ -0,0 +1,87 @@ +# © Crown Copyright GCHQ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Benchmark performance of different coreset algorithms on frames of a video. + +The benchmarking process follows these steps: +1. Load an input GIF and preprocess its frames. +2. Reshape the frame data and apply UMAP for dimensionality reduction. +3. Generate coresets using each algorithm and save the selected frames as GIFs. +4. Print the time taken to generate each coreset. +""" + +import os +import time +from pathlib import Path + +import imageio +import numpy as np +import umap +from jax import random +from mnist_benchmark import get_solver_name, initialise_solvers + +from coreax.data import Data + + +def benchmark_coreset_algorithms( + in_path: Path = Path("../examples/data/pounce/pounce.gif"), + out_dir: Path = Path("pounce"), + coreset_size: int = 10, +): + """ + Benchmark coreset algorithms by processing a video GIF. + + :param in_path: Path to the input GIF file. + :param out_dir: Directory to save the output GIFs for each coreset algorithm. + :param coreset_size: The size of the coreset. + """ + base_dir = os.path.dirname(os.path.abspath(__file__)) + + # Ensure paths are absolute and output directory exists + in_path = Path(os.path.join(base_dir, in_path)).resolve() + out_dir = Path(os.path.join(base_dir, out_dir)).resolve() + out_dir.mkdir(parents=True, exist_ok=True) + + # Load and preprocess video frames + _, *image_data = imageio.v2.mimread(in_path) + raw_data = np.asarray(image_data) + reshaped_data = raw_data.reshape(raw_data.shape[0], -1) + print(type(reshaped_data)) + + umap_model = umap.UMAP(densmap=True, n_components=25) + umap_data = umap_model.fit_transform(reshaped_data) + + solvers = initialise_solvers(umap_data, random.PRNGKey(45)) + + for get_solver in solvers: + solver = get_solver(coreset_size) + solver_name = get_solver_name(solver) + data = Data(umap_data) + + start_time = time.perf_counter() + coreset, _ = solver.reduce(data) + duration = time.perf_counter() - start_time + + selected_indices = np.sort(np.asarray(coreset.unweighted_indices)) + + # Extract corresponding frames from original data and save GIF + coreset_frames = raw_data[selected_indices] + output_gif_path = out_dir / f"{solver_name}_coreset.gif" + imageio.mimsave(output_gif_path, coreset_frames) + print(f"Saved {solver_name} coreset GIF to {output_gif_path}") + print(f"time taken: {solver_name:<25} {duration:<30.4f}") + + +if __name__ == "__main__": + benchmark_coreset_algorithms() From bbfd64e1004b2ca32edaef603ea2b500972bdc53 Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Fri, 13 Dec 2024 14:45:30 +0000 Subject: [PATCH 2/8] docs: Add video benchmarking to existing benchmarks #892 --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 09b1ab71..8b3c93d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 and KSD metrics) (https://github.com/gchq/coreax/pull/802) - David (extract pixel locations and values from an image and plot coresets side by side for visual benchmarking) (https://github.com/gchq/coreax/pull/880) + - Pounce (extract frames from a video and use coreset algorithms to select the best + frames) (https://github.com/gchq/coreax/issues/892) - `benchmark` dependency group for benchmarking dependencies. (https://github.com/gchq/coreax/pull/888) - Added a method `SquaredExponentialKernel.get_sqrt_kernel` which returns a square From ac022c88cd422b686364eed51ac4e128bb676782 Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Fri, 13 Dec 2024 15:29:52 +0000 Subject: [PATCH 3/8] chore: remove pylint comments to address pre-commit complains on remote #892 --- benchmark/mnist_benchmark_coresets_only.py | 1 - examples/david_map_reduce_weighted.py | 1 - 2 files changed, 2 deletions(-) diff --git a/benchmark/mnist_benchmark_coresets_only.py b/benchmark/mnist_benchmark_coresets_only.py index 6fb5398e..f1aa0950 100644 --- a/benchmark/mnist_benchmark_coresets_only.py +++ b/benchmark/mnist_benchmark_coresets_only.py @@ -99,7 +99,6 @@ def main() -> None: # Run the experiment with 5 different random keys for i in range(5): - print(f"Run {i + 1} of 5:") key = jax.random.PRNGKey(i) solvers = initialise_solvers(train_data_umap, key) for getter in solvers: diff --git a/examples/david_map_reduce_weighted.py b/examples/david_map_reduce_weighted.py index d0acebf0..58133088 100644 --- a/examples/david_map_reduce_weighted.py +++ b/examples/david_map_reduce_weighted.py @@ -63,7 +63,6 @@ # Examples are written to be easy to read, copy and paste by users, so we ignore the # pylint warnings raised that go against this approach -# pylint: disable=no-member # pylint: disable=too-many-locals # pylint: disable=too-many-statements # pylint: disable=duplicate-code From 7cbb306ef8cca5eaaf39c237c282aa43c7c6041c Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Tue, 17 Dec 2024 12:47:34 +0000 Subject: [PATCH 4/8] chore: Set regularise to False for stein_solver as it does not behave properly for high dimensional data with regularise enabled and pounce_benchmark does not use MapReduce #892 --- benchmark/mnist_benchmark.py | 4 +++- benchmark/pounce_benchmark.py | 8 ++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/benchmark/mnist_benchmark.py b/benchmark/mnist_benchmark.py index bf0a49f2..94c9b483 100644 --- a/benchmark/mnist_benchmark.py +++ b/benchmark/mnist_benchmark.py @@ -474,7 +474,9 @@ def _get_stein_solver(_size: int) -> MapReduce: train_data_umap[idx] ) stein_kernel = SteinKernel(kernel, score_function) - stein_solver = SteinThinning(coreset_size=_size, kernel=stein_kernel) + stein_solver = SteinThinning( + coreset_size=_size, kernel=stein_kernel, regularise=False + ) return MapReduce(stein_solver, leaf_size=3 * _size) def _get_random_solver(_size: int) -> RandomSample: diff --git a/benchmark/pounce_benchmark.py b/benchmark/pounce_benchmark.py index ec0029af..1d3d8c77 100644 --- a/benchmark/pounce_benchmark.py +++ b/benchmark/pounce_benchmark.py @@ -63,7 +63,11 @@ def benchmark_coreset_algorithms( umap_data = umap_model.fit_transform(reshaped_data) solvers = initialise_solvers(umap_data, random.PRNGKey(45)) - + # There is no need to use MapReduce as the data-size is small + solvers = [ + solver.base_solver if solver.__class__.__name__ == "MapReduce" else solver + for solver in solvers + ] for get_solver in solvers: solver = get_solver(coreset_size) solver_name = get_solver_name(solver) @@ -78,7 +82,7 @@ def benchmark_coreset_algorithms( # Extract corresponding frames from original data and save GIF coreset_frames = raw_data[selected_indices] output_gif_path = out_dir / f"{solver_name}_coreset.gif" - imageio.mimsave(output_gif_path, coreset_frames) + imageio.mimsave(output_gif_path, coreset_frames, loop=0) print(f"Saved {solver_name} coreset GIF to {output_gif_path}") print(f"time taken: {solver_name:<25} {duration:<30.4f}") From c8e08edf67434e5e940b62a3dc85fb36b2ae2266 Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Tue, 17 Dec 2024 14:47:26 +0000 Subject: [PATCH 5/8] chore: Add pylint disable for code repetition. #892 --- tests/performance/compare.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/performance/compare.py b/tests/performance/compare.py index e507e3d9..d9113836 100644 --- a/tests/performance/compare.py +++ b/tests/performance/compare.py @@ -90,6 +90,7 @@ def parse_args() -> Tuple[Path, Path, str, Path]: ) +# pylint: disable=R0801 def date_from_filename(path: Path) -> Optional[Tuple[datetime.datetime, str]]: """ Extract the date from a performance data file name. @@ -129,6 +130,7 @@ def date_from_filename(path: Path) -> Optional[Tuple[datetime.datetime, str]]: ), git_hash +# pylint: disable=R0801 def get_most_recent_historic_data( reference_directory: Path, ) -> FullPerformanceData: From 2f7570d781f25d112c91fdffec6829a698d2e64d Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Thu, 19 Dec 2024 10:43:41 +0000 Subject: [PATCH 6/8] chore: push with regularise set to False to run MNIST benchmark again on the GPU instance. --- benchmark/mnist_benchmark.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/benchmark/mnist_benchmark.py b/benchmark/mnist_benchmark.py index 94c9b483..8b622c30 100644 --- a/benchmark/mnist_benchmark.py +++ b/benchmark/mnist_benchmark.py @@ -111,12 +111,19 @@ def compute_metrics(logits: jnp.ndarray, labels: jnp.ndarray) -> dict[str, jnp.n class MLP(nn.Module): - """Multi-layer perceptron with optional batch normalization and dropout.""" + """ + Multi-layer perceptron with optional batch normalisation and dropout. + + :param hidden_size: Number of units in the hidden layer. + :param output_size: Number of output units. + :param use_batchnorm: Whether to apply batch norm. + :param dropout_rate: Dropout rate to use during training. + """ hidden_size: int output_size: int = 10 use_batchnorm: bool = True - dropout_rate: float = 0.5 + dropout_rate: float = 0.2 @nn.compact def __call__(self, x: jnp.ndarray, training: bool = True) -> jnp.ndarray: @@ -128,10 +135,11 @@ def __call__(self, x: jnp.ndarray, training: bool = True) -> jnp.ndarray: :return: Output logits of the network. """ x = nn.Dense(self.hidden_size)(x) + if training: + x = nn.Dropout(rate=self.dropout_rate, deterministic=False)(x) if self.use_batchnorm: x = nn.BatchNorm(use_running_average=not training)(x) x = nn.relu(x) - x = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(x) x = nn.Dense(self.output_size)(x) return x From 2e96257a3ba8cd390fe5a8c8663d91366cbd411c Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Thu, 19 Dec 2024 11:24:01 +0000 Subject: [PATCH 7/8] Rerun MNIST benchmark on GPU machine. --- benchmark/mnist_benchmark_results.json | 482 ++++++++++++------------- 1 file changed, 241 insertions(+), 241 deletions(-) diff --git a/benchmark/mnist_benchmark_results.json b/benchmark/mnist_benchmark_results.json index de81c888..456c1dd0 100644 --- a/benchmark/mnist_benchmark_results.json +++ b/benchmark/mnist_benchmark_results.json @@ -2,537 +2,537 @@ "RandomSample": { "25": { "0": { - "accuracy": 0.45518216490745544, - "time_taken": 36.396069547000025 + "accuracy": 0.47499004006385803, + "time_taken": 21.129530332 }, "1": { - "accuracy": 0.510104238986969, - "time_taken": 39.19832314600001 + "accuracy": 0.5094035863876343, + "time_taken": 21.620098648999942 }, "2": { - "accuracy": 0.43347349762916565, - "time_taken": 36.08024689900003 + "accuracy": 0.4429771900177002, + "time_taken": 23.587054420000072 }, "3": { - "accuracy": 0.5120047926902771, - "time_taken": 25.586340088000043 + "accuracy": 0.5266107320785522, + "time_taken": 20.613561648000086 }, "4": { - "accuracy": 0.4719884693622589, - "time_taken": 49.6169613750003 + "accuracy": 0.4708881676197052, + "time_taken": 23.3405117719999 } }, "50": { "0": { - "accuracy": 0.6087996363639832, - "time_taken": 12.642228439999997 + "accuracy": 0.6092996001243591, + "time_taken": 7.480583503999981 }, "1": { - "accuracy": 0.6126995086669922, - "time_taken": 19.304289521999976 + "accuracy": 0.6098998785018921, + "time_taken": 9.492431753999995 }, "2": { - "accuracy": 0.5459998846054077, - "time_taken": 18.504275409 + "accuracy": 0.5316997766494751, + "time_taken": 7.016472408000027 }, "3": { - "accuracy": 0.6301995515823364, - "time_taken": 14.654397104999816 + "accuracy": 0.6038997769355774, + "time_taken": 7.433440193000024 }, "4": { - "accuracy": 0.6123995780944824, - "time_taken": 22.13362233299995 + "accuracy": 0.6276999711990356, + "time_taken": 15.224367368999992 } }, "100": { "0": { - "accuracy": 0.7077999114990234, - "time_taken": 10.597413770000003 + "accuracy": 0.704299807548523, + "time_taken": 5.593184323999992 }, "1": { - "accuracy": 0.7308999300003052, - "time_taken": 8.02459470000008 + "accuracy": 0.7321001887321472, + "time_taken": 5.385782182999947 }, "2": { - "accuracy": 0.6015002131462097, - "time_taken": 3.7315618499999346 + "accuracy": 0.7243001461029053, + "time_taken": 9.138968591000094 }, "3": { - "accuracy": 0.7232002019882202, - "time_taken": 6.287373247999994 + "accuracy": 0.7279003262519836, + "time_taken": 3.660555053000053 }, "4": { - "accuracy": 0.6919002532958984, - "time_taken": 8.955407454000124 + "accuracy": 0.6853998899459839, + "time_taken": 6.721555722999938 } }, "500": { "0": { - "accuracy": 0.8521634936332703, - "time_taken": 5.262510822999957 + "accuracy": 0.849459171295166, + "time_taken": 4.1964316430000395 }, "1": { - "accuracy": 0.8403445482254028, - "time_taken": 3.346613910999963 + "accuracy": 0.8390424847602844, + "time_taken": 2.6543577679999544 }, "2": { - "accuracy": 0.8541666865348816, - "time_taken": 3.4243159590000687 + "accuracy": 0.8586738705635071, + "time_taken": 3.7907679610000287 }, "3": { - "accuracy": 0.8365384936332703, - "time_taken": 3.701081079000005 + "accuracy": 0.8433493971824646, + "time_taken": 3.211584036999966 }, "4": { - "accuracy": 0.8471554517745972, - "time_taken": 3.271475813000052 + "accuracy": 0.8508613705635071, + "time_taken": 4.339650910000046 } }, "1000": { "0": { - "accuracy": 0.8824118971824646, - "time_taken": 4.746199452000042 + "accuracy": 0.8806089758872986, + "time_taken": 3.810672871999998 }, "1": { - "accuracy": 0.8761017918586731, - "time_taken": 3.0665597200001002 + "accuracy": 0.8755007982254028, + "time_taken": 2.6573232700000062 }, "2": { - "accuracy": 0.8830128312110901, - "time_taken": 2.602093793999984 + "accuracy": 0.8828125, + "time_taken": 2.6886053809998884 }, "3": { - "accuracy": 0.8777043223381042, - "time_taken": 4.245425575000127 + "accuracy": 0.8731971383094788, + "time_taken": 2.8564365010001893 }, "4": { - "accuracy": 0.8834134936332703, - "time_taken": 2.963790925000012 + "accuracy": 0.8818109035491943, + "time_taken": 3.027272230000108 } }, "5000": { "0": { - "accuracy": 0.9229767918586731, - "time_taken": 5.274251015000004 + "accuracy": 0.9258814454078674, + "time_taken": 4.428039572999978 }, "1": { - "accuracy": 0.926682710647583, - "time_taken": 5.173077755999998 + "accuracy": 0.9238781929016113, + "time_taken": 3.322857870000007 }, "2": { - "accuracy": 0.9291867017745972, - "time_taken": 4.965678029000173 + "accuracy": 0.9277844429016113, + "time_taken": 3.3129011699998046 }, "3": { - "accuracy": 0.9264823794364929, - "time_taken": 4.216483555999957 + "accuracy": 0.9291867017745972, + "time_taken": 4.800562907000085 }, "4": { - "accuracy": 0.9259815812110901, - "time_taken": 4.042372738000267 + "accuracy": 0.9294871687889099, + "time_taken": 3.492984100000058 } } }, "RPCholesky": { "25": { "0": { - "accuracy": 0.5006002187728882, - "time_taken": 41.42257340599997 + "accuracy": 0.47358962893486023, + "time_taken": 19.895207870000036 }, "1": { - "accuracy": 0.556322455406189, - "time_taken": 25.34968879200005 + "accuracy": 0.5369148850440979, + "time_taken": 14.838139795000075 }, "2": { - "accuracy": 0.5873348712921143, - "time_taken": 37.2711402450002 + "accuracy": 0.5467190146446228, + "time_taken": 17.321421823000037 }, "3": { - "accuracy": 0.45908376574516296, - "time_taken": 41.109968239999944 + "accuracy": 0.4283711016178131, + "time_taken": 20.84139633099994 }, "4": { - "accuracy": 0.47118860483169556, - "time_taken": 41.38185949600029 + "accuracy": 0.4912963807582855, + "time_taken": 17.36893893699994 } }, "50": { "0": { - "accuracy": 0.5983994603157043, - "time_taken": 16.419000583999946 + "accuracy": 0.625399649143219, + "time_taken": 13.748352533000002 }, "1": { - "accuracy": 0.6796995997428894, - "time_taken": 15.655551776000038 + "accuracy": 0.6291998624801636, + "time_taken": 8.75706090899996 }, "2": { - "accuracy": 0.5455998182296753, - "time_taken": 23.019029731000046 + "accuracy": 0.5662998557090759, + "time_taken": 6.946477493999964 }, "3": { - "accuracy": 0.6141995191574097, - "time_taken": 20.926626421000037 + "accuracy": 0.6297994256019592, + "time_taken": 11.587993342000118 }, "4": { - "accuracy": 0.5852997303009033, - "time_taken": 19.28357562400015 + "accuracy": 0.5483998656272888, + "time_taken": 14.694229403999998 } }, "100": { "0": { - "accuracy": 0.7150002121925354, - "time_taken": 10.279756443999986 + "accuracy": 0.725000262260437, + "time_taken": 8.696777773000008 }, "1": { - "accuracy": 0.695099949836731, - "time_taken": 9.367277195000042 + "accuracy": 0.7001000046730042, + "time_taken": 5.543955067999946 }, "2": { - "accuracy": 0.6307001709938049, - "time_taken": 6.254802073000064 + "accuracy": 0.6311002373695374, + "time_taken": 4.008666183000059 }, "3": { - "accuracy": 0.671200156211853, - "time_taken": 9.271331661999966 + "accuracy": 0.6850997805595398, + "time_taken": 7.576385713000036 }, "4": { - "accuracy": 0.6927000284194946, - "time_taken": 10.04333458400015 + "accuracy": 0.646899938583374, + "time_taken": 9.464091529000143 } }, "500": { "0": { - "accuracy": 0.8525640964508057, - "time_taken": 6.090376355999979 + "accuracy": 0.8492588400840759, + "time_taken": 5.850811007999994 }, "1": { - "accuracy": 0.8218148946762085, - "time_taken": 3.980846933000066 + "accuracy": 0.8279246687889099, + "time_taken": 3.400919842999997 }, "2": { - "accuracy": 0.8170072436332703, - "time_taken": 5.630132798999966 + "accuracy": 0.8026843070983887, + "time_taken": 4.008895577999965 }, "3": { - "accuracy": 0.8092948794364929, - "time_taken": 4.590955021999889 + "accuracy": 0.8110977411270142, + "time_taken": 4.032388095999977 }, "4": { - "accuracy": 0.840745210647583, - "time_taken": 4.741569309999704 + "accuracy": 0.8451522588729858, + "time_taken": 3.68542980899997 } }, "1000": { "0": { - "accuracy": 0.8627804517745972, - "time_taken": 5.964579186000037 + "accuracy": 0.8804086446762085, + "time_taken": 5.837933116000045 }, "1": { - "accuracy": 0.8595753312110901, - "time_taken": 4.2806823729999905 + "accuracy": 0.8630809187889099, + "time_taken": 4.17050370100003 }, "2": { - "accuracy": 0.8712940812110901, - "time_taken": 5.8189128090000395 + "accuracy": 0.8666867017745972, + "time_taken": 4.835822229000087 }, "3": { - "accuracy": 0.8412460088729858, - "time_taken": 4.6377100379997955 + "accuracy": 0.8480569124221802, + "time_taken": 4.71791837700016 }, "4": { - "accuracy": 0.8733974695205688, - "time_taken": 4.602818280000065 + "accuracy": 0.8788061141967773, + "time_taken": 4.380343813000081 } }, "5000": { "0": { - "accuracy": 0.9246795177459717, - "time_taken": 30.759691682999915 + "accuracy": 0.9239783883094788, + "time_taken": 29.51819806499998 }, "1": { - "accuracy": 0.9284855723381042, - "time_taken": 28.251680083999872 + "accuracy": 0.9287860989570618, + "time_taken": 27.981147562000046 }, "2": { - "accuracy": 0.9284855723381042, - "time_taken": 28.629509002000077 + "accuracy": 0.9247796535491943, + "time_taken": 27.539323438999872 }, "3": { - "accuracy": 0.9278846383094788, - "time_taken": 29.119471999000098 + "accuracy": 0.9285857677459717, + "time_taken": 27.830323820999865 }, "4": { - "accuracy": 0.9231771230697632, - "time_taken": 28.0593601999999 + "accuracy": 0.9238781929016113, + "time_taken": 28.524467420000065 } } }, "KernelHerding": { "25": { "0": { - "accuracy": 0.5423170924186707, - "time_taken": 52.57508162400006 + "accuracy": 0.5501202940940857, + "time_taken": 29.94118039799997 }, "1": { - "accuracy": 0.5205084085464478, - "time_taken": 39.45232254999996 + "accuracy": 0.5024006366729736, + "time_taken": 19.38333930700003 }, "2": { - "accuracy": 0.46438610553741455, - "time_taken": 35.902196312000115 + "accuracy": 0.5388156771659851, + "time_taken": 16.329524966000008 }, "3": { - "accuracy": 0.5081029534339905, - "time_taken": 44.60431595299997 + "accuracy": 0.5127052068710327, + "time_taken": 16.15010204400005 }, "4": { - "accuracy": 0.5087032914161682, - "time_taken": 41.437250334000055 + "accuracy": 0.539915919303894, + "time_taken": 17.088837638000086 } }, "50": { "0": { - "accuracy": 0.6092994213104248, - "time_taken": 20.360353013999998 + "accuracy": 0.6299993395805359, + "time_taken": 17.985974345999978 }, "1": { - "accuracy": 0.6002998352050781, - "time_taken": 14.649381372000107 + "accuracy": 0.6076997518539429, + "time_taken": 10.080095013000005 }, "2": { - "accuracy": 0.5516994595527649, - "time_taken": 9.946566445000144 + "accuracy": 0.6290996670722961, + "time_taken": 10.87603085700016 }, "3": { - "accuracy": 0.5886996984481812, - "time_taken": 17.334741768999947 + "accuracy": 0.6168995499610901, + "time_taken": 12.538436354000169 }, "4": { - "accuracy": 0.6138997673988342, - "time_taken": 17.77710858399996 + "accuracy": 0.5985994935035706, + "time_taken": 9.817184417000135 } }, "100": { "0": { - "accuracy": 0.7017998099327087, - "time_taken": 9.876689930999987 + "accuracy": 0.7302002310752869, + "time_taken": 13.26493101899996 }, "1": { - "accuracy": 0.6914001107215881, - "time_taken": 5.265936994000185 + "accuracy": 0.7350001931190491, + "time_taken": 7.117937374000007 }, "2": { - "accuracy": 0.7000998258590698, - "time_taken": 5.393372559999989 + "accuracy": 0.723000168800354, + "time_taken": 6.991755536000028 }, "3": { - "accuracy": 0.6898000836372375, - "time_taken": 6.27198206900016 + "accuracy": 0.7188998460769653, + "time_taken": 6.8836636179999005 }, "4": { - "accuracy": 0.698900043964386, - "time_taken": 4.835163652000119 + "accuracy": 0.716400146484375, + "time_taken": 5.2314696280000135 } }, "500": { "0": { - "accuracy": 0.8225160241127014, - "time_taken": 7.761980211000036 + "accuracy": 0.8127003312110901, + "time_taken": 8.29185179000001 }, "1": { - "accuracy": 0.8360376954078674, - "time_taken": 4.213319348000141 + "accuracy": 0.8148037195205688, + "time_taken": 3.6316947710000704 }, "2": { - "accuracy": 0.8356370329856873, - "time_taken": 3.8930105310000727 + "accuracy": 0.8196113705635071, + "time_taken": 3.252079743999957 }, "3": { - "accuracy": 0.810396671295166, - "time_taken": 3.072564359000353 + "accuracy": 0.8093950152397156, + "time_taken": 3.455513597000163 }, "4": { - "accuracy": 0.8274238705635071, - "time_taken": 4.178343173999565 + "accuracy": 0.8156049847602844, + "time_taken": 3.3011748870001156 } }, "1000": { "0": { - "accuracy": 0.8530648946762085, - "time_taken": 6.340398364999942 + "accuracy": 0.8576722741127014, + "time_taken": 8.381767883000009 }, "1": { - "accuracy": 0.8548678159713745, - "time_taken": 3.055124343999978 + "accuracy": 0.8568710088729858, + "time_taken": 3.56256258399992 }, "2": { - "accuracy": 0.8617788553237915, - "time_taken": 3.762878829000101 + "accuracy": 0.8465545177459717, + "time_taken": 3.135869393999883 }, "3": { - "accuracy": 0.8614783883094788, - "time_taken": 3.4879642840000997 + "accuracy": 0.860276460647583, + "time_taken": 3.4447939570000017 }, "4": { - "accuracy": 0.8601762652397156, - "time_taken": 3.4571287429998847 + "accuracy": 0.8583734035491943, + "time_taken": 3.61459772000012 } }, "5000": { "0": { - "accuracy": 0.9297876954078674, - "time_taken": 7.996410788999924 + "accuracy": 0.9268830418586731, + "time_taken": 9.059691469000086 }, "1": { - "accuracy": 0.9332932829856873, - "time_taken": 6.031857453999919 + "accuracy": 0.9315905570983887, + "time_taken": 5.148555465000072 }, "2": { - "accuracy": 0.9241787195205688, - "time_taken": 4.493304755000054 + "accuracy": 0.930588960647583, + "time_taken": 5.6641410270001415 }, "3": { - "accuracy": 0.9307892918586731, - "time_taken": 6.0171785980001005 + "accuracy": 0.9274839758872986, + "time_taken": 5.935124415000018 }, "4": { - "accuracy": 0.9332932829856873, - "time_taken": 6.450509147000048 + "accuracy": 0.9286859035491943, + "time_taken": 4.589270231 } } }, "SteinThinning": { "25": { "0": { - "accuracy": 0.38285332918167114, - "time_taken": 61.79404670500003 + "accuracy": 0.36124444007873535, + "time_taken": 32.886642288000075 }, "1": { - "accuracy": 0.37535014748573303, - "time_taken": 55.07096734299989 + "accuracy": 0.3338334560394287, + "time_taken": 41.821553507999965 }, "2": { - "accuracy": 0.38295334577560425, - "time_taken": 59.987997941 + "accuracy": 0.31372541189193726, + "time_taken": 34.356144773000096 }, "3": { - "accuracy": 0.3390357494354248, - "time_taken": 40.479311096999936 + "accuracy": 0.27040842175483704, + "time_taken": 33.386861657000054 }, "4": { - "accuracy": 0.36194491386413574, - "time_taken": 39.54609884499996 + "accuracy": 0.35444176197052, + "time_taken": 36.694613785 } }, "50": { "0": { - "accuracy": 0.4300001561641693, - "time_taken": 33.57654272299999 + "accuracy": 0.46370017528533936, + "time_taken": 37.51419123599999 }, "1": { - "accuracy": 0.43349993228912354, - "time_taken": 31.334013471999924 + "accuracy": 0.43250009417533875, + "time_taken": 33.460962724999945 }, "2": { - "accuracy": 0.4227999746799469, - "time_taken": 32.19259820599996 + "accuracy": 0.4118999242782593, + "time_taken": 29.25043843399999 }, "3": { - "accuracy": 0.39500007033348083, - "time_taken": 33.17475527899978 + "accuracy": 0.40100017189979553, + "time_taken": 27.707170819999874 }, "4": { - "accuracy": 0.4095003008842468, - "time_taken": 35.54211925799973 + "accuracy": 0.45050016045570374, + "time_taken": 28.452544991999957 } }, "100": { "0": { - "accuracy": 0.4675999581813812, - "time_taken": 31.202895077999983 + "accuracy": 0.4580000340938568, + "time_taken": 23.420479623999995 }, "1": { - "accuracy": 0.49570003151893616, - "time_taken": 27.87190726899985 + "accuracy": 0.4966999590396881, + "time_taken": 26.137647441000013 }, "2": { - "accuracy": 0.47530001401901245, - "time_taken": 26.891571627000076 + "accuracy": 0.500999927520752, + "time_taken": 24.4042449760002 }, "3": { - "accuracy": 0.48059985041618347, - "time_taken": 24.772397474999707 + "accuracy": 0.4515998661518097, + "time_taken": 22.571522415000118 }, "4": { - "accuracy": 0.44819989800453186, - "time_taken": 24.519068094999966 + "accuracy": 0.4860999286174774, + "time_taken": 27.187762750000047 } }, "500": { "0": { - "accuracy": 0.5233373641967773, - "time_taken": 32.518364321000035 + "accuracy": 0.5588942170143127, + "time_taken": 20.343176127999982 }, "1": { - "accuracy": 0.5446714758872986, - "time_taken": 21.10401904199989 + "accuracy": 0.5828325152397156, + "time_taken": 19.380917395999973 }, "2": { - "accuracy": 0.5258413553237915, - "time_taken": 21.467372049999994 + "accuracy": 0.5852363705635071, + "time_taken": 19.87485118199993 }, "3": { - "accuracy": 0.5215344429016113, - "time_taken": 21.456703629999993 + "accuracy": 0.5499799847602844, + "time_taken": 22.195402222999974 }, "4": { - "accuracy": 0.5278445482254028, - "time_taken": 22.62931420799987 + "accuracy": 0.5618990659713745, + "time_taken": 19.422155341999996 } }, "1000": { "0": { - "accuracy": 0.5904447436332703, - "time_taken": 27.6814611100001 + "accuracy": 0.5892428159713745, + "time_taken": 21.69759744099997 }, "1": { - "accuracy": 0.5819311141967773, - "time_taken": 24.345591108000008 + "accuracy": 0.5992588400840759, + "time_taken": 22.000713124999947 }, "2": { - "accuracy": 0.5763221383094788, - "time_taken": 22.27670924300014 + "accuracy": 0.5953525900840759, + "time_taken": 23.080259337999905 }, "3": { - "accuracy": 0.5704126954078674, - "time_taken": 25.645224296999913 + "accuracy": 0.5842347741127014, + "time_taken": 20.16238621299999 }, "4": { - "accuracy": 0.5785256624221802, - "time_taken": 22.827015231999667 + "accuracy": 0.5821314454078674, + "time_taken": 20.503732982999963 } }, "5000": { "0": { - "accuracy": 0.6521434187889099, - "time_taken": 38.882569371000045 + "accuracy": 0.6593549847602844, + "time_taken": 29.339991162000047 }, "1": { - "accuracy": 0.6533453464508057, - "time_taken": 34.77338421800005 + "accuracy": 0.6642628312110901, + "time_taken": 29.7060228900001 }, "2": { - "accuracy": 0.6462339758872986, - "time_taken": 37.378769906000116 + "accuracy": 0.6545472741127014, + "time_taken": 30.334544582000035 }, "3": { - "accuracy": 0.6377203464508057, - "time_taken": 34.397472245000245 + "accuracy": 0.6499398946762085, + "time_taken": 29.475541733 }, "4": { - "accuracy": 0.6459335088729858, - "time_taken": 38.09663688599994 + "accuracy": 0.6581530570983887, + "time_taken": 30.22798960299997 } } } -} +} \ No newline at end of file From 874467ad780c36e3576cb482d90c2fe7ef97ed20 Mon Sep 17 00:00:00 2001 From: qh681248 <181246904+qh681248@users.noreply.github.com> Date: Thu, 19 Dec 2024 15:23:26 +0000 Subject: [PATCH 8/8] feat: used path lib consistently #893 --- benchmark/mnist_benchmark_results.json | 2 +- benchmark/pounce_benchmark.py | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/benchmark/mnist_benchmark_results.json b/benchmark/mnist_benchmark_results.json index 456c1dd0..63d18543 100644 --- a/benchmark/mnist_benchmark_results.json +++ b/benchmark/mnist_benchmark_results.json @@ -535,4 +535,4 @@ } } } -} \ No newline at end of file +} diff --git a/benchmark/pounce_benchmark.py b/benchmark/pounce_benchmark.py index 1d3d8c77..d5ee7801 100644 --- a/benchmark/pounce_benchmark.py +++ b/benchmark/pounce_benchmark.py @@ -21,11 +21,11 @@ 4. Print the time taken to generate each coreset. """ -import os import time from pathlib import Path import imageio +import jax.numpy as jnp import numpy as np import umap from jax import random @@ -46,18 +46,17 @@ def benchmark_coreset_algorithms( :param out_dir: Directory to save the output GIFs for each coreset algorithm. :param coreset_size: The size of the coreset. """ - base_dir = os.path.dirname(os.path.abspath(__file__)) + base_dir = Path(__file__).resolve().parent # Ensure paths are absolute and output directory exists - in_path = Path(os.path.join(base_dir, in_path)).resolve() - out_dir = Path(os.path.join(base_dir, out_dir)).resolve() + in_path = (base_dir / in_path).resolve() + out_dir = (base_dir / out_dir).resolve() out_dir.mkdir(parents=True, exist_ok=True) # Load and preprocess video frames _, *image_data = imageio.v2.mimread(in_path) raw_data = np.asarray(image_data) reshaped_data = raw_data.reshape(raw_data.shape[0], -1) - print(type(reshaped_data)) umap_model = umap.UMAP(densmap=True, n_components=25) umap_data = umap_model.fit_transform(reshaped_data) @@ -71,7 +70,7 @@ def benchmark_coreset_algorithms( for get_solver in solvers: solver = get_solver(coreset_size) solver_name = get_solver_name(solver) - data = Data(umap_data) + data = Data(jnp.array(umap_data)) start_time = time.perf_counter() coreset, _ = solver.reduce(data)