diff --git a/docs/index.rst b/docs/index.rst
index b763481e7ea..561a496d67a 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -39,6 +39,7 @@ torchgeo
tutorials/indices
tutorials/trainers
tutorials/pretrained_weights
+ tutorials/earth_surface_water
.. toctree::
:maxdepth: 1
diff --git a/docs/tutorials/earth_surface_water.ipynb b/docs/tutorials/earth_surface_water.ipynb
new file mode 100644
index 00000000000..e44d1e385f7
--- /dev/null
+++ b/docs/tutorials/earth_surface_water.ipynb
@@ -0,0 +1,878 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "MAR190Aszv8r"
+ },
+ "source": [
+ "# Earth Water Surface\n",
+ "\n",
+ "_Written by: Mauricio Cordeiro_"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0lZeGJNTz1y5"
+ },
+ "source": [
+ "## Introduction"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "VXG-oXGUz39X"
+ },
+ "source": [
+ "The objective of this tutorial is to go through the Earth Water Surface dataset and cover the following topics:
\n",
+ "\n",
+ "* Creating RasterDatasets, DataLoaders and Samplers for images and masks;\n",
+ "* Intersection Dataset;\n",
+ "* Normalizing the data;\n",
+ "* Creating spectral indices;\n",
+ "* Creating the segmentation model (DeepLabV3);\n",
+ "* Loss function and metrics; and\n",
+ "* Training loop.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "JTt_Ysyl4El5"
+ },
+ "source": [
+ "## Environment"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-U9kWwoL4GqT"
+ },
+ "source": [
+ "For the environment, we will install the torchgeo and scikit-learn packages."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "A14-syGAFahE",
+ "outputId": "d5b0ac1d-5cc2-4532-b2e6-7134638e9389"
+ },
+ "outputs": [],
+ "source": [
+ "%pip install torchgeo scikit-learn"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import tempfile\n",
+ "from collections.abc import Callable, Iterable\n",
+ "from pathlib import Path\n",
+ "\n",
+ "import kornia.augmentation as K\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import rasterio as rio\n",
+ "import torch\n",
+ "from sklearn.metrics import jaccard_score\n",
+ "from torch.utils.data import DataLoader\n",
+ "\n",
+ "from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples, utils\n",
+ "from torchgeo.samplers import RandomGeoSampler, Units\n",
+ "from torchgeo.transforms import indices"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "73l1wc-fFWuU"
+ },
+ "source": [
+ "## Dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "yY7mCdoq00Yo"
+ },
+ "source": [
+ "The dataset we will use is the Earth Surface Water dataset [1] (licensed under Creative Commons Attribution 4.0 International Public License), which has patches from different parts of the world (Figure below) and its corresponding water masks. The dataset uses optical imagery from Sentinel-2 satellite with 10m of spatial resolution.\n",
+ "\n",
+ "![Image1](https://raw.githubusercontent.com/xinluo2018/WatNet/main/figures/dataset.png)\n",
+ "\n",
+ "[1] Xin Luo. (2021). Earth Surface Water Dataset [Data set]. Zenodo. https://doi.org/10.5281/zenodo.5205674\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Download and extract dataset to a temp folder\n",
+ "tmp_path = Path(tempfile.gettempdir()) / 'surface_water/'\n",
+ "utils.download_and_extract_archive(\n",
+ " 'https://hf.co/datasets/cordmaur/earth_surface_water/resolve/main/earth_surface_water.zip',\n",
+ " tmp_path,\n",
+ ")\n",
+ "\n",
+ "# Set the root to the extracted folder\n",
+ "root = tmp_path / 'dset-s2'"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "abnT63f1GOh8"
+ },
+ "source": [
+ "## Creating the Datasets"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "7ke5sQ1_4nEq"
+ },
+ "source": [
+ "Now that we have the original dataset already uncompressed in Colab’s environment, we can prepare it to be loaded into a neural network. For that, we will create an instance of the RasterDataset class, provided by TorchGeo, and point to the specific directory, using the following commands. The `scale` function will apply the `1e-4` scale necessary to get the Sentinel-2 values in reflectance. Once the datasets are created, we can combine images with masks (labels) using the `&` operator."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "iXiPVrjXGSes"
+ },
+ "outputs": [],
+ "source": [
+ "def scale(item: dict):\n",
+ " item['image'] = item['image'] / 10000\n",
+ " return item"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "yWApUYVYZ1Lr"
+ },
+ "outputs": [],
+ "source": [
+ "train_imgs = RasterDataset(\n",
+ " paths=(root / 'tra_scene').as_posix(), crs='epsg:3395', res=10, transforms=scale\n",
+ ")\n",
+ "train_msks = RasterDataset(\n",
+ " paths=(root / 'tra_truth').as_posix(), crs='epsg:3395', res=10\n",
+ ")\n",
+ "\n",
+ "valid_imgs = RasterDataset(\n",
+ " paths=(root / 'val_scene').as_posix(), crs='epsg:3395', res=10, transforms=scale\n",
+ ")\n",
+ "valid_msks = RasterDataset(\n",
+ " paths=(root / 'val_truth').as_posix(), crs='epsg:3395', res=10\n",
+ ")\n",
+ "\n",
+ "# IMPORTANT\n",
+ "train_msks.is_image = False\n",
+ "valid_msks.is_image = False\n",
+ "\n",
+ "train_dset = train_imgs & train_msks\n",
+ "valid_dset = valid_imgs & valid_msks\n",
+ "\n",
+ "# Create the samplers\n",
+ "\n",
+ "train_sampler = RandomGeoSampler(train_imgs, size=512, length=130, units=Units.PIXELS)\n",
+ "valid_sampler = RandomGeoSampler(valid_imgs, size=512, length=64, units=Units.PIXELS)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "MaTZ03eJ5FGa"
+ },
+ "source": [
+ "Note that we are specifying the CRS (Coordinate Reference System) to EPSG:3395. TorchGeo requires that all the images are loaded in the same CRS. However, the patches in the dataset are in different UTM projections and the default behavior of TorchGeo is to use the first CRS found as its default. In this case, we have to inform a CRS that is able to cope with these different regions around the globe. To minimize the deformations due to the huge differences in latitude (I can create a history specific for this purpose) within the patches, I have selected World Mercator as the main CRS for the project. Figure 3 shows the world projected in World Mercator CRS.\n",
+ "\n",
+ "\n",
+ "![Image2](https://miro.medium.com/max/4800/1*sUdRKEfIAbm79jpB3bCShQ.webp)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "TqNOU7WaOJ2t"
+ },
+ "source": [
+ "### Understanding the sampler"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8RFrF3bTOSJn"
+ },
+ "source": [
+ "To create training patches that can be fed into a neural network from our dataset, we need to select samples of fixed sizes. TorchGeo has many samplers, but here we will use the `RandomGeoSampler` class. Basically, the sampler selects random bounding boxes of fixed size that belongs to the original image. Then, these bounding boxes are used in the `RasterDataset` to query the portion of the image we want. Here is an exmple using the previously created samplers."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "r8VGVIWNPI_W",
+ "outputId": "5b779cd3-25bc-4ec4-e29d-99391e906a4d"
+ },
+ "outputs": [],
+ "source": [
+ "bbox = next(iter(train_sampler))\n",
+ "bbox"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "KFCssvlqPI87",
+ "outputId": "4bb06a55-375e-4c14-d790-2429d4f14d0a"
+ },
+ "outputs": [],
+ "source": [
+ "sample = train_dset[bbox]\n",
+ "sample.keys()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "G-gF-8k1PI6N",
+ "outputId": "60002cb9-8d8c-4e1e-ba18-51edf4191ea3"
+ },
+ "outputs": [],
+ "source": [
+ "sample['image'].shape, sample['mask'].shape"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "dGiqU2TcPcTq"
+ },
+ "source": [
+ "Notice we have now patches of same size (..., 512 x 512)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "TBNfy4X5Pn-G"
+ },
+ "source": [
+ "## Creating Dataloaders"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "a97dvCYVP5D5"
+ },
+ "source": [
+ "Creating a `DataLoader` in TorchGeo is very straightforward, just like it is with Pytorch (we are actually using the same class). Note below that we are also using the same samplers already defined. Additionally we inform the dataset that the dataloader will use to pull data from, the batch_size (number of samples in each batch) and a collate function that specifies how to “concatenate” the multiple samples into one single batch.\n",
+ "\n",
+ "Finally, we can iterate through the dataloader to grab batches from it. To test it, we will get the first batch."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "IhWa0SYfav2V",
+ "outputId": "6e80207e-30cf-46b8-d507-e15c0619aa9d"
+ },
+ "outputs": [],
+ "source": [
+ "# Adjust the batch size according to your GPU memory\n",
+ "train_dataloader = DataLoader(\n",
+ " train_dset, sampler=train_sampler, batch_size=4, collate_fn=stack_samples\n",
+ ")\n",
+ "valid_dataloader = DataLoader(\n",
+ " valid_dset, sampler=valid_sampler, batch_size=4, collate_fn=stack_samples\n",
+ ")\n",
+ "\n",
+ "train_batch = next(iter(train_dataloader))\n",
+ "valid_batch = next(iter(valid_dataloader))\n",
+ "train_batch.keys(), valid_batch.keys()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "o7VRAzpkQIvr"
+ },
+ "source": [
+ "## Batch Visualization"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "78IKqlWUQMTM"
+ },
+ "source": [
+ "Now that we can draw batches from our datasets, let’s create a function to display the batches.\n",
+ "\n",
+ "The function `plot_batch` will will check automatically the number of items in the batch and if there are masks associated to arrange the output grid accordingly."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "zynH3etxQQmG"
+ },
+ "outputs": [],
+ "source": [
+ "def plot_imgs(\n",
+ " images: Iterable, axs: Iterable, chnls: list[int] = [2, 1, 0], bright: float = 3.0\n",
+ "):\n",
+ " for img, ax in zip(images, axs):\n",
+ " arr = torch.clamp(bright * img, min=0, max=1).numpy()\n",
+ " rgb = arr.transpose(1, 2, 0)[:, :, chnls]\n",
+ " ax.imshow(rgb)\n",
+ " ax.axis('off')\n",
+ "\n",
+ "\n",
+ "def plot_msks(masks: Iterable, axs: Iterable):\n",
+ " for mask, ax in zip(masks, axs):\n",
+ " ax.imshow(mask.squeeze().numpy(), cmap='Blues')\n",
+ " ax.axis('off')\n",
+ "\n",
+ "\n",
+ "def plot_batch(\n",
+ " batch: dict,\n",
+ " bright: float = 3.0,\n",
+ " cols: int = 4,\n",
+ " width: int = 5,\n",
+ " chnls: list[int] = [2, 1, 0],\n",
+ "):\n",
+ " # Get the samples and the number of items in the batch\n",
+ " samples = unbind_samples(batch.copy())\n",
+ "\n",
+ " # if batch contains images and masks, the number of images will be doubled\n",
+ " n = 2 * len(samples) if ('image' in batch) and ('mask' in batch) else len(samples)\n",
+ "\n",
+ " # calculate the number of rows in the grid\n",
+ " rows = n // cols + (1 if n % cols != 0 else 0)\n",
+ "\n",
+ " # create a grid\n",
+ " _, axs = plt.subplots(rows, cols, figsize=(cols * width, rows * width))\n",
+ "\n",
+ " if ('image' in batch) and ('mask' in batch):\n",
+ " # plot the images on the even axis\n",
+ " plot_imgs(\n",
+ " images=map(lambda x: x['image'], samples),\n",
+ " axs=axs.reshape(-1)[::2],\n",
+ " chnls=chnls,\n",
+ " bright=bright,\n",
+ " )\n",
+ "\n",
+ " # plot the masks on the odd axis\n",
+ " plot_msks(masks=map(lambda x: x['mask'], samples), axs=axs.reshape(-1)[1::2])\n",
+ "\n",
+ " else:\n",
+ " if 'image' in batch:\n",
+ " plot_imgs(\n",
+ " images=map(lambda x: x['image'], samples),\n",
+ " axs=axs.reshape(-1),\n",
+ " chnls=chnls,\n",
+ " bright=bright,\n",
+ " )\n",
+ "\n",
+ " elif 'mask' in batch:\n",
+ " plot_msks(masks=map(lambda x: x['mask'], samples), axs=axs.reshape(-1))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 672
+ },
+ "id": "0S4DZa8aQd8Z",
+ "outputId": "755d6f7a-d4fd-4cab-ec1d-c293255a7e8c"
+ },
+ "outputs": [],
+ "source": [
+ "plot_batch(train_batch)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "SkGQoaWlQVFC"
+ },
+ "source": [
+ "## Data Standardization and Spectral Indices"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "mvce_wKMQ3pT"
+ },
+ "source": [
+ "Normally, machine learning methods (deep learning included) benefit from feature scaling. That means standard deviation around 1 and zero mean, by applying the following formula:
\n",
+ "$X'=\\frac{X-Mean}{\\text{Standard deviation}}$\n",
+ "\n",
+ "To do that, we need to first find the mean and standard deviation for each one of the 6s channels in the dataset.\n",
+ "\n",
+ "Let’s define a function calculate these statistics and write its results in the variables mean and std. We will use our previously installed rasterio package to open the images and perform a simple average over the statistics for each batch/channel. For the standard deviation, this method is an approximation. For a more precise calculation, please refer to: http://notmatthancock.github.io/2017/03/23/simple-batch-stat-updates.htm."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "IZ5yXZPzNIdh"
+ },
+ "outputs": [],
+ "source": [
+ "def calc_statistics(dset: RasterDataset):\n",
+ " \"\"\"\n",
+ " Calculate the statistics (mean and std) for the entire dataset\n",
+ " Warning: This is an approximation. The correct value should take into account the\n",
+ " mean for the whole dataset for computing individual stds.\n",
+ " For correctness I suggest checking: http://notmatthancock.github.io/2017/03/23/simple-batch-stat-updates.html\n",
+ " \"\"\"\n",
+ "\n",
+ " # To avoid loading the entire dataset in memory, we will loop through each img\n",
+ " # The filenames will be retrieved from the dataset's rtree index\n",
+ " files = [\n",
+ " item.object for item in dset.index.intersection(dset.index.bounds, objects=True)\n",
+ " ]\n",
+ "\n",
+ " # Reseting statistics\n",
+ " accum_mean = 0\n",
+ " accum_std = 0\n",
+ "\n",
+ " for file in files:\n",
+ " img = rio.open(file).read() / 10000 # type: ignore\n",
+ " accum_mean += img.reshape((img.shape[0], -1)).mean(axis=1)\n",
+ " accum_std += img.reshape((img.shape[0], -1)).std(axis=1)\n",
+ "\n",
+ " # at the end, we shall have 2 vectors with lenght n=chnls\n",
+ " # we will average them considering the number of images\n",
+ " return accum_mean / len(files), accum_std / len(files)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "4VtubMAPxXYq"
+ },
+ "outputs": [],
+ "source": [
+ "# Calculate the statistics (Mean and std) for the dataset\n",
+ "mean, std = calc_statistics(train_imgs)\n",
+ "\n",
+ "# Please, note that we will create spectral indices using the raw (non-normalized) data. Then, when normalizing, the sensors will have more channels (the indices) that should not be normalized.\n",
+ "# To solve this, we will add the indices to the 0's to the mean vector and 1's to the std vectors\n",
+ "mean = np.concat([mean, [0, 0, 0]])\n",
+ "std = np.concat([std, [1, 1, 1]])\n",
+ "\n",
+ "norm = K.Normalize(mean=mean, std=std)\n",
+ "\n",
+ "tfms = torch.nn.Sequential(\n",
+ " indices.AppendNDWI(index_green=1, index_nir=3),\n",
+ " indices.AppendNDWI(index_green=1, index_nir=5),\n",
+ " indices.AppendNDVI(index_nir=3, index_red=2),\n",
+ " norm,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "hVsDAOaVRhiN",
+ "outputId": "b128640d-b3bc-4cf6-82e1-187b5df2a164"
+ },
+ "outputs": [],
+ "source": [
+ "transformed_img = tfms(train_batch['image'])\n",
+ "print(transformed_img.shape)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HA-jTEKeRuA4"
+ },
+ "source": [
+ "Note that our transformed batch has now 9 channels, instead of 6."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "hWfUOS1RR14g"
+ },
+ "source": [
+ "> Important: the normalize method we created will apply the normalization just to the original bands and it will ignore the previously appended indices. That’s important to avoid errors due to distinct shapes between the batch and the mean and std vectors."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "JfNmPcwWSNFv"
+ },
+ "source": [
+ "## Segmentation Model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "bMmXdpZWSSNR"
+ },
+ "source": [
+ "For the semantic segmentation model, we are going to use a predefined architecture that is available in Pytorch. Looking at list (https://pytorch.org/vision/stable/models.html#semantic-segmentation) it is possible to note 3 models available for semantic segmentation, but one (LRASPP) is intended for mobile applications. In our tutorial, we will use the DeepLabV3 model.\n",
+ "\n",
+ "Here, we will create a DeepLabV3 model for 2 classes. In this case, I will skip the pretrained weights, as the weights represent another domain (not water segmentation from multispectral imagery)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "OzhOtV3IyubJ",
+ "outputId": "b7e57773-1f4c-4140-9039-b0e84e59aa58"
+ },
+ "outputs": [],
+ "source": [
+ "from torchvision.models.segmentation import deeplabv3_resnet50\n",
+ "\n",
+ "model = deeplabv3_resnet50(weights=None, num_classes=2)\n",
+ "model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "AvsK_LXuSmnL"
+ },
+ "source": [
+ "The first thing we have to pay attention in the model architecture is the number of channels expected in the first convolution (Conv2d), that is defined as 3. That’s because the model is prepared to work with RGB images. After the first convolution, the 3 channels will produce 64 channels in lower resolution, and so on. As we have now 9 channels, we will change this first processing layer to adapt correctly to our model. We can do this by replacing the first convolutional layer for a new one, by following the commands. Finally, we check a mock batch can pass through the model and provide the output with 2 channels (water / no_water) as desired."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "xYPdGBPOSfHJ",
+ "outputId": "7e9267b7-a4e7-428d-a0cd-9292dd2b1923"
+ },
+ "outputs": [],
+ "source": [
+ "backbone = model.get_submodule('backbone')\n",
+ "\n",
+ "conv = torch.nn.modules.conv.Conv2d(\n",
+ " in_channels=9,\n",
+ " out_channels=64,\n",
+ " kernel_size=(7, 7),\n",
+ " stride=(2, 2),\n",
+ " padding=(3, 3),\n",
+ " bias=False,\n",
+ ")\n",
+ "backbone.register_module('conv1', conv)\n",
+ "\n",
+ "pred = model(torch.randn(3, 9, 512, 512))\n",
+ "pred['out'].shape"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "PvvGWtMrTdCE"
+ },
+ "source": [
+ "## Training Loop"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "YuUPhlzxTftk"
+ },
+ "source": [
+ "The training function should receive the number of epochs, the model, the dataloaders, the loss function (to be optimized) the accuracy function (to assess the results), the optimizer (that will adjust the parameters of the model in the correct direction) and the transformations to be applied to each batch."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Check if GPU is available\n",
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
+ "device"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "sFlsT8glSq3x"
+ },
+ "outputs": [],
+ "source": [
+ "def train_loop(\n",
+ " epochs: int,\n",
+ " train_dl: DataLoader,\n",
+ " val_dl: DataLoader | None,\n",
+ " model: torch.nn.Module,\n",
+ " loss_fn: Callable,\n",
+ " optimizer: torch.optim.Optimizer,\n",
+ " acc_fns: list | None = None,\n",
+ " batch_tfms: Callable | None = None,\n",
+ "):\n",
+ " # size = len(dataloader.dataset)\n",
+ " cuda_model = model.to(device)\n",
+ "\n",
+ " for epoch in range(epochs):\n",
+ " accum_loss = 0\n",
+ " for batch in train_dl:\n",
+ " if batch_tfms is not None:\n",
+ " X = batch_tfms(batch['image']).to(device)\n",
+ " else:\n",
+ " X = batch['image'].to(device)\n",
+ "\n",
+ " y = batch['mask'].type(torch.long).to(device)\n",
+ " pred = cuda_model(X)['out']\n",
+ " loss = loss_fn(pred, y)\n",
+ "\n",
+ " # BackProp\n",
+ " optimizer.zero_grad()\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ "\n",
+ " # update the accum loss\n",
+ " accum_loss += float(loss) / len(train_dl)\n",
+ "\n",
+ " # Testing against the validation dataset\n",
+ " if acc_fns is not None and val_dl is not None:\n",
+ " # reset the accuracies metrics\n",
+ " acc = [0.0] * len(acc_fns)\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " for batch in val_dl:\n",
+ " if batch_tfms is not None:\n",
+ " X = batch_tfms(batch['image']).to(device)\n",
+ " else:\n",
+ " X = batch['image'].type(torch.float32).to(device)\n",
+ "\n",
+ " y = batch['mask'].type(torch.long).to(device)\n",
+ "\n",
+ " pred = cuda_model(X)['out']\n",
+ "\n",
+ " for i, acc_fn in enumerate(acc_fns):\n",
+ " acc[i] = float(acc[i] + acc_fn(pred, y) / len(val_dl))\n",
+ "\n",
+ " # at the end of the epoch, print the errors, etc.\n",
+ " print(\n",
+ " f'Epoch {epoch}: Train Loss={accum_loss:.5f} - Accs={[round(a, 3) for a in acc]}'\n",
+ " )\n",
+ " else:\n",
+ " print(f'Epoch {epoch}: Train Loss={accum_loss:.5f}')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "kUIPlgndUB_9"
+ },
+ "source": [
+ "## Loss and Accuracy Functions"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Xs5LOs8LUS7j"
+ },
+ "source": [
+ "For the loss function, normally the Cross Entropy Loss should work, but it requires the mask to have shape (N, d1, d2). In this case, we will need to squeeze our second dimension manually."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "O8K3e5NwTrYg"
+ },
+ "outputs": [],
+ "source": [
+ "def oa(pred, y):\n",
+ " flat_y = y.squeeze()\n",
+ " flat_pred = pred.argmax(dim=1)\n",
+ " acc = torch.count_nonzero(flat_y == flat_pred) / torch.numel(flat_y)\n",
+ " return acc\n",
+ "\n",
+ "\n",
+ "def iou(pred, y):\n",
+ " flat_y = y.cpu().numpy().squeeze()\n",
+ " flat_pred = pred.argmax(dim=1).detach().cpu().numpy()\n",
+ " return jaccard_score(flat_y.reshape(-1), flat_pred.reshape(-1), zero_division=1.0)\n",
+ "\n",
+ "\n",
+ "def loss(p, t):\n",
+ " return torch.nn.functional.cross_entropy(p, t.squeeze())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "UW6YfmnyUa1A"
+ },
+ "source": [
+ "## Training"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "xajn1unSUh3h"
+ },
+ "source": [
+ "> To train the model it is important to have CUDA GPUs available. In Colab, it can be done by changing the runtime type and re-running the notebook. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "moU9ol78UUqP",
+ "outputId": "a1b5af92-b7fe-43c0-c4cd-f92906a46fb3"
+ },
+ "outputs": [],
+ "source": [
+ "# adjust number of epochs depending on the device\n",
+ "if torch.cuda.is_available():\n",
+ " num_epochs = 2\n",
+ "else:\n",
+ " # if GPU is not available, just make 1 pass and limit the size of the datasets\n",
+ " num_epochs = 1\n",
+ "\n",
+ " # by limiting the length of the sampler we limit the iterations in each epoch\n",
+ " train_dataloader.sampler.length = 8\n",
+ " valid_dataloader.sampler.length = 8\n",
+ "\n",
+ "# train the model\n",
+ "optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.01)\n",
+ "train_loop(\n",
+ " num_epochs,\n",
+ " train_dataloader,\n",
+ " valid_dataloader,\n",
+ " model,\n",
+ " loss,\n",
+ " optimizer,\n",
+ " acc_fns=[oa, iou],\n",
+ " batch_tfms=tfms,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Additional Reading"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This tutorial is also available as a 3 parts Medium story:
https://medium.com/towards-data-science/artificial-intelligence-for-geospatial-analysis-with-pytorchs-torchgeo-part-1-52d17e409f09"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "authorship_tag": "ABX9TyPk0gtwHzQoTqfC6uudCTRe",
+ "include_colab_link": true,
+ "provenance": [],
+ "toc_visible": true
+ },
+ "gpuClass": "standard",
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.13.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}