Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add runner argument to describe datasets #167

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sklbench/runner/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ def add_runner_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentPa
action="store_true",
help="Load all requested datasets in parallel before running benchmarks.",
)
parser.add_argument(
"--describe-datasets",
default=False,
action="store_true",
help="Load all requested datasets in parallel and show their parameters.",
)
# workflow control
parser.add_argument(
"--exit-on-error",
Expand Down
17 changes: 15 additions & 2 deletions sklbench/runner/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@


import argparse
import gc
import json
import sys
from multiprocessing import Pool
from typing import Dict, List, Tuple, Union

Expand Down Expand Up @@ -94,15 +96,26 @@ def run_benchmarks(args: argparse.Namespace) -> int:
bench_cases = early_filtering(bench_cases, param_filters)

# prefetch datasets
if args.prefetch_datasets:
if args.prefetch_datasets or args.describe_datasets:
# trick: get unique dataset names only to avoid loading of same dataset
# by different cases/processes
dataset_cases = {get_data_name(case): case for case in bench_cases}
logger.debug(f"Unique dataset names to load:\n{list(dataset_cases.keys())}")
n_proc = min([16, cpu_count(), len(dataset_cases)])
logger.info(f"Prefetching datasets with {n_proc} processes")
with Pool(n_proc) as pool:
pool.map(load_data, dataset_cases.values())
datasets = pool.map(load_data, dataset_cases.values())
if args.describe_datasets:
for (data, data_description), data_name in zip(
datasets, dataset_cases.keys()
):
print(
f"{data_name}:\n\tshape: {data['x'].shape}\n\tparameters: {data_description}"
)
sys.exit(0)
# free memory used by prefetched datasets
del datasets
gc.collect()

# run bench_cases
return_code, result = call_benchmarks(
Expand Down