Skip to content

Commit

Permalink
add fractions parameter to split function
Browse files Browse the repository at this point in the history
  • Loading branch information
namsaraeva committed May 22, 2024
1 parent bd76c36 commit 1ebcc44
Showing 1 changed file with 31 additions and 17 deletions.
48 changes: 31 additions & 17 deletions src/sparcscore/ml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def combine_datasets_balanced(list_of_datasets, class_labels, train_per_class, v
return train_dataset, val_dataset, test_dataset


def split_dataset_regression(list_of_datasets, train_size, test_size, val_size, seed=None):
def split_dataset_regression(list_of_datasets, train_size, test_size, val_size, fractions=None, seed=None):
"""
Split a dataset into train, test, and validation set.
Expand All @@ -119,6 +119,8 @@ def split_dataset_regression(list_of_datasets, train_size, test_size, val_size,
Number of samples in the test set.
val_size : int
Number of samples in the validation set.
fractions : list of float
Fractions of the dataset to be used for train, test, and validation set. Should sum up to 1.
Returns
-------
Expand All @@ -132,26 +134,38 @@ def split_dataset_regression(list_of_datasets, train_size, test_size, val_size,
train_dataset = []
test_dataset = []
val_dataset = []

for dataset in list_of_datasets:
residual_size = len(dataset) - train_size - test_size - val_size
if fractions is not None:
if seed is not None:
gen = torch.Generator()
gen.manual_seed(seed)
train, test, val, _ = torch.utils.data.random_split(dataset, fractions, generator=gen)
else:
train, test, val, _ = torch.utils.data.random_split(dataset, fractions)

if residual_size < 0:
raise ValueError(
f"Dataset with length {len(dataset)} is too small to be split into test set of size {test_size}, "
f"train set of size {train_size}, and validation set of size {val_size}. "
)
train_dataset.append(train)
test_dataset.append(test)
val_dataset.append(val)

if fractions is None:
residual_size = len(dataset) - train_size - test_size - val_size
if residual_size < 0:
raise ValueError(
f"Dataset with length {len(dataset)} is too small to be split into test set of size {test_size}, "
f"train set of size {train_size}, and validation set of size {val_size}. "
)

if seed is not None:
gen = torch.Generator()
gen.manual_seed(seed)
train, test, val, _ = torch.utils.data.random_split(dataset, [train_size, test_size, val_size, residual_size], generator=gen)
else:
train, test, val, _ = torch.utils.data.random_split(dataset, [train_size, test_size, val_size, residual_size])
if seed is not None:
gen = torch.Generator()
gen.manual_seed(seed)
train, test, val, _ = torch.utils.data.random_split(dataset, [train_size, test_size, val_size, residual_size], generator=gen)
else:
train, test, val, _ = torch.utils.data.random_split(dataset, [train_size, test_size, val_size, residual_size])

train_dataset.append(train)
test_dataset.append(test)
val_dataset.append(val)
train_dataset.append(train)
test_dataset.append(test)
val_dataset.append(val)

train_dataset = torch.utils.data.ConcatDataset(train_dataset)
test_dataset = torch.utils.data.ConcatDataset(test_dataset)
Expand Down

0 comments on commit 1ebcc44

Please sign in to comment.