diff --git a/benchmark.py b/benchmark.py index 6e48c35312d..a7aed4a565d 100755 --- a/benchmark.py +++ b/benchmark.py @@ -8,8 +8,12 @@ import time import pytorch_lightning as pl +import torch +import torch.nn as nn +import torch.optim as optim from rasterio.crs import CRS from torch.utils.data import DataLoader +from torchvision.models import resnet18 from torchgeo.datasets import CDL, Landsat8 from torchgeo.samplers import GridGeoSampler, RandomBatchGeoSampler, RandomGeoSampler @@ -69,7 +73,7 @@ def set_up_parser() -> argparse.ArgumentParser: parser.add_argument( "-p", "--patch-size", - default=2 ** 8, + default=224, type=int, help="height/width of each patch", metavar="SIZE", @@ -120,24 +124,14 @@ def main(args: argparse.Namespace) -> None: Args: args: command-line arguments """ + bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7"] + + # Benchmark samplers + # Initialize datasets crs = CRS.from_epsg(32610) # UTM, Zone 10 res = 15 - landsat = Landsat8( - args.landsat_root, - crs, - res, - cache=args.cache, - bands=[ - "B1", - "B2", - "B3", - "B4", - "B5", - "B6", - "B7", - ], - ) + landsat = Landsat8(args.landsat_root, crs, res, cache=args.cache, bands=bands) cdl = CDL(args.cdl_root, crs, res, cache=args.cache) dataset = landsat + cdl @@ -219,6 +213,60 @@ def main(args: argparse.Namespace) -> None: } ) + # Benchmark model + model = resnet18() + # Change number of input channels to match Landsat + model.conv1 = nn.Conv2d( # type: ignore[attr-defined] + len(bands), 64, kernel_size=7, stride=2, padding=3, bias=False + ) + + criterion = nn.CrossEntropyLoss() # type: ignore[attr-defined] + params = model.parameters() + optimizer = optim.SGD(params, lr=0.0001) + + device = torch.device( # type: ignore[attr-defined] + "cuda" if torch.cuda.is_available() else "cpu" + ) + model = model.to(device) + + tic = time.time() + num_total_patches = 0 + for _ in range(num_batches): + num_total_patches += args.batch_size + x = torch.rand(args.batch_size, len(bands), args.patch_size, args.patch_size) + # y = torch.randint(0, 256, (args.batch_size, args.patch_size, args.patch_size)) + y = torch.randint(0, 256, (args.batch_size,)) # type: ignore[attr-defined] + x = x.to(device) + y = y.to(device) + + optimizer.zero_grad() + prediction = model(x) + loss = criterion(prediction, y) + loss.backward() + optimizer.step() + + toc = time.time() + duration = toc - tic + + if args.verbose: + print("\nResNet-18:") + print(f" duration: {duration:.3f} sec") + print(f" count: {num_total_patches} patches") + print(f" rate: {num_total_patches / duration:.3f} patches/sec") + + results_rows.append( + { + "cached": args.cache, + "seed": args.seed, + "duration": duration, + "count": num_total_patches, + "rate": num_total_patches / duration, + "sampler": "resnet18", + "batch_size": args.batch_size, + "num_workers": args.num_workers, + } + ) + fieldnames = [ "cached", "seed",