Skip to content

Commit

Permalink
Benchmark model as well
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Sep 18, 2021
1 parent 17a25bd commit dd1c8a2
Showing 1 changed file with 64 additions and 16 deletions.
80 changes: 64 additions & 16 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
import time

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from rasterio.crs import CRS
from torch.utils.data import DataLoader
from torchvision.models import resnet18

from torchgeo.datasets import CDL, Landsat8
from torchgeo.samplers import GridGeoSampler, RandomBatchGeoSampler, RandomGeoSampler
Expand Down Expand Up @@ -69,7 +73,7 @@ def set_up_parser() -> argparse.ArgumentParser:
parser.add_argument(
"-p",
"--patch-size",
default=2 ** 8,
default=224,
type=int,
help="height/width of each patch",
metavar="SIZE",
Expand Down Expand Up @@ -120,24 +124,14 @@ def main(args: argparse.Namespace) -> None:
Args:
args: command-line arguments
"""
bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7"]

# Benchmark samplers

# Initialize datasets
crs = CRS.from_epsg(32610) # UTM, Zone 10
res = 15
landsat = Landsat8(
args.landsat_root,
crs,
res,
cache=args.cache,
bands=[
"B1",
"B2",
"B3",
"B4",
"B5",
"B6",
"B7",
],
)
landsat = Landsat8(args.landsat_root, crs, res, cache=args.cache, bands=bands)
cdl = CDL(args.cdl_root, crs, res, cache=args.cache)
dataset = landsat + cdl

Expand Down Expand Up @@ -219,6 +213,60 @@ def main(args: argparse.Namespace) -> None:
}
)

# Benchmark model
model = resnet18()
# Change number of input channels to match Landsat
model.conv1 = nn.Conv2d( # type: ignore[attr-defined]
len(bands), 64, kernel_size=7, stride=2, padding=3, bias=False
)

criterion = nn.CrossEntropyLoss() # type: ignore[attr-defined]
params = model.parameters()
optimizer = optim.SGD(params, lr=0.0001)

device = torch.device( # type: ignore[attr-defined]
"cuda" if torch.cuda.is_available() else "cpu"
)
model = model.to(device)

tic = time.time()
num_total_patches = 0
for _ in range(num_batches):
num_total_patches += args.batch_size
x = torch.rand(args.batch_size, len(bands), args.patch_size, args.patch_size)
# y = torch.randint(0, 256, (args.batch_size, args.patch_size, args.patch_size))
y = torch.randint(0, 256, (args.batch_size,)) # type: ignore[attr-defined]
x = x.to(device)
y = y.to(device)

optimizer.zero_grad()
prediction = model(x)
loss = criterion(prediction, y)
loss.backward()
optimizer.step()

toc = time.time()
duration = toc - tic

if args.verbose:
print("\nResNet-18:")
print(f" duration: {duration:.3f} sec")
print(f" count: {num_total_patches} patches")
print(f" rate: {num_total_patches / duration:.3f} patches/sec")

results_rows.append(
{
"cached": args.cache,
"seed": args.seed,
"duration": duration,
"count": num_total_patches,
"rate": num_total_patches / duration,
"sampler": "resnet18",
"batch_size": args.batch_size,
"num_workers": args.num_workers,
}
)

fieldnames = [
"cached",
"seed",
Expand Down

0 comments on commit dd1c8a2

Please sign in to comment.