Skip to content

Commit

Permalink
Add toy dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Dec 8, 2024
1 parent 0966b11 commit 9c5257b
Showing 1 changed file with 52 additions and 52 deletions.
104 changes: 52 additions & 52 deletions docs/tutorials/torchgeo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,13 @@
"source": [
"import os\n",
"import tempfile\n",
"from datetime import datetime\n",
"\n",
"from matplotlib import pyplot as plt\n",
"from torch.utils.data import DataLoader\n",
"\n",
"from torchgeo.datasets import CDL, BoundingBox, Landsat7, Landsat8, stack_samples\n",
"from torchgeo.datasets.utils import download_url\n",
"from torchgeo.datasets.utils import download_and_extract_archive\n",
"from torchgeo.samplers import GridGeoSampler, RandomGeoSampler"
]
},
Expand Down Expand Up @@ -102,7 +104,7 @@
"\n",
"Traditionally, people either performed classification on a single pixel at a time or curated their own benchmark dataset. This works fine for training, but isn't really useful for inference. What we would really like to be able to do is sample small pixel-aligned pairs of input images and output masks from the region of overlap between both datasets. This exact situation is illustrated in the following figure:\n",
"\n",
"![Landsat CDL intersection]()\n",
"![Landsat CDL intersection](https://github.com/microsoft/torchgeo/blob/main/images/geodataset.png?raw=true)\n",
"\n",
"Now, let's see what features TorchGeo has to support this kind of use case."
]
Expand Down Expand Up @@ -141,18 +143,24 @@
"source": [
"landsat_root = os.path.join(tempfile.gettempdir(), 'landsat')\n",
"\n",
"download_url()\n",
"download_url()\n",
"url = 'https://hf.co/datasets/torchgeo/tutorials/resolve/ff30b729e3cbf906148d69a4441cc68023898924/'\n",
"landsat7_url = url + 'LE07_L2SP_022032_20230725_20230820_02_T1.tar.gz'\n",
"landsat8_url = url + 'LC08_L2SP_023032_20230831_20230911_02_T1.tar.gz'\n",
"\n",
"landsat7 = Landsat7(\n",
" paths=landsat_root, bands=['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7']\n",
")\n",
"landsat8 = Landsat8(\n",
" paths=landsat_root, bands=['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8']\n",
")\n",
"download_and_extract_archive(landsat7_url, landsat_root)\n",
"download_and_extract_archive(landsat8_url, landsat_root)\n",
"\n",
"landsat7_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B7']\n",
"landsat8_bands = ['SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7']\n",
"\n",
"landsat7 = Landsat7(paths=landsat_root, bands=landsat7_bands)\n",
"landsat8 = Landsat8(paths=landsat_root, bands=landsat8_bands)\n",
"\n",
"print(landsat7)\n",
"print(landsat8)"
"print(landsat8)\n",
"\n",
"print(landsat7.crs)\n",
"print(landsat8.crs)"
]
},
{
Expand Down Expand Up @@ -186,11 +194,14 @@
"source": [
"cdl_root = os.path.join(tempfile.gettempdir(), 'cdl')\n",
"\n",
"download_url()\n",
"cdl_url = url + '2023_30m_cdls.zip'\n",
"\n",
"download_and_extract_archive(cdl_url, cdl_root)\n",
"\n",
"cdl = CDL(paths=cdl_root)\n",
"\n",
"print(cdl)"
"print(cdl)\n",
"print(cdl.crs)"
]
},
{
Expand All @@ -201,8 +212,8 @@
"Again, the following details are worth noting:\n",
"\n",
"* We could actually ask the `CDL` dataset to download our data for us by adding `download=True`\n",
"* All three datasets have different spatial extends\n",
"* All three datasets have different CRSs"
"* All datasets have different spatial extents\n",
"* All datasets have different CRSs"
]
},
{
Expand All @@ -223,7 +234,8 @@
"outputs": [],
"source": [
"landsat = landsat7 | landsat8\n",
"print(landsat)"
"print(landsat)\n",
"print(landsat.crs)"
]
},
{
Expand All @@ -242,7 +254,8 @@
"outputs": [],
"source": [
"dataset = landsat & cdl\n",
"print(dataset)"
"print(dataset)\n",
"print(dataset.crs)"
]
},
{
Expand All @@ -262,7 +275,7 @@
"\n",
"How did we do this? TorchGeo uses a data structure called an *R-tree* to store the spatiotemporal bounding box of every file in the dataset. \n",
"\n",
"![R-tree]()\n",
"![R-tree](https://raw.githubusercontent.com/davidmoten/davidmoten.github.io/master/resources/rtree-3d/plot2.png)\n",
"\n",
"TorchGeo extracts the spatial bounding box from the metadata of each file, and the timestamp from the filename. This geospatial and geotemporal metadata allows us to efficiently compute the intersection or union of two datasets. It also lets us quickly retrieve an image and corresponding mask for a particular location in space and time."
]
Expand All @@ -274,11 +287,21 @@
"metadata": {},
"outputs": [],
"source": [
"bbox = BoundingBox()\n",
"sample = dataset[sample]\n",
"size = 256\n",
"\n",
"landsat.plot(sample)\n",
"cdl.plot(sample)"
"xmin = 925000\n",
"xmax = xmin + size * 30\n",
"ymin = 4470000\n",
"ymax = ymin + size * 30\n",
"tmin = datetime(2023, 1, 1).timestamp()\n",
"tmax = datetime(2023, 12, 31).timestamp()\n",
"\n",
"bbox = BoundingBox(xmin, xmax, ymin, ymax, tmin, tmax)\n",
"sample = dataset[bbox]\n",
"\n",
"landsat8.plot(sample)\n",
"cdl.plot(sample)\n",
"plt.show()"
]
},
{
Expand All @@ -289,15 +312,6 @@
"TorchGeo uses *windowed-reading* to only read the blocks of memory needed to load a small patch from a large raster tile. It also automatically reprojects all data to the same CRS and resolution (from the first dataset). This can be controlled by explicitly passing `crs` or `res` to the dataset."
]
},
{
"cell_type": "markdown",
"id": "02368e20-3391-4be7-bbe5-5a3c367ab398",
"metadata": {},
"source": [
"### Geospatial splitting\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "e2e4221e-dfb7-4966-96a6-e52400ae266c",
Expand Down Expand Up @@ -327,8 +341,8 @@
"metadata": {},
"outputs": [],
"source": [
"train_sampler = RandomGeoSampler(dataset, size=256, length=1000)\n",
"print(next(train_sampler))"
"train_sampler = RandomGeoSampler(dataset, size=size, length=1000)\n",
"next(iter(train_sampler))"
]
},
{
Expand All @@ -338,7 +352,7 @@
"source": [
"### Gridded sampling\n",
"\n",
"At evaluation time, this actually becomes a problem. We want to make sure we aren't making multiple predictions for the same location. We also want to make sure we don't miss any locations. To achieve this, TorchGeo also provides a `GridGeoSampler`. We can tell the sampler the size of each image patch and the stride of our sliding window (defaults to patch size)."
"At evaluation time, this actually becomes a problem. We want to make sure we aren't making multiple predictions for the same location. We also want to make sure we don't miss any locations. To achieve this, TorchGeo also provides a `GridGeoSampler`. We can tell the sampler the size of each image patch and the stride of our sliding window."
]
},
{
Expand All @@ -348,8 +362,8 @@
"metadata": {},
"outputs": [],
"source": [
"test_sampler = GridGeoSampler(dataset, size=256)\n",
"print(next(test_sampler))"
"test_sampler = GridGeoSampler(dataset, size=size, stride=size)\n",
"next(iter(test_sampler))"
]
},
{
Expand Down Expand Up @@ -379,16 +393,10 @@
},
{
"cell_type": "markdown",
"id": "3518c7d9-1bb3-4bc2-8216-53044d0b4009",
"id": "e46e8453-df25-4265-a85b-75dce7dea047",
"metadata": {},
"source": [
"\n",
"* Transforms?\n",
"* Models\n",
" * U-Net + pre-trained ResNet\n",
" * Model pre-trained directly on satellite imagery\n",
"* Training and evaluation\n",
" * Copy everything else from "
"Now that we have working data loaders, we can copy-n-paste our training code from the Introduction to PyTorch tutorial. We only need to change our model to one designed for semantic segmentation, such as a U-Net. Every other line of code would be identical to how you would do this in your normal PyTorch workflow."
]
},
{
Expand All @@ -403,14 +411,6 @@
"* [TorchGeo: Deep Learning With Geospatial Data](https://arxiv.org/abs/2111.08872)\n",
"* [Geospatial deep learning with TorchGeo](https://pytorch.org/blog/geospatial-deep-learning-with-torchgeo/)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "38e60635-69b2-47c9-8df2-fd7c872abdd9",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit 9c5257b

Please sign in to comment.