From 73f1f4c3366b995c019a09bb7107f3bef6e291e8 Mon Sep 17 00:00:00 2001 From: Alexander Stante Date: Sun, 6 Jan 2019 11:37:28 +0100 Subject: [PATCH 1/3] Add LearningRateMultiplier optimizer. --- keras_contrib/optimizers/__init__.py | 1 + keras_contrib/optimizers/lr_multiplier.py | 65 +++++++++++++++++++ .../optimizers/lr_multiplier_test.py | 11 ++++ 3 files changed, 77 insertions(+) create mode 100644 keras_contrib/optimizers/lr_multiplier.py create mode 100644 tests/keras_contrib/optimizers/lr_multiplier_test.py diff --git a/keras_contrib/optimizers/__init__.py b/keras_contrib/optimizers/__init__.py index 055d73236..66bfe9c7d 100644 --- a/keras_contrib/optimizers/__init__.py +++ b/keras_contrib/optimizers/__init__.py @@ -1,6 +1,7 @@ from .ftml import FTML from .padam import Padam from .yogi import Yogi +from .lr_multiplier import LearningRateMultiplier # aliases ftml = FTML diff --git a/keras_contrib/optimizers/lr_multiplier.py b/keras_contrib/optimizers/lr_multiplier.py new file mode 100644 index 000000000..be67a4b2e --- /dev/null +++ b/keras_contrib/optimizers/lr_multiplier.py @@ -0,0 +1,65 @@ +from keras.optimizers import Optimizer +from keras.utils import get_custom_objects + + +class LearningRateMultiplier(Optimizer): + """Optimizer wrapper for per layer learning rate. + + This wrapper is used to add per layer learning rates by + providing per layer factors which are multiplied with the + learning rate of the optimizer. + + Note: This is a wrapper and does not implement any + optimization algorithm. + + # Arguments + optimizer: An optimizer class to be wrapped. + lr_multipliers: Dictionary of the per layer factors. For + example `optimizer={'conv_1/kernel':0.5, 'conv_1/bias':0.1}`. + If for kernel and bias the same learning rate is used, the + user can specify `optimizer={'conv_1':0.5}`. + **kwargs: The arguments for instantiating the wrapped optimizer + class. + """ + def __init__(self, optimizer, lr_multipliers=None, **kwargs): + self._class = optimizer + self._optimizer = optimizer(**kwargs) + self._lr_multipliers = lr_multipliers or {} + + def _get_multiplier(self, param): + for k in self._lr_multipliers.keys(): + if k in param.name: + return self._lr_multipliers[k] + + def get_updates(self, loss, params): + mult_lr_params = {p: self._get_multiplier(p) for p in params if self._get_multiplier(p)} + base_lr_params = [p for p in params if self._get_multiplier(p) is None] + + updates = [] + base_lr = self._optimizer.lr + for param, multiplier in mult_lr_params.items(): + self._optimizer.lr = base_lr * multiplier + updates.extend(self._optimizer.get_updates(loss, [param])) + + self._optimizer.lr = base_lr + updates.extend(self._optimizer.get_updates(loss, base_lr_params)) + + return updates + + def get_config(self): + config = {'optimizer': self._class, + 'lr_multipliers': self._lr_multipliers} + base_config = self._optimizer.get_config() + return dict(list(base_config.items()) + list(config.items())) + + def __getattr__(self, name): + return getattr(self._optimizer, name) + + def __setattr__(self, name, value): + if name.startswith('_'): + super().__setattr__(name, value) + else: + self._optimizer.__setattr__(name, value) + + +get_custom_objects().update({'LearningRateMultiplier': LearningRateMultiplier}) diff --git a/tests/keras_contrib/optimizers/lr_multiplier_test.py b/tests/keras_contrib/optimizers/lr_multiplier_test.py new file mode 100644 index 000000000..1355d7ca5 --- /dev/null +++ b/tests/keras_contrib/optimizers/lr_multiplier_test.py @@ -0,0 +1,11 @@ +from keras_contrib.tests import optimizers +from keras_contrib.optimizers import LearningRateMultiplier +from keras.optimizers import SGD, Adam +from keras.callbacks import LearningRateScheduler + +def test_lr_multiplier(): + mult={'dense':10} + optimizers._test_optimizer(LearningRateMultiplier(SGD, lr=0.01, momentum=0.9, nesterov=True), target=0.95) + optimizers._test_optimizer(LearningRateMultiplier(SGD, lr_multipliers=mult, lr=0.001, + momentum=0.9, nesterov=True), target=0.95) + From cc6eaaf65a688d420c2691df827610822639cf25 Mon Sep 17 00:00:00 2001 From: Alexander Stante Date: Mon, 7 Jan 2019 20:41:43 +0100 Subject: [PATCH 2/3] Fix pep8 violations in LearningRateMultiplier. --- keras_contrib/optimizers/lr_multiplier.py | 3 ++- tests/keras_contrib/optimizers/lr_multiplier_test.py | 11 +++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/keras_contrib/optimizers/lr_multiplier.py b/keras_contrib/optimizers/lr_multiplier.py index be67a4b2e..38954bf7b 100644 --- a/keras_contrib/optimizers/lr_multiplier.py +++ b/keras_contrib/optimizers/lr_multiplier.py @@ -32,7 +32,8 @@ def _get_multiplier(self, param): return self._lr_multipliers[k] def get_updates(self, loss, params): - mult_lr_params = {p: self._get_multiplier(p) for p in params if self._get_multiplier(p)} + mult_lr_params = {p: self._get_multiplier(p) for p in params + if self._get_multiplier(p)} base_lr_params = [p for p in params if self._get_multiplier(p) is None] updates = [] diff --git a/tests/keras_contrib/optimizers/lr_multiplier_test.py b/tests/keras_contrib/optimizers/lr_multiplier_test.py index 1355d7ca5..4d0a29bac 100644 --- a/tests/keras_contrib/optimizers/lr_multiplier_test.py +++ b/tests/keras_contrib/optimizers/lr_multiplier_test.py @@ -3,9 +3,12 @@ from keras.optimizers import SGD, Adam from keras.callbacks import LearningRateScheduler + def test_lr_multiplier(): - mult={'dense':10} - optimizers._test_optimizer(LearningRateMultiplier(SGD, lr=0.01, momentum=0.9, nesterov=True), target=0.95) - optimizers._test_optimizer(LearningRateMultiplier(SGD, lr_multipliers=mult, lr=0.001, - momentum=0.9, nesterov=True), target=0.95) + opt1 = LearningRateMultiplier(SGD, lr=0.01, momentum=0.9, nesterov=True) + optimizers._test_optimizer(opt1, target=0.95) + mult = {'dense': 10} + opt2 = LearningRateMultiplier(SGD, lr_multipliers=mult, + lr=0.001, momentum=0.9, nesterov=True) + optimizers._test_optimizer(opt2, target=0.95) From 3d3cf9fbf4d58cf5c0bb42b407d62fc0047ff5d8 Mon Sep 17 00:00:00 2001 From: Alexander Stante Date: Mon, 7 Jan 2019 21:22:08 +0100 Subject: [PATCH 3/3] Fix Python2.7 compatibility in LearningRateMultiplier. --- keras_contrib/optimizers/lr_multiplier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_contrib/optimizers/lr_multiplier.py b/keras_contrib/optimizers/lr_multiplier.py index 38954bf7b..e5405b81e 100644 --- a/keras_contrib/optimizers/lr_multiplier.py +++ b/keras_contrib/optimizers/lr_multiplier.py @@ -58,7 +58,7 @@ def __getattr__(self, name): def __setattr__(self, name, value): if name.startswith('_'): - super().__setattr__(name, value) + super(LearningRateMultiplier, self).__setattr__(name, value) else: self._optimizer.__setattr__(name, value)