Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CropHarvest Dataset #1677

Merged
merged 49 commits into from
Jan 19, 2024
Merged
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
e7e8d90
initial commit
Oct 18, 2023
8aaba4f
Merge branch 'microsoft:main' into add-cropharvest-dataset
GeorgeHuber Oct 18, 2023
e8e38f7
Added functionality to cropharvest dataset
Nov 2, 2023
48f2832
Added test coverage
Nov 3, 2023
b69553c
test fixes
Nov 13, 2023
4c2837a
mdpy typing
Nov 15, 2023
2db6cc3
Merge branch 'main' into add-cropharvest-dataset
GeorgeHuber Nov 15, 2023
faec50b
flake8 revision
Nov 15, 2023
5c2eb11
Merge branch 'add-cropharvest-dataset' of https://github.com/GeorgeHu…
Nov 15, 2023
3d7633c
added docs
Nov 15, 2023
c044c29
fixed h5py import
Nov 15, 2023
8c78e8b
fix .rst underline
Nov 15, 2023
1c118d7
updated tests to mock h5py module
Nov 15, 2023
6e53470
fixed documentation
Nov 29, 2023
3ef552c
fixed black formating
Nov 29, 2023
dc2f025
turn labels to tensors
Nov 29, 2023
6b5a25a
fix data generationa and mdpy for tensor encoding
Nov 29, 2023
3747805
update verify model
Nov 30, 2023
3289c76
doc style
Nov 30, 2023
cca9bf6
test coverage
Nov 30, 2023
923a9ec
Merge branch 'main' into add-cropharvest-dataset
GeorgeHuber Nov 30, 2023
d7dee4e
fix test coverage leaks
Nov 30, 2023
71bdc53
Update torchgeo/datasets/cropharvest.py
GeorgeHuber Nov 30, 2023
59c94a4
Update torchgeo/datasets/cropharvest.py
GeorgeHuber Nov 30, 2023
50ce700
Update torchgeo/datasets/cropharvest.py
GeorgeHuber Nov 30, 2023
dd2f1b6
Update torchgeo/datasets/cropharvest.py
GeorgeHuber Nov 30, 2023
3444bb7
update test data path and monkeypatch
Nov 30, 2023
c4a0f6e
Update torchgeo/datasets/cropharvest.py
calebrob6 Dec 2, 2023
773bdb4
Update torchgeo/datasets/cropharvest.py
calebrob6 Dec 2, 2023
9dc5053
remove hard coded classes
Dec 29, 2023
6992afa
fixed plot and label one hot encoding
Jan 1, 2024
a115255
refactor datasetnotfounderror
Jan 9, 2024
daba36d
resolve conflict with main
Jan 9, 2024
431ae85
Merge branch 'main' into add-cropharvest-dataset
GeorgeHuber Jan 9, 2024
3a8655c
refactored importerror
Jan 9, 2024
9408b37
mdpy
Jan 9, 2024
0d5002c
Merge branch 'main' into add-cropharvest-dataset
GeorgeHuber Jan 10, 2024
f8e05bb
Update torchgeo/datasets/cropharvest.py
GeorgeHuber Jan 15, 2024
4fcd286
Update torchgeo/datasets/cropharvest.py
GeorgeHuber Jan 15, 2024
cd43eaa
Update torchgeo/datasets/cropharvest.py
GeorgeHuber Jan 15, 2024
e71e1c0
Update torchgeo/datasets/cropharvest.py
GeorgeHuber Jan 15, 2024
be55d3a
Update torchgeo/datasets/cropharvest.py
GeorgeHuber Jan 15, 2024
4f87ac2
Update torchgeo/datasets/cropharvest.py
GeorgeHuber Jan 15, 2024
a2c403c
Update torchgeo/datasets/cropharvest.py
GeorgeHuber Jan 15, 2024
5ac3e55
Update torchgeo/datasets/cropharvest.py
GeorgeHuber Jan 15, 2024
c96efcd
Merge branch 'main' into add-cropharvest-dataset
GeorgeHuber Jan 15, 2024
2616637
Remove empty class and correct csv
Jan 16, 2024
a030bd1
Update cropharvest.py
GeorgeHuber Jan 19, 2024
7d4e57f
formatting changes
Jan 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
@@ -206,6 +206,11 @@ COWC
.. autoclass:: COWCCounting
.. autoclass:: COWCDetection

CropHarvest
^^^^^^^^^^^

.. autoclass:: CropHarvest

Kenya Crop Type
^^^^^^^^^^^^^^^

3 changes: 2 additions & 1 deletion docs/api/non_geo_datasets.csv
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`ChaBuD`_,CD,Sentinel-2,"OpenRAIL",356,2,512x512,10,MSI
`Cloud Cover Detection`_,S,Sentinel-2,"CC-BY-4.0","22,728",2,512x512,10,MSI
`COWC`_,"C, R","CSUAV AFRL, ISPRS, LINZ, AGRC","AGPL-3.0-only","388,435",2,256x256,0.15,RGB
`CropHarvest`_,"C","Sentinel-1/2, SRTM, ERA5","CC-BY-SA-4.0","70,213",351,1x1,10,"SAR, MSI, SRTM"
`Kenya Crop Type`_,S,Sentinel-2,"CC-BY-SA-4.0","4,688",7,"3,035x2,016",10,MSI
`DeepGlobe Land Cover`_,S,DigitalGlobe +Vivid,-,803,7,"2,448x2,448",0.5,RGB
`DFC2022`_,S,Aerial,"CC-BY-4.0","3,981",15,"2,000x2,000",0.5,RGB
@@ -49,4 +50,4 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`VHR-10`_,I,"Google Earth, Vaihingen","MIT",800,10,"358--1,728",0.08--2,RGB
`Western USA Live Fuel Moisture`_,R,"Landsat8, Sentinel-1","CC-BY-NC-ND-4.0",2615,-,-,-,-
`xView2`_,CD,Maxar,"CC-BY-NC-SA-4.0","3,732",4,"1,024x1,024",0.8,RGB
`ZueriCrop`_,"I, T",Sentinel-2,-,116K,48,24x24,10,MSI
`ZueriCrop`_,"I, T",Sentinel-2,-,116K,48,24x24,10,MSI
153 changes: 153 additions & 0 deletions tests/data/cropharvest/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import json
import os
import shutil

import h5py
import numpy as np

SIZE = 32

np.random.seed(0)

PATHS = [
os.path.join("cropharvest", "features", "arrays", "0_TestDataset1.h5"),
os.path.join("cropharvest", "features", "arrays", "1_TestDataset1.h5"),
os.path.join("cropharvest", "features", "arrays", "2_TestDataset1.h5"),
os.path.join("cropharvest", "features", "arrays", "0_TestDataset2.h5"),
os.path.join("cropharvest", "features", "arrays", "1_TestDataset2.h5"),
]


def create_geojson():
geojson = {
"type": "FeatureCollection",
"crs": {},
"features": [
{
"type": "Feature",
"properties": {
"dataset": "TestDataset1",
"index": 0,
"is_crop": 1,
"label": "soybean",
},
"geometry": {
"type": "Polygon",
"coordinates": [
[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]
],
},
},
{
"type": "Feature",
"properties": {
"dataset": "TestDataset1",
"index": 0,
"is_crop": 1,
"label": "alfalfa",
},
"geometry": {
"type": "Polygon",
"coordinates": [
[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]
],
},
},
{
"type": "Feature",
"properties": {
"dataset": "TestDataset1",
"index": 1,
"is_crop": 1,
"label": None,
},
"geometry": {
"type": "Polygon",
"coordinates": [
[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]
],
},
},
{
"type": "Feature",
"properties": {
"dataset": "TestDataset2",
"index": 2,
"is_crop": 1,
"label": "maize",
},
"geometry": {
"type": "Polygon",
"coordinates": [
[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]
],
},
},
{
"type": "Feature",
"properties": {
"dataset": "TestDataset2",
"index": 1,
"is_crop": 0,
"label": None,
},
"geometry": {
"type": "Polygon",
"coordinates": [
[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]
],
},
},
],
}
return geojson


def create_file(path: str) -> None:
Z = np.random.randint(4000, size=(12, 18), dtype=np.int64)
with h5py.File(path, "w") as f:
f.create_dataset("array", data=Z)


if __name__ == "__main__":
directory = "cropharvest"

# remove old data
to_remove = [
os.path.join(directory, "features"),
os.path.join(directory, "features.tar.gz"),
os.path.join(directory, "labels.geojson"),
]
for path in to_remove:
if os.path.isdir(path):
shutil.rmtree(path)

label_path = os.path.join(directory, "labels.geojson")
geojson = create_geojson()
os.makedirs(os.path.dirname(label_path), exist_ok=True)

with open(label_path, "w") as f:
json.dump(geojson, f)

for path in PATHS:
os.makedirs(os.path.dirname(path), exist_ok=True)
create_file(path)

# compress data
source_dir = os.path.join(directory, "features")
shutil.make_archive(source_dir, "gztar", directory, "features")

# compute checksum
with open(label_path, "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"{label_path}: {md5}")

with open(os.path.join(directory, "features.tar.gz"), "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"zipped features: {md5}")
Binary file added tests/data/cropharvest/features.tar.gz
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1 change: 1 addition & 0 deletions tests/data/cropharvest/labels.geojson
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type": "FeatureCollection", "crs": {}, "features": [{"type": "Feature", "properties": {"dataset": "TestDataset1", "index": 0, "is_crop": 1, "label": "soybean"}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]]}}, {"type": "Feature", "properties": {"dataset": "TestDataset1", "index": 0, "is_crop": 1, "label": "alfalfa"}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]]}}, {"type": "Feature", "properties": {"dataset": "TestDataset1", "index": 1, "is_crop": 1, "label": null}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]]}}, {"type": "Feature", "properties": {"dataset": "TestDataset2", "index": 2, "is_crop": 1, "label": "maize"}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]]}}, {"type": "Feature", "properties": {"dataset": "TestDataset2", "index": 1, "is_crop": 0, "label": null}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]]}}]}
100 changes: 100 additions & 0 deletions tests/datasets/test_cropharvest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import builtins
import os
import shutil
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from pytest import MonkeyPatch

import torchgeo.datasets.utils
from torchgeo.datasets import CropHarvest, DatasetNotFoundError

pytest.importorskip("h5py", minversion="3")


def download_url(url: str, root: str, filename: str, md5: str) -> None:
shutil.copy(url, os.path.join(root, filename))


class TestCropHarvest:
@pytest.fixture
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__

def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == "h5py":
raise ImportError()
return import_orig(name, *args, **kwargs)

monkeypatch.setattr(builtins, "__import__", mocked_import)

@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CropHarvest:
monkeypatch.setattr(torchgeo.datasets.cropharvest, "download_url", download_url)
monkeypatch.setitem(
CropHarvest.file_dict["features"], "md5", "ef6f4f00c0b3b50ed8380b0044928572"
)
monkeypatch.setitem(
CropHarvest.file_dict["labels"], "md5", "1d93b6bfcec7b6797b75acbd9d284b92"
)
monkeypatch.setitem(
CropHarvest.file_dict["features"],
"url",
os.path.join("tests", "data", "cropharvest", "features.tar.gz"),
)
monkeypatch.setitem(
CropHarvest.file_dict["labels"],
"url",
os.path.join("tests", "data", "cropharvest", "labels.geojson"),
)

root = str(tmp_path)
transforms = nn.Identity()

dataset = CropHarvest(root, transforms, download=True, checksum=True)
return dataset

def test_getitem(self, dataset: CropHarvest) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["array"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)
assert x["array"].shape == (12, 18)
y = dataset[2]
assert y["label"] == 1

def test_len(self, dataset: CropHarvest) -> None:
assert len(dataset) == 5

def test_already_downloaded(self, dataset: CropHarvest, tmp_path: Path) -> None:
CropHarvest(root=str(tmp_path), download=False)

def test_downloaded_zipped(self, dataset: CropHarvest, tmp_path: Path) -> None:
feature_path = os.path.join(tmp_path, "features")
shutil.rmtree(feature_path)
CropHarvest(root=str(tmp_path), download=True)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
CropHarvest(str(tmp_path))

def test_plot(self, dataset: CropHarvest) -> None:
x = dataset[0].copy()
dataset.plot(x, subtitle="Test")
plt.close()

def test_mock_missing_module(
self, dataset: CropHarvest, tmp_path: Path, mock_missing_module: None
) -> None:
with pytest.raises(
ImportError,
match="h5py is not installed and is required to use this dataset",
):
CropHarvest(root=str(tmp_path), download=True)[0]
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@
from .cloud_cover import CloudCoverDetection
from .cms_mangrove_canopy import CMSGlobalMangroveCanopy
from .cowc import COWC, COWCCounting, COWCDetection
from .cropharvest import CropHarvest
from .cv4a_kenya_crop_type import CV4AKenyaCropType
from .cyclone import TropicalCyclone
from .deepglobelandcover import DeepGlobeLandCover
@@ -150,6 +151,7 @@
"ChesapeakeWV",
"ChesapeakeCVPR",
"CMSGlobalMangroveCanopy",
"CropHarvest",
"EDDMapS",
"Esri2020",
"EUDEM",
318 changes: 318 additions & 0 deletions torchgeo/datasets/cropharvest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""CropHarvest datasets."""

import glob
import json
import os
from typing import Callable, Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from matplotlib.figure import Figure
from torch import Tensor

from .geo import NonGeoDataset
from .utils import DatasetNotFoundError, download_url, extract_archive


class CropHarvest(NonGeoDataset):
"""CropHarvest dataset.
`CropHarvest <https://github.com/nasaharvest/cropharvest>`__ is a
crop classification dataset.
Dataset features:
* single pixel time series with crop-type labels
* 18 bands per image over 12 months
Dataset format:
* arrays are 12x18 with 18 bands over 12 months
Dataset properties:
1. is_crop - whether or not a single pixel contains cropland
2. classification_label - optional field identifying a specific crop type
3. dataset - source dataset for the imagery
4. lat - latitude
5. lon - longitude
If you use this dataset in your research, please cite the following paper:
* https://openreview.net/forum?id=JtjzUXPEaCu
This dataset requires the following additional library to be installed:
* `h5py <https://pypi.org/project/h5py/>`_ to load the dataset
.. versionadded:: 0.6
"""

# https://github.com/nasaharvest/cropharvest/blob/main/cropharvest/bands.py
all_bands = [
"VV",
"VH",
"B2",
"B3",
"B4",
"B5",
"B6",
"B7",
"B8",
"B8A",
"B9",
"B11",
"B12",
"temperature_2m",
"total_precipitation",
"elevation",
"slope",
"NDVI",
]
rgb_bands = ["B4", "B3", "B2"]

features_url = "https://zenodo.org/records/7257688/files/features.tar.gz?download=1"
labels_url = "https://zenodo.org/records/7257688/files/labels.geojson?download=1"
file_dict = {
"features": {
"url": features_url,
"filename": "features.tar.gz",
"extracted_filename": os.path.join("features", "arrays"),
"md5": "cad4df655c75caac805a80435e46ee3e",
},
"labels": {
"url": labels_url,
"filename": "labels.geojson",
"extracted_filename": "labels.geojson",
"md5": "bf7bae6812fc7213481aff6a2e34517d",
},
}

def __init__(
self,
root: str = "data",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
download: bool = False,
checksum: bool = False,
) -> None:
"""Initialize a new CropHarvest dataset instance.
Args:
root: root directory where dataset can be found
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
DatasetNotFoundError: If dataset is not found and *download* is False.
ImportError: If h5py is not installed
"""
try:
import h5py # noqa: F401
except ImportError:
raise ImportError(
"h5py is not installed and is required to use this dataset"
)

self.root = root
self.transforms = transforms
self.checksum = checksum
self.download = download

self._verify()

self.files = self._load_features(self.root)
self.labels = self._load_labels(self.root)
self.classes = self.labels["properties.label"].unique()
self.classes = self.classes[self.classes != np.array(None)]
self.classes = np.insert(self.classes, 0, ["None", "Other"])

def __getitem__(self, index: int) -> dict[str, Tensor]:
"""Return an index within the dataset.
Args:
index: index to return
Returns:
single pixel time-series array and label at that index
"""
files = self.files[index]
data = self._load_array(files["chip"])

label = self._load_label(files["index"], files["dataset"])
sample = {"array": data, "label": label}

if self.transforms is not None:
sample = self.transforms(sample)

return sample

def __len__(self) -> int:
"""Return the number of data points in the dataset.
Returns:
length of the dataset
"""
return len(self.files)

def _load_features(self, root: str) -> list[dict[str, str]]:
"""Return the paths of the files in the dataset.
Args:
root: root dir of dataset
Returns:
list of dicts containing path for each of hd5 single pixel time series and
its key for associated data
"""
files = []
chips = glob.glob(
os.path.join(root, self.file_dict["features"]["extracted_filename"], "*.h5")
)
chips = sorted(os.path.basename(chip) for chip in chips)
for chip in chips:
chip_path = os.path.join(
root, self.file_dict["features"]["extracted_filename"], chip
)
index = chip.split("_")[0]
dataset = chip.split("_")[1][:-3]
files.append(dict(chip=chip_path, index=index, dataset=dataset))
return files

def _load_labels(self, root: str) -> pd.DataFrame:
"""Return the paths of the files in the dataset.
Args:
root: root dir of dataset
Returns:
pandas dataframe containing label data for each feature
"""
filename = self.file_dict["labels"]["extracted_filename"]
with open(os.path.join(root, filename), encoding="utf8") as f:
data = json.load(f)
df = pd.json_normalize(data["features"])
return df

def _load_array(self, path: str) -> Tensor:
"""Load an individual single pixel time series.
Args:
path: path to the image
Returns:
the image
"""
import h5py

filename = os.path.join(path)
with h5py.File(filename, "r") as f:
array = f.get("array")[()]
tensor = torch.from_numpy(array)
return tensor

def _load_label(self, idx: str, dataset: str) -> Tensor:
"""Load the crop-type label for a single pixel time series.
Args:
idx: sample index in labels.geojson
dataset: dataset name to query labels.geojson
Returns:
the crop-type label
"""
index = int(idx)
row = self.labels[
(self.labels["properties.index"] == index)
& (self.labels["properties.dataset"] == dataset)
]
row = row.to_dict(orient="records")[0]
label = "None"
if row["properties.label"]:
label = row["properties.label"]
elif row["properties.is_crop"] == 1:
label = "Other"

return torch.tensor(np.where(self.classes == label)[0][0])

def _verify(self) -> None:
"""Verify the integrity of the dataset."""
# Check if feature files already exist
feature_path = os.path.join(
self.root, self.file_dict["features"]["extracted_filename"]
)
feature_path_zip = os.path.join(
self.root, self.file_dict["features"]["filename"]
)
label_path = os.path.join(
self.root, self.file_dict["labels"]["extracted_filename"]
)
# Check if labels exist
if os.path.exists(label_path):
# Check if features exist
if os.path.exists(feature_path):
return
# Check if features are downloaded in zip format
if os.path.exists(feature_path_zip):
self._extract()
return

# Check if the user requested to download the dataset
if not self.download:
raise DatasetNotFoundError(self)

# Download and extract the dataset
self._download()
self._extract()

def _download(self) -> None:
"""Download the dataset and extract it."""
features_path = os.path.join(self.file_dict["features"]["filename"])
download_url(
self.file_dict["features"]["url"],
self.root,
filename=features_path,
md5=self.file_dict["features"]["md5"] if self.checksum else None,
)

download_url(
self.file_dict["labels"]["url"],
self.root,
filename=os.path.join(self.file_dict["labels"]["filename"]),
md5=self.file_dict["labels"]["md5"] if self.checksum else None,
)

def _extract(self) -> None:
"""Extract the dataset."""
features_path = os.path.join(self.root, self.file_dict["features"]["filename"])
extract_archive(features_path)

def plot(self, sample: dict[str, Tensor], subtitle: Optional[str] = None) -> Figure:
"""Plot a sample from the dataset using bands for Agriculture RGB composite.
Args:
sample: a sample returned by :meth:`__getitem__`
suptitle: optional subtitle to use for figure
Returns:
a matplotlib Figure with the rendered sample
"""
fig, axs = plt.subplots()
bands = [self.all_bands.index(band) for band in self.rgb_bands]
rgb = np.array(sample["array"])[:, bands] / 3000
axs.imshow(rgb[None, ...])
axs.set_title(f'Crop type: {self.classes[sample["label"]]}')
axs.set_xticks(np.arange(12))
axs.set_xticklabels(np.arange(12) + 1)
axs.set_yticks([])
axs.set_xlabel("Month")
if subtitle is not None:
plt.suptitle(subtitle)

return fig