Skip to content

Commit

Permalink
Fixing ExponentialBS rounding errors for small gamma (#13)
Browse files Browse the repository at this point in the history
* Previous behavior: when the exponential factor `gamma` used by `ExponentialBS` is close to 1, the batch size would remain constant.
* Fixed this by maintaining a float value for the batch size as an internal state.
* Updated tests.
* Added new plots.
* Bumping version.
  • Loading branch information
ancestor-mithril authored Jun 5, 2024
1 parent 2bed4ae commit 11be404
Show file tree
Hide file tree
Showing 29 changed files with 14 additions and 8 deletions.
8 changes: 7 additions & 1 deletion bs_scheduler/batch_size_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,7 @@ def __init__(self, dataloader: DataLoader, gamma: float, batch_size_manager: Uni
assert gamma > 0.0
# Gamma is expected to be greater than 1.0 for batch size growth. It can be lower than 1.0 for batch size decay.
self.gamma: float = gamma
self.float_bs: Union[float, None] = None
super().__init__(dataloader, batch_size_manager, max_batch_size, min_batch_size, verbose)

def get_new_bs(self) -> int:
Expand All @@ -674,7 +675,12 @@ def get_new_bs(self) -> int:
if self.last_epoch == 0:
return self.batch_size

return rint(self.batch_size * self.gamma)
if self.float_bs is None or rint(self.float_bs) != self.batch_size:
# Using rint instead of int because otherwise we will increas the BS faster
self.float_bs = self.batch_size

self.float_bs *= self.gamma
return rint(self.float_bs)


class SequentialBS(BSScheduler):
Expand Down
Binary file modified docs/img/plots/ChainedBSScheduler.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/ChainedScheduler.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/ConstantBS.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/ConstantLR.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/CosineAnnealingBS.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/CosineAnnealingBSWithWarmRestarts.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/CosineAnnealingLR.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/CosineAnnealingWarmRestarts.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/CyclicBS-exp_range.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/CyclicBS-triangular2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/CyclicLR-exp_range.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/CyclicLR-triangular2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/ExponentialBS.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/ExponentialLR.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/IncreaseBSOnPlateau.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/LinearBS.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/LinearLR.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/MultiStepBS.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/MultiStepLR.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/OneCycleBS.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/OneCycleLR.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/ReduceLROnPlateau.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/StepBS.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/plots/StepLR.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "bs_scheduler"
version = "0.4.1"
version = "0.4.2"
requires-python = ">=3.9"
description = "A PyTorch Dataloader compatible batch size scheduler library."
readme = "README.md"
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ChainedBSScheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_dataloader_lengths(self):
n_epochs = 10

epoch_lengths = simulate_n_epochs(dataloader, scheduler, n_epochs)
expected_batch_sizes = [100, 110, 121, 133, 14, 15, 16, 18, 20, 22]
expected_batch_sizes = [100, 110, 121, 133, 14, 16, 17, 19, 21, 23]
expected_lengths = self.compute_epoch_lengths(expected_batch_sizes, len(self.dataset), drop_last=False)
self.assertEqual(epoch_lengths, expected_lengths)

Expand All @@ -37,7 +37,7 @@ def test_dataloader_batch_size(self):
n_epochs = 10

batch_sizes = get_batch_sizes_across_epochs(dataloader, scheduler, n_epochs)
expected_batch_sizes = [100, 110, 121, 133, 14, 15, 16, 18, 20, 22]
expected_batch_sizes = [100, 110, 121, 133, 14, 16, 17, 19, 21, 23]

self.assertEqual(batch_sizes, expected_batch_sizes)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_ExponentialBS.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def test_graphic(self):

base_batch_size = 10
dataloader = create_dataloader(self.dataset, batch_size=base_batch_size)
scheduler = ExponentialBS(dataloader, gamma=2, max_batch_size=100, verbose=False)
n_epochs = 10
scheduler = ExponentialBS(dataloader, gamma=1.05, max_batch_size=500, verbose=False)
n_epochs = 50

batch_sizes = get_batch_sizes_across_epochs(dataloader, scheduler, n_epochs)
plt.plot(batch_sizes)
Expand All @@ -64,7 +64,7 @@ def test_graphic(self):

model = torch.nn.Linear(10, 10)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.5)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
learning_rates = []

def get_lr(optimizer):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_SequentialBS.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_dataloader_lengths(self):
n_epochs = 10

epoch_lengths = simulate_n_epochs(dataloader, scheduler, n_epochs)
expected_batch_sizes = [100] * 4 + [10, 11, 12, 13, 14, 15]
expected_batch_sizes = [100] * 4 + [10, 11, 12, 13, 15, 16]
expected_lengths = self.compute_epoch_lengths(expected_batch_sizes, len(self.dataset), drop_last=False)
self.assertEqual(epoch_lengths, expected_lengths)

Expand Down

0 comments on commit 11be404

Please sign in to comment.