Skip to content

Commit

Permalink
Add Digital typhoon dataset (microsoft#1748)
Browse files Browse the repository at this point in the history
* analysis task dataset

* implement sequence sampling

* add outline datamodule

* add datamodule with two way splitting capabilities

* add plotting function

* download and verify

* add unit tests but they fail

* fix tests

* fix style

* trainer testing yaml

* test split logic

* fix tests

* fix tests2

* found bug

* try to fix mypy

* h5py error docs

* fix docs

* fix one mypy error

* mypy on test file

* fix coverage

* fix tests for trainers

* fix mypy

* try typed dict

* try to fix docs

* fix pytest

* linters

* suggested changes and normalization procedure

* regression target normalization

* update dataset splitting

* fix test

* quotes

* strings

* ruff

* quotes

* ruff format on all

* docs

* lazy import

* h5py

* h5py datamodule

* typo

* tests

* review

* pass tests

* fix tests

* list -> tuple

* mypy fix

* rename

* tests

* Remove Analysis

* min pandas 2.2.0

* resolve tests

---------

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
nilsleh and adamjstewart authored Aug 29, 2024
1 parent e8e309f commit b9a09f5
Show file tree
Hide file tree
Showing 47 changed files with 954 additions and 2 deletions.
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ Deep Globe Land Cover Challenge

.. autoclass:: DeepGlobeLandCoverDataModule

Digital Typhoon
^^^^^^^^^^^^^^^

.. autoclass:: DigitalTyphoonDataModule

ETCI2021 Flood Detection
^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
6 changes: 6 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,12 @@ DFC2022

.. autoclass:: DFC2022


Digital Typhoon
^^^^^^^^^^^^^^^

.. autoclass:: DigitalTyphoon

ETCI2021 Flood Detection
^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`Kenya Crop Type`_,S,Sentinel-2,"CC-BY-SA-4.0","4,688",7,"3,035x2,016",10,MSI
`DeepGlobe Land Cover`_,S,DigitalGlobe +Vivid,-,803,7,"2,448x2,448",0.5,RGB
`DFC2022`_,S,Aerial,"CC-BY-4.0","3,981",15,"2,000x2,000",0.5,RGB
`Digital Typhoon`_,"C, R",Himawari,"CC-BY-4.0","189,364",8,512,5000,Infrared
`ETCI2021 Flood Detection`_,S,Sentinel-1,-,"66,810",2,256x256,5--20,SAR
`EuroSAT`_,C,Sentinel-2,"MIT","27,000",10,64x64,10,MSI
`FAIR1M`_,OD,Gaofen/Google Earth,"CC-BY-NC-SA-3.0","15,000",37,"1,024x1,024",0.3--0.8,RGB
Expand Down
18 changes: 18 additions & 0 deletions tests/conf/digital_typhoon_id.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
model:
class_path: RegressionTask
init_args:
model: "resnet18"
num_outputs: 1
in_channels: 3
loss: "mse"
data:
class_path: DigitalTyphoonDataModule
init_args:
batch_size: 1
split_by: "typhoon_id"
dict_kwargs:
root: "tests/data/digital_typhoon"
download: true
min_feature_value:
wind: 10
sequence_length: 3
18 changes: 18 additions & 0 deletions tests/conf/digital_typhoon_time.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
model:
class_path: RegressionTask
init_args:
model: "resnet18"
num_outputs: 1
in_channels: 3
loss: "mse"
data:
class_path: DigitalTyphoonDataModule
init_args:
batch_size: 1
split_by: "time"
dict_kwargs:
root: "tests/data/digital_typhoon"
download: true
min_feature_value:
wind: 10
sequence_length: 3
Binary file added tests/data/digital_typhoon/WP.tar.gz
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP.tar.gzaa
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP.tar.gzab
Binary file not shown.
26 changes: 26 additions & 0 deletions tests/data/digital_typhoon/WP/aux_data.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct
0,0.h5,1979,12,25,6,3,-55.81114066899345,76.6995939240727,973.8743108424701,44.98399850309952,66,71,75,137,25,95,1,1,1.h5,mask_40,89.87979469874404
0,1.h5,1979,12,25,7,3,-33.621634184914114,-25.860702927919903,903.8203398162416,6.3832427352565,230,28,61,111,72,4,1,0,2.h5,mask_40,55.86768840838465
0,2.h5,1979,12,25,8,3,72.02964248591297,-47.48138416430828,982.76724331446,0.027966770724696666,342,76,23,337,49,19,1,0,3.h5,mask_49,55.18449786430531
0,3.h5,1979,12,25,9,2,55.920575184851316,13.989913225833078,906.0181106433341,51.01642134825744,330,90,52,258,44,65,1,1,4.h5,mask_86,15.969129252036707
0,4.h5,1979,12,25,10,2,-43.28994147714503,-161.94483446959413,903.9366550400755,16.7093617045847,242,62,99,132,63,0,1,1,5.h5,mask_66,70.21971067939033
1,0.h5,1988,1,22,10,2,-33.37129190053344,-115.29637290040873,948.0758912152131,51.11399505734963,118,15,67,232,63,86,1,1,1.h5,mask_15,30.245077213336646
1,1.h5,1988,1,22,11,2,74.93228846926493,70.74999801636073,910.1992664115785,60.8348103266534,266,41,67,48,44,16,1,0,2.h5,mask_90,42.30390416164944
1,2.h5,1988,1,22,12,2,-27.931601464223597,-141.3019006863473,961.5531323907394,18.35497901874176,19,61,24,295,50,26,1,1,3.h5,mask_67,60.35785307941444
1,3.h5,1988,1,22,13,3,-27.166703710913154,-27.976214499674484,904.1165949703977,9.081723951290567,144,43,66,22,32,48,0,1,4.h5,mask_3,80.04417033291257
1,4.h5,1988,1,22,14,2,47.51657289770864,-138.58539565379158,950.9654977977864,86.18819130981862,175,75,89,42,19,70,0,1,5.h5,mask_96,0.44001778199053154
2,0.h5,1998,8,23,22,2,71.11037770397022,-170.05883586527145,902.757696015989,64.83605229043086,308,32,54,249,94,13,1,0,1.h5,mask_87,97.96789767456457
2,1.h5,1998,8,23,23,2,-45.9880469141837,-153.85203885662787,956.1578736191437,95.77226625568278,230,17,58,214,72,21,1,0,2.h5,mask_66,48.1513473689529
2,2.h5,1998,8,24,0,4,-88.778300647409,-78.43060469893915,958.764771469677,17.97662971655637,127,41,19,138,89,36,1,1,3.h5,mask_57,76.31799924098371
2,3.h5,1998,8,24,1,2,-49.56689955810804,-120.3389762632577,986.4933451650326,49.259894810485605,333,90,28,51,45,99,1,0,4.h5,mask_92,65.60333971250041
2,4.h5,1998,8,24,2,3,-52.55231579306487,80.06217230886841,997.4333837891787,48.25976623703225,63,7,13,71,55,58,1,1,5.h5,mask_73,50.634737551399034
3,0.h5,1997,4,24,16,4,-61.81374526076493,60.62026564332362,900.1093638487514,94.66595722320622,189,70,67,249,12,58,0,1,1.h5,mask_93,99.77561346276104
3,1.h5,1997,4,24,17,3,35.596382297289026,-117.20301531275722,925.1366339770796,34.46028512732848,55,55,74,11,0,49,1,1,2.h5,mask_11,5.726401727423658
3,2.h5,1997,4,24,18,1,68.16880747309938,30.42194122117013,955.7265683876137,96.55057639044118,217,22,60,6,18,9,1,1,3.h5,mask_63,58.982331802755375
3,3.h5,1997,4,24,19,3,-5.491619122910365,141.83240318855258,922.5486496962513,89.2199247408618,49,26,14,245,95,84,1,0,4.h5,mask_38,76.01607012923168
3,4.h5,1997,4,24,20,4,4.052162855787202,21.732867986138842,990.5791999912764,98.40094253121877,158,86,11,28,11,81,0,0,5.h5,mask_12,75.84036894650622
4,0.h5,1984,6,16,14,3,53.238650326925125,-54.63854263302531,934.2198641027621,18.697921579520305,212,16,42,91,90,56,1,1,1.h5,mask_72,78.93081269669048
4,1.h5,1984,6,16,15,2,-56.222689844694024,-6.8726887962189664,912.6113238303491,61.286246561868666,60,81,2,198,64,76,1,0,2.h5,mask_64,24.039173626000288
4,2.h5,1984,6,16,16,2,-4.285643464886363,95.66534210331434,962.0580147775602,86.01251389789185,281,81,5,228,18,94,0,0,3.h5,mask_66,89.89080488339964
4,3.h5,1984,6,16,17,2,89.15893201203946,124.94143678744513,997.342814284227,84.00590505469005,242,28,61,132,80,29,0,0,4.h5,mask_77,4.839048143310343
4,4.h5,1984,6,16,18,1,-46.31233638346047,21.77073986978661,932.8378121656477,26.18973887839292,294,76,57,252,99,27,1,0,5.h5,mask_65,89.74882055138497
Binary file added tests/data/digital_typhoon/WP/image/0/0.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/0/1.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/0/2.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/0/3.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/0/4.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/1/0.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/1/1.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/1/2.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/1/3.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/1/4.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/2/0.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/2/1.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/2/2.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/2/3.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/2/4.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/3/0.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/3/1.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/3/2.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/3/3.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/3/4.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/4/0.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/4/1.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/4/2.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/4/3.h5
Binary file not shown.
Binary file added tests/data/digital_typhoon/WP/image/4/4.h5
Binary file not shown.
6 changes: 6 additions & 0 deletions tests/data/digital_typhoon/WP/metadata/0.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct
0,0.h5,1979,12,25,6,3,-55.81114066899345,76.6995939240727,973.8743108424701,44.98399850309952,66,71,75,137,25,95,1,1,1.h5,mask_40,89.87979469874404
0,1.h5,1979,12,25,7,3,-33.621634184914114,-25.860702927919903,903.8203398162416,6.3832427352565,230,28,61,111,72,4,1,0,2.h5,mask_40,55.86768840838465
0,2.h5,1979,12,25,8,3,72.02964248591297,-47.48138416430828,982.76724331446,0.027966770724696666,342,76,23,337,49,19,1,0,3.h5,mask_49,55.18449786430531
0,3.h5,1979,12,25,9,2,55.920575184851316,13.989913225833078,906.0181106433341,51.01642134825744,330,90,52,258,44,65,1,1,4.h5,mask_86,15.969129252036707
0,4.h5,1979,12,25,10,2,-43.28994147714503,-161.94483446959413,903.9366550400755,16.7093617045847,242,62,99,132,63,0,1,1,5.h5,mask_66,70.21971067939033
6 changes: 6 additions & 0 deletions tests/data/digital_typhoon/WP/metadata/1.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct
1,0.h5,1988,1,22,10,2,-33.37129190053344,-115.29637290040873,948.0758912152131,51.11399505734963,118,15,67,232,63,86,1,1,1.h5,mask_15,30.245077213336646
1,1.h5,1988,1,22,11,2,74.93228846926493,70.74999801636073,910.1992664115785,60.8348103266534,266,41,67,48,44,16,1,0,2.h5,mask_90,42.30390416164944
1,2.h5,1988,1,22,12,2,-27.931601464223597,-141.3019006863473,961.5531323907394,18.35497901874176,19,61,24,295,50,26,1,1,3.h5,mask_67,60.35785307941444
1,3.h5,1988,1,22,13,3,-27.166703710913154,-27.976214499674484,904.1165949703977,9.081723951290567,144,43,66,22,32,48,0,1,4.h5,mask_3,80.04417033291257
1,4.h5,1988,1,22,14,2,47.51657289770864,-138.58539565379158,950.9654977977864,86.18819130981862,175,75,89,42,19,70,0,1,5.h5,mask_96,0.44001778199053154
6 changes: 6 additions & 0 deletions tests/data/digital_typhoon/WP/metadata/2.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct
2,0.h5,1998,8,23,22,2,71.11037770397022,-170.05883586527145,902.757696015989,64.83605229043086,308,32,54,249,94,13,1,0,1.h5,mask_87,97.96789767456457
2,1.h5,1998,8,23,23,2,-45.9880469141837,-153.85203885662787,956.1578736191437,95.77226625568278,230,17,58,214,72,21,1,0,2.h5,mask_66,48.1513473689529
2,2.h5,1998,8,24,0,4,-88.778300647409,-78.43060469893915,958.764771469677,17.97662971655637,127,41,19,138,89,36,1,1,3.h5,mask_57,76.31799924098371
2,3.h5,1998,8,24,1,2,-49.56689955810804,-120.3389762632577,986.4933451650326,49.259894810485605,333,90,28,51,45,99,1,0,4.h5,mask_92,65.60333971250041
2,4.h5,1998,8,24,2,3,-52.55231579306487,80.06217230886841,997.4333837891787,48.25976623703225,63,7,13,71,55,58,1,1,5.h5,mask_73,50.634737551399034
6 changes: 6 additions & 0 deletions tests/data/digital_typhoon/WP/metadata/3.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct
3,0.h5,1997,4,24,16,4,-61.81374526076493,60.62026564332362,900.1093638487514,94.66595722320622,189,70,67,249,12,58,0,1,1.h5,mask_93,99.77561346276104
3,1.h5,1997,4,24,17,3,35.596382297289026,-117.20301531275722,925.1366339770796,34.46028512732848,55,55,74,11,0,49,1,1,2.h5,mask_11,5.726401727423658
3,2.h5,1997,4,24,18,1,68.16880747309938,30.42194122117013,955.7265683876137,96.55057639044118,217,22,60,6,18,9,1,1,3.h5,mask_63,58.982331802755375
3,3.h5,1997,4,24,19,3,-5.491619122910365,141.83240318855258,922.5486496962513,89.2199247408618,49,26,14,245,95,84,1,0,4.h5,mask_38,76.01607012923168
3,4.h5,1997,4,24,20,4,4.052162855787202,21.732867986138842,990.5791999912764,98.40094253121877,158,86,11,28,11,81,0,0,5.h5,mask_12,75.84036894650622
6 changes: 6 additions & 0 deletions tests/data/digital_typhoon/WP/metadata/4.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct
4,0.h5,1984,6,16,14,3,53.238650326925125,-54.63854263302531,934.2198641027621,18.697921579520305,212,16,42,91,90,56,1,1,1.h5,mask_72,78.93081269669048
4,1.h5,1984,6,16,15,2,-56.222689844694024,-6.8726887962189664,912.6113238303491,61.286246561868666,60,81,2,198,64,76,1,0,2.h5,mask_64,24.039173626000288
4,2.h5,1984,6,16,16,2,-4.285643464886363,95.66534210331434,962.0580147775602,86.01251389789185,281,81,5,228,18,94,0,0,3.h5,mask_66,89.89080488339964
4,3.h5,1984,6,16,17,2,89.15893201203946,124.94143678744513,997.342814284227,84.00590505469005,242,28,61,132,80,29,0,0,4.h5,mask_77,4.839048143310343
4,4.h5,1984,6,16,18,1,-46.31233638346047,21.77073986978661,932.8378121656477,26.18973887839292,294,76,57,252,99,27,1,0,5.h5,mask_65,89.74882055138497
111 changes: 111 additions & 0 deletions tests/data/digital_typhoon/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil

import h5py
import numpy as np
import pandas as pd
from torchvision.datasets.utils import calculate_md5

# Define the root directory
root = 'WP'
IMAGE_SIZE = 32
NUM_TYHOON_IDS = 5
NUM_IMAGES_PER_ID = 5
CHUNK_SIZE = 2**12

# If the root directory exists, remove it
if os.path.exists(root):
shutil.rmtree(root)

# Create the 'image' and 'metadata' directories
os.makedirs(os.path.join(root, 'image'))
os.makedirs(os.path.join(root, 'metadata'))

# For each typhoon_id
all_dfs = []
for typhoon_id in range(NUM_TYHOON_IDS):
# Create a directory under 'root/image/typhoon_id/'
os.makedirs(os.path.join(root, 'image', str(typhoon_id)), exist_ok=True)

# Create dummy .h5 files
image_paths_per_typhoon = []
for image_id in range(NUM_IMAGES_PER_ID):
image_file_name = f'{image_id}.h5'
with h5py.File(
os.path.join(root, 'image', str(typhoon_id), image_file_name), 'w'
) as hf:
hf.create_dataset('Infrared', data=np.random.rand(IMAGE_SIZE, IMAGE_SIZE))
image_paths_per_typhoon.append(image_file_name)

start_time = pd.Timestamp(
year=np.random.randint(1978, 2022),
month=np.random.randint(1, 13),
day=np.random.randint(1, 29),
hour=np.random.randint(0, 24),
)
times = pd.date_range(start=start_time, periods=NUM_IMAGES_PER_ID, freq='H')
df = pd.DataFrame(
{
'id': np.repeat(typhoon_id, NUM_IMAGES_PER_ID),
'image_path': image_paths_per_typhoon,
'year': times.year,
'month': times.month,
'day': times.day,
'hour': times.hour,
'grade': np.random.randint(1, 5, NUM_IMAGES_PER_ID),
'lat': np.random.uniform(-90, 90, NUM_IMAGES_PER_ID),
'lng': np.random.uniform(-180, 180, NUM_IMAGES_PER_ID),
'pressure': np.random.uniform(900, 1000, NUM_IMAGES_PER_ID),
'wind': np.random.uniform(0, 100, NUM_IMAGES_PER_ID),
'dir50': np.random.randint(0, 360, NUM_IMAGES_PER_ID),
'long50': np.random.randint(0, 100, NUM_IMAGES_PER_ID),
'short50': np.random.randint(0, 100, NUM_IMAGES_PER_ID),
'dir30': np.random.randint(0, 360, NUM_IMAGES_PER_ID),
'long30': np.random.randint(0, 100, NUM_IMAGES_PER_ID),
'short30': np.random.randint(0, 100, NUM_IMAGES_PER_ID),
'landfall': np.random.randint(0, 2, NUM_IMAGES_PER_ID),
'intp': np.random.randint(0, 2, NUM_IMAGES_PER_ID),
'file_1': [f'{idx}.h5' for idx in range(1, NUM_IMAGES_PER_ID + 1)],
'mask_1': [
'mask_' + str(i) for i in np.random.randint(1, 100, NUM_IMAGES_PER_ID)
],
'mask_1_pct': np.random.uniform(0, 100, NUM_IMAGES_PER_ID),
}
)

# Save the DataFrame to correspoding typhoon id as metadata
df.to_csv(os.path.join(root, 'metadata', f'{typhoon_id}.csv'), index=False)

all_dfs.append(df)

# Save the aux_data.csv
aux_data = pd.concat(all_dfs)
aux_data.to_csv(os.path.join(root, 'aux_data.csv'), index=False)


# Create tarball
shutil.make_archive(root, 'gztar', '.', root)

# simulate multiple tar files
path = f'{root}.tar.gz'
paths = []
with open(path, 'rb') as f:
# Write the entire tarball to gzaa
split = f'{path}aa'
with open(split, 'wb') as g:
g.write(f.read())
paths.append(split)

# Create gzab as a copy of gzaa
shutil.copy2(f'{path}aa', f'{path}ab')
paths.append(f'{path}ab')


# Calculate the md5sum of the tar file
for path in paths:
print(f'{path}: {calculate_md5(path)}')
70 changes: 70 additions & 0 deletions tests/datamodules/test_digital_typhoon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Test Digital Typhoon Datamodule."""

import os

import pytest

from torchgeo.datamodules import DigitalTyphoonDataModule
from torchgeo.datasets.digital_typhoon import DigitalTyphoon, _SampleSequenceDict

pytest.importorskip('h5py', minversion='3.6')


class TestDigitalTyphoonDataModule:
def test_invalid_param_config(self) -> None:
with pytest.raises(AssertionError, match='Please choose from'):
DigitalTyphoonDataModule(
root=os.path.join('tests', 'data', 'digital_typhoon'),
split_by='invalid',
batch_size=2,
num_workers=0,
)

@pytest.mark.parametrize('split_by', ['time', 'typhoon_id'])
def test_split_dataset(self, split_by: str) -> None:
dm = DigitalTyphoonDataModule(
root=os.path.join('tests', 'data', 'digital_typhoon'),
split_by=split_by,
batch_size=2,
num_workers=0,
)
dataset = DigitalTyphoon(root=os.path.join('tests', 'data', 'digital_typhoon'))
train_indices, val_indices = dm._split_dataset(dataset.sample_sequences)
train_sequences, val_sequences = (
[dataset.sample_sequences[i] for i in train_indices],
[dataset.sample_sequences[i] for i in val_indices],
)

if split_by == 'time':

def find_max_time_per_id(
split_sequences: list[_SampleSequenceDict],
) -> dict[str, int]:
# Find the maximum value of each id in train_sequences
max_values: dict[str, int] = {}
for seq in split_sequences:
id: str = str(seq['id'])
value: int = max(seq['seq_id'])
if id not in max_values or value > max_values[id]:
max_values[id] = value
return max_values

train_max_values = find_max_time_per_id(train_sequences)
val_max_values = find_max_time_per_id(val_sequences)
# Assert that each max value in train_max_values is lower
# than in val_max_values for each key id
for id, max_value in train_max_values.items():
assert (
id not in val_max_values or max_value < val_max_values[id]
), f'Max value for id {id} in train is not lower than in validation.'
else:
train_ids = {seq['id'] for seq in train_sequences}
val_ids = {seq['id'] for seq in val_sequences}

# Assert that the intersection between train_ids and val_ids is empty
assert (
len(train_ids & val_ids) == 0
), 'Train and validation datasets have overlapping ids.'
Loading

0 comments on commit b9a09f5

Please sign in to comment.