Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dataset/sampler benchmarking script #115

Merged
merged 17 commits into from
Sep 22, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
/data/
/logs/
/output/
*.csv
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

# Spack
.spack-env/
Expand Down
307 changes: 307 additions & 0 deletions benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
#!/usr/bin/env python3

"""dataset and sampler benchmarking script."""

import argparse
import csv
import os
import time

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.models import resnet18

from torchgeo.datasets import CDL, BoundingBox, Landsat8
from torchgeo.samplers import GridGeoSampler, RandomBatchGeoSampler, RandomGeoSampler


def set_up_parser() -> argparse.ArgumentParser:
"""Set up the argument parser.

Returns:
the argument parser
"""
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

parser.add_argument(
"--landsat-root",
default=os.path.join("data", "landsat"),
help="directory containing Landsat data",
metavar="ROOT",
)
parser.add_argument(
"--cdl-root",
default=os.path.join("data", "cdl"),
help="directory containing CDL data",
metavar="ROOT",
)
parser.add_argument(
"-c",
"--cache",
action="store_true",
help="cache file handles during data loading",
)
parser.add_argument(
"-b",
"--batch-size",
default=2 ** 4,
type=int,
help="number of samples in each mini-batch",
metavar="SIZE",
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"-n",
"--num-batches",
type=int,
help="number of batches to load",
metavar="SIZE",
)
group.add_argument(
"-e",
"--epoch-size",
type=int,
help="number of samples to load, should be evenly divisible by batch size",
metavar="SIZE",
)
parser.add_argument(
"-p",
"--patch-size",
default=224,
type=int,
help="height/width of each patch",
metavar="SIZE",
)
parser.add_argument(
"-s",
"--stride",
default=2 ** 7,
type=int,
help="sampling stride for GridGeoSampler",
)
parser.add_argument(
"-w",
"--num-workers",
default=0,
type=int,
help="number of workers for parallel data loading",
metavar="NUM",
)
parser.add_argument(
"--seed",
default=0,
type=int,
help="random seed for reproducibility",
)
parser.add_argument(
"--output-fn",
default="benchmark-results.csv",
type=str,
help="path to the CSV file to write results",
metavar="FILE",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="print results to stdout",
)

return parser


def main(args: argparse.Namespace) -> None:
"""High-level pipeline.

Benchmarks performance of various samplers with and without caching.

Args:
args: command-line arguments
"""
bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7"]

# Benchmark samplers

# Initialize datasets
cdl = CDL(args.cdl_root, cache=args.cache)
landsat = Landsat8(
args.landsat_root, crs=cdl.crs, res=cdl.res, cache=args.cache, bands=bands
)
dataset = landsat + cdl

# Initialize samplers
if args.epoch_size:
length = args.epoch_size
num_batches = args.epoch_size // args.batch_size
elif args.num_batches:
length = args.num_batches * args.batch_size
num_batches = args.num_batches

# Workaround for https://github.com/microsoft/torchgeo/issues/149
roi = BoundingBox(
-2000000, 2200000, 280000, 3170000, dataset.bounds.mint, dataset.bounds.maxt
)
samplers = [
RandomGeoSampler(
landsat.index,
size=args.patch_size,
length=length,
roi=roi,
),
GridGeoSampler(
landsat.index, size=args.patch_size, stride=args.stride, roi=roi
),
RandomBatchGeoSampler(
landsat.index,
size=args.patch_size,
batch_size=args.batch_size,
length=length,
roi=roi,
),
]

results_rows = []
for sampler in samplers:
if args.verbose:
print(f"\n{sampler.__class__.__name__}:")

if isinstance(sampler, RandomBatchGeoSampler):
dataloader = DataLoader(
dataset,
batch_sampler=sampler, # type: ignore[arg-type]
num_workers=args.num_workers,
)
else:
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler, # type: ignore[arg-type]
num_workers=args.num_workers,
)

tic = time.time()
num_total_patches = 0
for i, batch in enumerate(dataloader):
num_total_patches += args.batch_size
# This is to stop the GridGeoSampler from enumerating everything
if i == num_batches - 1:
break
toc = time.time()
duration = toc - tic

if args.verbose:
print(f" duration: {duration:.3f} sec")
print(f" count: {num_total_patches} patches")
print(f" rate: {num_total_patches / duration:.3f} patches/sec")

if args.cache:
if args.verbose:
print(landsat._cached_load_warp_file.cache_info())

# Clear cache for fair comparison between samplers
# Both `landsat` and `cdl` share the same cache
landsat._cached_load_warp_file.cache_clear()

results_rows.append(
{
"cached": args.cache,
"seed": args.seed,
"duration": duration,
"count": num_total_patches,
"rate": num_total_patches / duration,
"sampler": sampler.__class__.__name__,
"batch_size": args.batch_size,
"num_workers": args.num_workers,
}
)

# Benchmark model
model = resnet18()
# Change number of input channels to match Landsat
model.conv1 = nn.Conv2d( # type: ignore[attr-defined]
len(bands), 64, kernel_size=7, stride=2, padding=3, bias=False
)

criterion = nn.CrossEntropyLoss() # type: ignore[attr-defined]
params = model.parameters()
optimizer = optim.SGD(params, lr=0.0001)

device = torch.device( # type: ignore[attr-defined]
"cuda" if torch.cuda.is_available() else "cpu"
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
)
model = model.to(device)

tic = time.time()
num_total_patches = 0
for _ in range(num_batches):
num_total_patches += args.batch_size
x = torch.rand(args.batch_size, len(bands), args.patch_size, args.patch_size)
# y = torch.randint(0, 256, (args.batch_size, args.patch_size, args.patch_size))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is left over from when I was thinking about benchmarking a segmentation model instead of ResNet. I think segmentation is one of the more common tasks in remote sensing, and the models are more complex and therefore slower. If we want to have a slower model for comparison with our data loading rates, it might be good to use something like Mask R-CNN instead.

y = torch.randint(0, 256, (args.batch_size,)) # type: ignore[attr-defined]
x = x.to(device)
y = y.to(device)

optimizer.zero_grad()
prediction = model(x)
loss = criterion(prediction, y)
loss.backward()
optimizer.step()

toc = time.time()
duration = toc - tic

if args.verbose:
print("\nResNet-18:")
print(f" duration: {duration:.3f} sec")
print(f" count: {num_total_patches} patches")
print(f" rate: {num_total_patches / duration:.3f} patches/sec")

results_rows.append(
{
"cached": args.cache,
"seed": args.seed,
"duration": duration,
"count": num_total_patches,
"rate": num_total_patches / duration,
"sampler": "resnet18",
"batch_size": args.batch_size,
"num_workers": args.num_workers,
}
)

fieldnames = [
"cached",
"seed",
"duration",
"count",
"rate",
"sampler",
"batch_size",
"num_workers",
]
if not os.path.exists(args.output_fn):
with open(args.output_fn, "w") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
with open(args.output_fn, "a") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writerows(results_rows)


if __name__ == "__main__":
os.environ["GDAL_CACHEMAX"] = "50%"
calebrob6 marked this conversation as resolved.
Show resolved Hide resolved

parser = set_up_parser()
args = parser.parse_args()

if args.epoch_size:
assert args.epoch_size % args.batch_size == 0

pl.seed_everything(args.seed)

main(args)
44 changes: 44 additions & 0 deletions run_benchmarks_experiments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Script for running the benchmark script over a sweep of different options."""
import itertools
import subprocess
import time
from typing import List

NUM_BATCHES = 100

SEED_OPTIONS = [0, 1, 2]
CACHE_OPTIONS = [True, False]
BATCH_SIZE_OPTIONS = [16, 32, 64, 128, 256, 512]

total_num_experiments = len(SEED_OPTIONS) * len(CACHE_OPTIONS) * len(BATCH_SIZE_OPTIONS)

if __name__ == "__main__":

tic = time.time()
for i, (cache, batch_size, seed) in enumerate(
itertools.product(CACHE_OPTIONS, BATCH_SIZE_OPTIONS, SEED_OPTIONS)
):
print(f"{i}/{total_num_experiments} -- {time.time() - tic}")
tic = time.time()
command: List[str] = [
"python",
"benchmark.py",
"--landsat-root",
"/datadrive/landsat",
"--cdl-root",
"/datadrive/cdl",
"-w",
"6",
"--batch-size",
str(batch_size),
"--num-batches",
str(NUM_BATCHES),
"--seed",
str(seed),
"--verbose",
]

if cache:
command.append("--cache")

subprocess.call(command)