diff --git a/sklbench/runner/arguments.py b/sklbench/runner/arguments.py index 1ba47daa..3f114dd0 100644 --- a/sklbench/runner/arguments.py +++ b/sklbench/runner/arguments.py @@ -130,6 +130,12 @@ def add_runner_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentPa action="store_true", help="Load all requested datasets in parallel before running benchmarks.", ) + parser.add_argument( + "--describe-datasets", + default=False, + action="store_true", + help="Load all requested datasets in parallel and show their parameters.", + ) # workflow control parser.add_argument( "--exit-on-error", diff --git a/sklbench/runner/implementation.py b/sklbench/runner/implementation.py index 2375e4b7..47b10962 100644 --- a/sklbench/runner/implementation.py +++ b/sklbench/runner/implementation.py @@ -16,7 +16,9 @@ import argparse +import gc import json +import sys from multiprocessing import Pool from typing import Dict, List, Tuple, Union @@ -94,7 +96,7 @@ def run_benchmarks(args: argparse.Namespace) -> int: bench_cases = early_filtering(bench_cases, param_filters) # prefetch datasets - if args.prefetch_datasets: + if args.prefetch_datasets or args.describe_datasets: # trick: get unique dataset names only to avoid loading of same dataset # by different cases/processes dataset_cases = {get_data_name(case): case for case in bench_cases} @@ -102,7 +104,18 @@ def run_benchmarks(args: argparse.Namespace) -> int: n_proc = min([16, cpu_count(), len(dataset_cases)]) logger.info(f"Prefetching datasets with {n_proc} processes") with Pool(n_proc) as pool: - pool.map(load_data, dataset_cases.values()) + datasets = pool.map(load_data, dataset_cases.values()) + if args.describe_datasets: + for (data, data_description), data_name in zip( + datasets, dataset_cases.keys() + ): + print( + f"{data_name}:\n\tshape: {data['x'].shape}\n\tparameters: {data_description}" + ) + sys.exit(0) + # free memory used by prefetched datasets + del datasets + gc.collect() # run bench_cases return_code, result = call_benchmarks(