From 67f41d909eece796354162619620bbf987e09dcd Mon Sep 17 00:00:00 2001 From: OrigamiDream Date: Tue, 30 Nov 2021 08:09:47 +0900 Subject: [PATCH 1/3] Add support for negative subscripting for TimeseriesGenerator --- keras_preprocessing/sequence.py | 2 ++ tests/sequence_test.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/keras_preprocessing/sequence.py b/keras_preprocessing/sequence.py index 74660cef..85adbfa1 100644 --- a/keras_preprocessing/sequence.py +++ b/keras_preprocessing/sequence.py @@ -356,6 +356,8 @@ def __len__(self): self.batch_size * self.stride) // (self.batch_size * self.stride) def __getitem__(self, index): + if index < 0: + index = len(self) + index if self.shuffle: rows = np.random.randint( self.start_index, self.end_index + 1, size=self.batch_size) diff --git a/tests/sequence_test.py b/tests/sequence_test.py index 246ca664..95da2931 100644 --- a/tests/sequence_test.py +++ b/tests/sequence_test.py @@ -140,6 +140,21 @@ def test_TimeseriesGenerator_serde(): assert (data_gen.targets == recovered_gen.targets).all() +def test_TimeseriesGenerator_negative_subscript(): + data = np.array([[i] for i in range(50)]) + targets = np.array([[i] for i in range(50)]) + + data_gen = sequence.TimeseriesGenerator(data, targets, + length=10, + sampling_rate=2, + batch_size=2) + assert len(data_gen) == 20 + assert (np.allclose(data_gen[19][0], data_gen[-1][0])) + assert (np.allclose(data_gen[19][1], data_gen[-1][1])) + assert (np.allclose(data_gen[18][0], data_gen[-2][0])) + assert (np.allclose(data_gen[18][1], data_gen[-2][1])) + + def test_TimeseriesGenerator(): data = np.array([[i] for i in range(50)]) targets = np.array([[i] for i in range(50)]) From bce7a3c6af231a21018599ef7c01b0cc048d8e9e Mon Sep 17 00:00:00 2001 From: OrigamiDream Date: Tue, 30 Nov 2021 08:13:18 +0900 Subject: [PATCH 2/3] Add test cases --- tests/sequence_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/sequence_test.py b/tests/sequence_test.py index 95da2931..0df51389 100644 --- a/tests/sequence_test.py +++ b/tests/sequence_test.py @@ -154,6 +154,11 @@ def test_TimeseriesGenerator_negative_subscript(): assert (np.allclose(data_gen[18][0], data_gen[-2][0])) assert (np.allclose(data_gen[18][1], data_gen[-2][1])) + size = len(data_gen) + for i in range(1, size + 1): + assert (np.allclose(data_gen[size - i][0], data_gen[-i][0])) + assert (np.allclose(data_gen[size - i][1], data_gen[-i][1])) + def test_TimeseriesGenerator(): data = np.array([[i] for i in range(50)]) From d23ef06c0d89e80a696b0eaf75affc3eae49e331 Mon Sep 17 00:00:00 2001 From: OrigamiDream Date: Tue, 30 Nov 2021 08:43:34 +0900 Subject: [PATCH 3/3] Add negative subscripting support for Iterator --- keras_preprocessing/image/iterator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/keras_preprocessing/image/iterator.py b/keras_preprocessing/image/iterator.py index c62b1d3a..818fef17 100644 --- a/keras_preprocessing/image/iterator.py +++ b/keras_preprocessing/image/iterator.py @@ -51,6 +51,8 @@ def __getitem__(self, idx): 'but the Sequence ' 'has length {length}'.format(idx=idx, length=len(self))) + if idx < 0: + idx = len(self) + idx if self.seed is not None: np.random.seed(self.seed + self.total_batches_seen) self.total_batches_seen += 1