diff --git a/.gitignore b/.gitignore index 1bc313eebfe..a4e199b950b 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ /data/ /logs/ /output/ +*.csv # Spack .spack-env/ diff --git a/benchmark.py b/benchmark.py new file mode 100755 index 00000000000..912a1dbe948 --- /dev/null +++ b/benchmark.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python3 + +"""dataset and sampler benchmarking script.""" + +import argparse +import csv +import os +import time + +import pytorch_lightning as pl +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision.models import resnet18 + +from torchgeo.datasets import CDL, BoundingBox, Landsat8 +from torchgeo.samplers import GridGeoSampler, RandomBatchGeoSampler, RandomGeoSampler + + +def set_up_parser() -> argparse.ArgumentParser: + """Set up the argument parser. + + Returns: + the argument parser + """ + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--landsat-root", + default=os.path.join("data", "landsat"), + help="directory containing Landsat data", + metavar="ROOT", + ) + parser.add_argument( + "--cdl-root", + default=os.path.join("data", "cdl"), + help="directory containing CDL data", + metavar="ROOT", + ) + parser.add_argument( + "-d", "--device", default=0, type=int, help="CPU/GPU ID to use", metavar="ID" + ) + parser.add_argument( + "-c", + "--cache", + action="store_true", + help="cache file handles during data loading", + ) + parser.add_argument( + "-b", + "--batch-size", + default=2 ** 4, + type=int, + help="number of samples in each mini-batch", + metavar="SIZE", + ) + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( + "-n", + "--num-batches", + type=int, + help="number of batches to load", + metavar="SIZE", + ) + group.add_argument( + "-e", + "--epoch-size", + type=int, + help="number of samples to load, should be evenly divisible by batch size", + metavar="SIZE", + ) + parser.add_argument( + "-p", + "--patch-size", + default=224, + type=int, + help="height/width of each patch", + metavar="SIZE", + ) + parser.add_argument( + "-s", + "--stride", + default=2 ** 7, + type=int, + help="sampling stride for GridGeoSampler", + ) + parser.add_argument( + "-w", + "--num-workers", + default=0, + type=int, + help="number of workers for parallel data loading", + metavar="NUM", + ) + parser.add_argument( + "--seed", + default=0, + type=int, + help="random seed for reproducibility", + ) + parser.add_argument( + "--output-fn", + default="benchmark-results.csv", + type=str, + help="path to the CSV file to write results", + metavar="FILE", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="print results to stdout", + ) + + return parser + + +def main(args: argparse.Namespace) -> None: + """High-level pipeline. + + Benchmarks performance of various samplers with and without caching. + + Args: + args: command-line arguments + """ + bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7"] + + # Benchmark samplers + + # Initialize datasets + cdl = CDL(args.cdl_root, cache=args.cache) + landsat = Landsat8( + args.landsat_root, crs=cdl.crs, res=cdl.res, cache=args.cache, bands=bands + ) + dataset = landsat + cdl + + # Initialize samplers + if args.epoch_size: + length = args.epoch_size + num_batches = args.epoch_size // args.batch_size + elif args.num_batches: + length = args.num_batches * args.batch_size + num_batches = args.num_batches + + # Workaround for https://github.com/microsoft/torchgeo/issues/149 + roi = BoundingBox( + -2000000, 2200000, 280000, 3170000, dataset.bounds.mint, dataset.bounds.maxt + ) + samplers = [ + RandomGeoSampler( + landsat.index, + size=args.patch_size, + length=length, + roi=roi, + ), + GridGeoSampler( + landsat.index, size=args.patch_size, stride=args.stride, roi=roi + ), + RandomBatchGeoSampler( + landsat.index, + size=args.patch_size, + batch_size=args.batch_size, + length=length, + roi=roi, + ), + ] + + results_rows = [] + for sampler in samplers: + if args.verbose: + print(f"\n{sampler.__class__.__name__}:") + + if isinstance(sampler, RandomBatchGeoSampler): + dataloader = DataLoader( + dataset, + batch_sampler=sampler, # type: ignore[arg-type] + num_workers=args.num_workers, + ) + else: + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + sampler=sampler, # type: ignore[arg-type] + num_workers=args.num_workers, + ) + + tic = time.time() + num_total_patches = 0 + for i, batch in enumerate(dataloader): + num_total_patches += args.batch_size + # This is to stop the GridGeoSampler from enumerating everything + if i == num_batches - 1: + break + toc = time.time() + duration = toc - tic + + if args.verbose: + print(f" duration: {duration:.3f} sec") + print(f" count: {num_total_patches} patches") + print(f" rate: {num_total_patches / duration:.3f} patches/sec") + + if args.cache: + if args.verbose: + print(landsat._cached_load_warp_file.cache_info()) + + # Clear cache for fair comparison between samplers + # Both `landsat` and `cdl` share the same cache + landsat._cached_load_warp_file.cache_clear() + + results_rows.append( + { + "cached": args.cache, + "seed": args.seed, + "duration": duration, + "count": num_total_patches, + "rate": num_total_patches / duration, + "sampler": sampler.__class__.__name__, + "batch_size": args.batch_size, + "num_workers": args.num_workers, + } + ) + + # Benchmark model + model = resnet18() + # Change number of input channels to match Landsat + model.conv1 = nn.Conv2d( # type: ignore[attr-defined] + len(bands), 64, kernel_size=7, stride=2, padding=3, bias=False + ) + + criterion = nn.CrossEntropyLoss() # type: ignore[attr-defined] + params = model.parameters() + optimizer = optim.SGD(params, lr=0.0001) + + device = torch.device( # type: ignore[attr-defined] + "cuda" if torch.cuda.is_available() else "cpu", args.device + ) + model = model.to(device) + + tic = time.time() + num_total_patches = 0 + for _ in range(num_batches): + num_total_patches += args.batch_size + x = torch.rand(args.batch_size, len(bands), args.patch_size, args.patch_size) + # y = torch.randint(0, 256, (args.batch_size, args.patch_size, args.patch_size)) + y = torch.randint(0, 256, (args.batch_size,)) # type: ignore[attr-defined] + x = x.to(device) + y = y.to(device) + + optimizer.zero_grad() + prediction = model(x) + loss = criterion(prediction, y) + loss.backward() + optimizer.step() + + toc = time.time() + duration = toc - tic + + if args.verbose: + print("\nResNet-18:") + print(f" duration: {duration:.3f} sec") + print(f" count: {num_total_patches} patches") + print(f" rate: {num_total_patches / duration:.3f} patches/sec") + + results_rows.append( + { + "cached": args.cache, + "seed": args.seed, + "duration": duration, + "count": num_total_patches, + "rate": num_total_patches / duration, + "sampler": "resnet18", + "batch_size": args.batch_size, + "num_workers": args.num_workers, + } + ) + + fieldnames = [ + "cached", + "seed", + "duration", + "count", + "rate", + "sampler", + "batch_size", + "num_workers", + ] + if not os.path.exists(args.output_fn): + with open(args.output_fn, "w") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + with open(args.output_fn, "a") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writerows(results_rows) + + +if __name__ == "__main__": + os.environ["GDAL_CACHEMAX"] = "50%" + + parser = set_up_parser() + args = parser.parse_args() + + if args.epoch_size: + assert args.epoch_size % args.batch_size == 0 + + pl.seed_everything(args.seed) + + main(args) diff --git a/run_benchmarks_experiments.py b/run_benchmarks_experiments.py new file mode 100755 index 00000000000..b6a83a46fbc --- /dev/null +++ b/run_benchmarks_experiments.py @@ -0,0 +1,44 @@ +"""Script for running the benchmark script over a sweep of different options.""" +import itertools +import subprocess +import time +from typing import List + +NUM_BATCHES = 100 + +SEED_OPTIONS = [0, 1, 2] +CACHE_OPTIONS = [True, False] +BATCH_SIZE_OPTIONS = [16, 32, 64, 128, 256, 512] + +total_num_experiments = len(SEED_OPTIONS) * len(CACHE_OPTIONS) * len(BATCH_SIZE_OPTIONS) + +if __name__ == "__main__": + + tic = time.time() + for i, (cache, batch_size, seed) in enumerate( + itertools.product(CACHE_OPTIONS, BATCH_SIZE_OPTIONS, SEED_OPTIONS) + ): + print(f"{i}/{total_num_experiments} -- {time.time() - tic}") + tic = time.time() + command: List[str] = [ + "python", + "benchmark.py", + "--landsat-root", + "/datadrive/landsat", + "--cdl-root", + "/datadrive/cdl", + "-w", + "6", + "--batch-size", + str(batch_size), + "--num-batches", + str(NUM_BATCHES), + "--seed", + str(seed), + "--verbose", + ] + + if cache: + command.append("--cache") + + subprocess.call(command)