forked from microsoft/torchgeo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Digital typhoon dataset (microsoft#1748)
* analysis task dataset * implement sequence sampling * add outline datamodule * add datamodule with two way splitting capabilities * add plotting function * download and verify * add unit tests but they fail * fix tests * fix style * trainer testing yaml * test split logic * fix tests * fix tests2 * found bug * try to fix mypy * h5py error docs * fix docs * fix one mypy error * mypy on test file * fix coverage * fix tests for trainers * fix mypy * try typed dict * try to fix docs * fix pytest * linters * suggested changes and normalization procedure * regression target normalization * update dataset splitting * fix test * quotes * strings * ruff * quotes * ruff format on all * docs * lazy import * h5py * h5py datamodule * typo * tests * review * pass tests * fix tests * list -> tuple * mypy fix * rename * tests * Remove Analysis * min pandas 2.2.0 * resolve tests --------- Co-authored-by: Adam J. Stewart <[email protected]>
- Loading branch information
1 parent
e8e309f
commit b9a09f5
Showing
47 changed files
with
954 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
model: | ||
class_path: RegressionTask | ||
init_args: | ||
model: "resnet18" | ||
num_outputs: 1 | ||
in_channels: 3 | ||
loss: "mse" | ||
data: | ||
class_path: DigitalTyphoonDataModule | ||
init_args: | ||
batch_size: 1 | ||
split_by: "typhoon_id" | ||
dict_kwargs: | ||
root: "tests/data/digital_typhoon" | ||
download: true | ||
min_feature_value: | ||
wind: 10 | ||
sequence_length: 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
model: | ||
class_path: RegressionTask | ||
init_args: | ||
model: "resnet18" | ||
num_outputs: 1 | ||
in_channels: 3 | ||
loss: "mse" | ||
data: | ||
class_path: DigitalTyphoonDataModule | ||
init_args: | ||
batch_size: 1 | ||
split_by: "time" | ||
dict_kwargs: | ||
root: "tests/data/digital_typhoon" | ||
download: true | ||
min_feature_value: | ||
wind: 10 | ||
sequence_length: 3 |
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct | ||
0,0.h5,1979,12,25,6,3,-55.81114066899345,76.6995939240727,973.8743108424701,44.98399850309952,66,71,75,137,25,95,1,1,1.h5,mask_40,89.87979469874404 | ||
0,1.h5,1979,12,25,7,3,-33.621634184914114,-25.860702927919903,903.8203398162416,6.3832427352565,230,28,61,111,72,4,1,0,2.h5,mask_40,55.86768840838465 | ||
0,2.h5,1979,12,25,8,3,72.02964248591297,-47.48138416430828,982.76724331446,0.027966770724696666,342,76,23,337,49,19,1,0,3.h5,mask_49,55.18449786430531 | ||
0,3.h5,1979,12,25,9,2,55.920575184851316,13.989913225833078,906.0181106433341,51.01642134825744,330,90,52,258,44,65,1,1,4.h5,mask_86,15.969129252036707 | ||
0,4.h5,1979,12,25,10,2,-43.28994147714503,-161.94483446959413,903.9366550400755,16.7093617045847,242,62,99,132,63,0,1,1,5.h5,mask_66,70.21971067939033 | ||
1,0.h5,1988,1,22,10,2,-33.37129190053344,-115.29637290040873,948.0758912152131,51.11399505734963,118,15,67,232,63,86,1,1,1.h5,mask_15,30.245077213336646 | ||
1,1.h5,1988,1,22,11,2,74.93228846926493,70.74999801636073,910.1992664115785,60.8348103266534,266,41,67,48,44,16,1,0,2.h5,mask_90,42.30390416164944 | ||
1,2.h5,1988,1,22,12,2,-27.931601464223597,-141.3019006863473,961.5531323907394,18.35497901874176,19,61,24,295,50,26,1,1,3.h5,mask_67,60.35785307941444 | ||
1,3.h5,1988,1,22,13,3,-27.166703710913154,-27.976214499674484,904.1165949703977,9.081723951290567,144,43,66,22,32,48,0,1,4.h5,mask_3,80.04417033291257 | ||
1,4.h5,1988,1,22,14,2,47.51657289770864,-138.58539565379158,950.9654977977864,86.18819130981862,175,75,89,42,19,70,0,1,5.h5,mask_96,0.44001778199053154 | ||
2,0.h5,1998,8,23,22,2,71.11037770397022,-170.05883586527145,902.757696015989,64.83605229043086,308,32,54,249,94,13,1,0,1.h5,mask_87,97.96789767456457 | ||
2,1.h5,1998,8,23,23,2,-45.9880469141837,-153.85203885662787,956.1578736191437,95.77226625568278,230,17,58,214,72,21,1,0,2.h5,mask_66,48.1513473689529 | ||
2,2.h5,1998,8,24,0,4,-88.778300647409,-78.43060469893915,958.764771469677,17.97662971655637,127,41,19,138,89,36,1,1,3.h5,mask_57,76.31799924098371 | ||
2,3.h5,1998,8,24,1,2,-49.56689955810804,-120.3389762632577,986.4933451650326,49.259894810485605,333,90,28,51,45,99,1,0,4.h5,mask_92,65.60333971250041 | ||
2,4.h5,1998,8,24,2,3,-52.55231579306487,80.06217230886841,997.4333837891787,48.25976623703225,63,7,13,71,55,58,1,1,5.h5,mask_73,50.634737551399034 | ||
3,0.h5,1997,4,24,16,4,-61.81374526076493,60.62026564332362,900.1093638487514,94.66595722320622,189,70,67,249,12,58,0,1,1.h5,mask_93,99.77561346276104 | ||
3,1.h5,1997,4,24,17,3,35.596382297289026,-117.20301531275722,925.1366339770796,34.46028512732848,55,55,74,11,0,49,1,1,2.h5,mask_11,5.726401727423658 | ||
3,2.h5,1997,4,24,18,1,68.16880747309938,30.42194122117013,955.7265683876137,96.55057639044118,217,22,60,6,18,9,1,1,3.h5,mask_63,58.982331802755375 | ||
3,3.h5,1997,4,24,19,3,-5.491619122910365,141.83240318855258,922.5486496962513,89.2199247408618,49,26,14,245,95,84,1,0,4.h5,mask_38,76.01607012923168 | ||
3,4.h5,1997,4,24,20,4,4.052162855787202,21.732867986138842,990.5791999912764,98.40094253121877,158,86,11,28,11,81,0,0,5.h5,mask_12,75.84036894650622 | ||
4,0.h5,1984,6,16,14,3,53.238650326925125,-54.63854263302531,934.2198641027621,18.697921579520305,212,16,42,91,90,56,1,1,1.h5,mask_72,78.93081269669048 | ||
4,1.h5,1984,6,16,15,2,-56.222689844694024,-6.8726887962189664,912.6113238303491,61.286246561868666,60,81,2,198,64,76,1,0,2.h5,mask_64,24.039173626000288 | ||
4,2.h5,1984,6,16,16,2,-4.285643464886363,95.66534210331434,962.0580147775602,86.01251389789185,281,81,5,228,18,94,0,0,3.h5,mask_66,89.89080488339964 | ||
4,3.h5,1984,6,16,17,2,89.15893201203946,124.94143678744513,997.342814284227,84.00590505469005,242,28,61,132,80,29,0,0,4.h5,mask_77,4.839048143310343 | ||
4,4.h5,1984,6,16,18,1,-46.31233638346047,21.77073986978661,932.8378121656477,26.18973887839292,294,76,57,252,99,27,1,0,5.h5,mask_65,89.74882055138497 |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct | ||
0,0.h5,1979,12,25,6,3,-55.81114066899345,76.6995939240727,973.8743108424701,44.98399850309952,66,71,75,137,25,95,1,1,1.h5,mask_40,89.87979469874404 | ||
0,1.h5,1979,12,25,7,3,-33.621634184914114,-25.860702927919903,903.8203398162416,6.3832427352565,230,28,61,111,72,4,1,0,2.h5,mask_40,55.86768840838465 | ||
0,2.h5,1979,12,25,8,3,72.02964248591297,-47.48138416430828,982.76724331446,0.027966770724696666,342,76,23,337,49,19,1,0,3.h5,mask_49,55.18449786430531 | ||
0,3.h5,1979,12,25,9,2,55.920575184851316,13.989913225833078,906.0181106433341,51.01642134825744,330,90,52,258,44,65,1,1,4.h5,mask_86,15.969129252036707 | ||
0,4.h5,1979,12,25,10,2,-43.28994147714503,-161.94483446959413,903.9366550400755,16.7093617045847,242,62,99,132,63,0,1,1,5.h5,mask_66,70.21971067939033 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct | ||
1,0.h5,1988,1,22,10,2,-33.37129190053344,-115.29637290040873,948.0758912152131,51.11399505734963,118,15,67,232,63,86,1,1,1.h5,mask_15,30.245077213336646 | ||
1,1.h5,1988,1,22,11,2,74.93228846926493,70.74999801636073,910.1992664115785,60.8348103266534,266,41,67,48,44,16,1,0,2.h5,mask_90,42.30390416164944 | ||
1,2.h5,1988,1,22,12,2,-27.931601464223597,-141.3019006863473,961.5531323907394,18.35497901874176,19,61,24,295,50,26,1,1,3.h5,mask_67,60.35785307941444 | ||
1,3.h5,1988,1,22,13,3,-27.166703710913154,-27.976214499674484,904.1165949703977,9.081723951290567,144,43,66,22,32,48,0,1,4.h5,mask_3,80.04417033291257 | ||
1,4.h5,1988,1,22,14,2,47.51657289770864,-138.58539565379158,950.9654977977864,86.18819130981862,175,75,89,42,19,70,0,1,5.h5,mask_96,0.44001778199053154 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct | ||
2,0.h5,1998,8,23,22,2,71.11037770397022,-170.05883586527145,902.757696015989,64.83605229043086,308,32,54,249,94,13,1,0,1.h5,mask_87,97.96789767456457 | ||
2,1.h5,1998,8,23,23,2,-45.9880469141837,-153.85203885662787,956.1578736191437,95.77226625568278,230,17,58,214,72,21,1,0,2.h5,mask_66,48.1513473689529 | ||
2,2.h5,1998,8,24,0,4,-88.778300647409,-78.43060469893915,958.764771469677,17.97662971655637,127,41,19,138,89,36,1,1,3.h5,mask_57,76.31799924098371 | ||
2,3.h5,1998,8,24,1,2,-49.56689955810804,-120.3389762632577,986.4933451650326,49.259894810485605,333,90,28,51,45,99,1,0,4.h5,mask_92,65.60333971250041 | ||
2,4.h5,1998,8,24,2,3,-52.55231579306487,80.06217230886841,997.4333837891787,48.25976623703225,63,7,13,71,55,58,1,1,5.h5,mask_73,50.634737551399034 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct | ||
3,0.h5,1997,4,24,16,4,-61.81374526076493,60.62026564332362,900.1093638487514,94.66595722320622,189,70,67,249,12,58,0,1,1.h5,mask_93,99.77561346276104 | ||
3,1.h5,1997,4,24,17,3,35.596382297289026,-117.20301531275722,925.1366339770796,34.46028512732848,55,55,74,11,0,49,1,1,2.h5,mask_11,5.726401727423658 | ||
3,2.h5,1997,4,24,18,1,68.16880747309938,30.42194122117013,955.7265683876137,96.55057639044118,217,22,60,6,18,9,1,1,3.h5,mask_63,58.982331802755375 | ||
3,3.h5,1997,4,24,19,3,-5.491619122910365,141.83240318855258,922.5486496962513,89.2199247408618,49,26,14,245,95,84,1,0,4.h5,mask_38,76.01607012923168 | ||
3,4.h5,1997,4,24,20,4,4.052162855787202,21.732867986138842,990.5791999912764,98.40094253121877,158,86,11,28,11,81,0,0,5.h5,mask_12,75.84036894650622 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct | ||
4,0.h5,1984,6,16,14,3,53.238650326925125,-54.63854263302531,934.2198641027621,18.697921579520305,212,16,42,91,90,56,1,1,1.h5,mask_72,78.93081269669048 | ||
4,1.h5,1984,6,16,15,2,-56.222689844694024,-6.8726887962189664,912.6113238303491,61.286246561868666,60,81,2,198,64,76,1,0,2.h5,mask_64,24.039173626000288 | ||
4,2.h5,1984,6,16,16,2,-4.285643464886363,95.66534210331434,962.0580147775602,86.01251389789185,281,81,5,228,18,94,0,0,3.h5,mask_66,89.89080488339964 | ||
4,3.h5,1984,6,16,17,2,89.15893201203946,124.94143678744513,997.342814284227,84.00590505469005,242,28,61,132,80,29,0,0,4.h5,mask_77,4.839048143310343 | ||
4,4.h5,1984,6,16,18,1,-46.31233638346047,21.77073986978661,932.8378121656477,26.18973887839292,294,76,57,252,99,27,1,0,5.h5,mask_65,89.74882055138497 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import os | ||
import shutil | ||
|
||
import h5py | ||
import numpy as np | ||
import pandas as pd | ||
from torchvision.datasets.utils import calculate_md5 | ||
|
||
# Define the root directory | ||
root = 'WP' | ||
IMAGE_SIZE = 32 | ||
NUM_TYHOON_IDS = 5 | ||
NUM_IMAGES_PER_ID = 5 | ||
CHUNK_SIZE = 2**12 | ||
|
||
# If the root directory exists, remove it | ||
if os.path.exists(root): | ||
shutil.rmtree(root) | ||
|
||
# Create the 'image' and 'metadata' directories | ||
os.makedirs(os.path.join(root, 'image')) | ||
os.makedirs(os.path.join(root, 'metadata')) | ||
|
||
# For each typhoon_id | ||
all_dfs = [] | ||
for typhoon_id in range(NUM_TYHOON_IDS): | ||
# Create a directory under 'root/image/typhoon_id/' | ||
os.makedirs(os.path.join(root, 'image', str(typhoon_id)), exist_ok=True) | ||
|
||
# Create dummy .h5 files | ||
image_paths_per_typhoon = [] | ||
for image_id in range(NUM_IMAGES_PER_ID): | ||
image_file_name = f'{image_id}.h5' | ||
with h5py.File( | ||
os.path.join(root, 'image', str(typhoon_id), image_file_name), 'w' | ||
) as hf: | ||
hf.create_dataset('Infrared', data=np.random.rand(IMAGE_SIZE, IMAGE_SIZE)) | ||
image_paths_per_typhoon.append(image_file_name) | ||
|
||
start_time = pd.Timestamp( | ||
year=np.random.randint(1978, 2022), | ||
month=np.random.randint(1, 13), | ||
day=np.random.randint(1, 29), | ||
hour=np.random.randint(0, 24), | ||
) | ||
times = pd.date_range(start=start_time, periods=NUM_IMAGES_PER_ID, freq='H') | ||
df = pd.DataFrame( | ||
{ | ||
'id': np.repeat(typhoon_id, NUM_IMAGES_PER_ID), | ||
'image_path': image_paths_per_typhoon, | ||
'year': times.year, | ||
'month': times.month, | ||
'day': times.day, | ||
'hour': times.hour, | ||
'grade': np.random.randint(1, 5, NUM_IMAGES_PER_ID), | ||
'lat': np.random.uniform(-90, 90, NUM_IMAGES_PER_ID), | ||
'lng': np.random.uniform(-180, 180, NUM_IMAGES_PER_ID), | ||
'pressure': np.random.uniform(900, 1000, NUM_IMAGES_PER_ID), | ||
'wind': np.random.uniform(0, 100, NUM_IMAGES_PER_ID), | ||
'dir50': np.random.randint(0, 360, NUM_IMAGES_PER_ID), | ||
'long50': np.random.randint(0, 100, NUM_IMAGES_PER_ID), | ||
'short50': np.random.randint(0, 100, NUM_IMAGES_PER_ID), | ||
'dir30': np.random.randint(0, 360, NUM_IMAGES_PER_ID), | ||
'long30': np.random.randint(0, 100, NUM_IMAGES_PER_ID), | ||
'short30': np.random.randint(0, 100, NUM_IMAGES_PER_ID), | ||
'landfall': np.random.randint(0, 2, NUM_IMAGES_PER_ID), | ||
'intp': np.random.randint(0, 2, NUM_IMAGES_PER_ID), | ||
'file_1': [f'{idx}.h5' for idx in range(1, NUM_IMAGES_PER_ID + 1)], | ||
'mask_1': [ | ||
'mask_' + str(i) for i in np.random.randint(1, 100, NUM_IMAGES_PER_ID) | ||
], | ||
'mask_1_pct': np.random.uniform(0, 100, NUM_IMAGES_PER_ID), | ||
} | ||
) | ||
|
||
# Save the DataFrame to correspoding typhoon id as metadata | ||
df.to_csv(os.path.join(root, 'metadata', f'{typhoon_id}.csv'), index=False) | ||
|
||
all_dfs.append(df) | ||
|
||
# Save the aux_data.csv | ||
aux_data = pd.concat(all_dfs) | ||
aux_data.to_csv(os.path.join(root, 'aux_data.csv'), index=False) | ||
|
||
|
||
# Create tarball | ||
shutil.make_archive(root, 'gztar', '.', root) | ||
|
||
# simulate multiple tar files | ||
path = f'{root}.tar.gz' | ||
paths = [] | ||
with open(path, 'rb') as f: | ||
# Write the entire tarball to gzaa | ||
split = f'{path}aa' | ||
with open(split, 'wb') as g: | ||
g.write(f.read()) | ||
paths.append(split) | ||
|
||
# Create gzab as a copy of gzaa | ||
shutil.copy2(f'{path}aa', f'{path}ab') | ||
paths.append(f'{path}ab') | ||
|
||
|
||
# Calculate the md5sum of the tar file | ||
for path in paths: | ||
print(f'{path}: {calculate_md5(path)}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
"""Test Digital Typhoon Datamodule.""" | ||
|
||
import os | ||
|
||
import pytest | ||
|
||
from torchgeo.datamodules import DigitalTyphoonDataModule | ||
from torchgeo.datasets.digital_typhoon import DigitalTyphoon, _SampleSequenceDict | ||
|
||
pytest.importorskip('h5py', minversion='3.6') | ||
|
||
|
||
class TestDigitalTyphoonDataModule: | ||
def test_invalid_param_config(self) -> None: | ||
with pytest.raises(AssertionError, match='Please choose from'): | ||
DigitalTyphoonDataModule( | ||
root=os.path.join('tests', 'data', 'digital_typhoon'), | ||
split_by='invalid', | ||
batch_size=2, | ||
num_workers=0, | ||
) | ||
|
||
@pytest.mark.parametrize('split_by', ['time', 'typhoon_id']) | ||
def test_split_dataset(self, split_by: str) -> None: | ||
dm = DigitalTyphoonDataModule( | ||
root=os.path.join('tests', 'data', 'digital_typhoon'), | ||
split_by=split_by, | ||
batch_size=2, | ||
num_workers=0, | ||
) | ||
dataset = DigitalTyphoon(root=os.path.join('tests', 'data', 'digital_typhoon')) | ||
train_indices, val_indices = dm._split_dataset(dataset.sample_sequences) | ||
train_sequences, val_sequences = ( | ||
[dataset.sample_sequences[i] for i in train_indices], | ||
[dataset.sample_sequences[i] for i in val_indices], | ||
) | ||
|
||
if split_by == 'time': | ||
|
||
def find_max_time_per_id( | ||
split_sequences: list[_SampleSequenceDict], | ||
) -> dict[str, int]: | ||
# Find the maximum value of each id in train_sequences | ||
max_values: dict[str, int] = {} | ||
for seq in split_sequences: | ||
id: str = str(seq['id']) | ||
value: int = max(seq['seq_id']) | ||
if id not in max_values or value > max_values[id]: | ||
max_values[id] = value | ||
return max_values | ||
|
||
train_max_values = find_max_time_per_id(train_sequences) | ||
val_max_values = find_max_time_per_id(val_sequences) | ||
# Assert that each max value in train_max_values is lower | ||
# than in val_max_values for each key id | ||
for id, max_value in train_max_values.items(): | ||
assert ( | ||
id not in val_max_values or max_value < val_max_values[id] | ||
), f'Max value for id {id} in train is not lower than in validation.' | ||
else: | ||
train_ids = {seq['id'] for seq in train_sequences} | ||
val_ids = {seq['id'] for seq in val_sequences} | ||
|
||
# Assert that the intersection between train_ids and val_ids is empty | ||
assert ( | ||
len(train_ids & val_ids) == 0 | ||
), 'Train and validation datasets have overlapping ids.' |
Oops, something went wrong.