Skip to content

Commit

Permalink
Fix a bug with the definition of max_shufflable_size (#164)
Browse files Browse the repository at this point in the history
This commit fixes this issue by reassigning `max_shufflable_size` to be
the min of train and test or val and test, depending on which is
appropriate.
  • Loading branch information
JacksonBurns authored Nov 27, 2023
2 parents 2499f41 + 73fa378 commit 7df4ce1
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 3 deletions.
2 changes: 1 addition & 1 deletion astartes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# convenience import to enable 'from astartes import train_test_split'
from .main import train_test_split, train_val_test_split

__version__ = "1.1.4"
__version__ = "1.1.5"

# DO NOT do this:
# from .molecules import train_test_split_molecules
Expand Down
11 changes: 9 additions & 2 deletions astartes/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,22 @@ def _extrapolative_sampling(
calls: return_helper
"""
# calculate "goal" splitting sizes
n_test_samples = floor(len(sampler_instance.X) * test_size)
n_train_samples = floor(len(sampler_instance.X) * train_size)
n_val_samples = floor(len(sampler_instance.X) * val_size)
n_test_samples = floor(len(sampler_instance.X) * test_size)

if val_size == 0:
max_shufflable_size = min(n_train_samples, n_test_samples)
else:
# typically, the test set and val set are smaller than the training set
max_shufflable_size = min(n_test_samples, n_val_samples)
# unlike interpolative, cannot calculate n_train_samples here
# since it will vary based on cluster_lengths

# largest clusters must go into largest set, but smaller ones can optionally
# be shuffled
cluster_counter = sampler_instance.get_sorted_cluster_counter(
max_shufflable_size=min(n_test_samples, n_val_samples)
max_shufflable_size=max_shufflable_size
if random_state is not None
else None
)
Expand Down
34 changes: 34 additions & 0 deletions test/functional/test_astartes.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,40 @@ def test_return_indices_with_validation(self):
),
)

def test_extrapolative_shuffling(self):
"""extrapolative samplers should split data differently with different random_state"""
result_1 = train_test_split(
self.X,
self.y,
labels=self.labels,
test_size=0.7,
train_size=0.3,
sampler="kmeans",
random_state=42,
return_indices=True,
hopts=dict(
n_clusters=6,
),
)
result_2 = train_test_split(
self.X,
self.y,
labels=self.labels,
test_size=0.7,
train_size=0.3,
sampler="kmeans",
random_state=41,
return_indices=True,
hopts=dict(
n_clusters=6,
),
)
for arr_1, arr_2 in zip(result_1, result_2):
self.assertFalse(
np.array_equal(arr_1, arr_2),
"random_state did not result in different splits",
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 7df4ce1

Please sign in to comment.