-
Notifications
You must be signed in to change notification settings - Fork 650
Add LearningRateMultiplier wrapper for optimizers #396
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
from .ftml import FTML | ||
from .padam import Padam | ||
from .yogi import Yogi | ||
from .lr_multiplier import LearningRateMultiplier | ||
|
||
# aliases | ||
ftml = FTML |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from keras.optimizers import Optimizer | ||
from keras.utils import get_custom_objects | ||
|
||
|
||
class LearningRateMultiplier(Optimizer): | ||
"""Optimizer wrapper for per layer learning rate. | ||
|
||
This wrapper is used to add per layer learning rates by | ||
providing per layer factors which are multiplied with the | ||
learning rate of the optimizer. | ||
|
||
Note: This is a wrapper and does not implement any | ||
optimization algorithm. | ||
|
||
# Arguments | ||
optimizer: An optimizer class to be wrapped. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think optimizer should be an optimizer instance, not an optimizer class. Let's minimize the hackyness. |
||
lr_multipliers: Dictionary of the per layer factors. For | ||
example `optimizer={'conv_1/kernel':0.5, 'conv_1/bias':0.1}`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typo: the keyword is lr_multipliers. |
||
If for kernel and bias the same learning rate is used, the | ||
user can specify `optimizer={'conv_1':0.5}`. | ||
**kwargs: The arguments for instantiating the wrapped optimizer | ||
class. | ||
""" | ||
def __init__(self, optimizer, lr_multipliers=None, **kwargs): | ||
self._class = optimizer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think underscores are needed. |
||
self._optimizer = optimizer(**kwargs) | ||
self._lr_multipliers = lr_multipliers or {} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should call |
||
def _get_multiplier(self, param): | ||
for k in self._lr_multipliers.keys(): | ||
if k in param.name: | ||
return self._lr_multipliers[k] | ||
|
||
def get_updates(self, loss, params): | ||
mult_lr_params = {p: self._get_multiplier(p) for p in params | ||
if self._get_multiplier(p)} | ||
base_lr_params = [p for p in params if self._get_multiplier(p) is None] | ||
|
||
updates = [] | ||
base_lr = self._optimizer.lr | ||
for param, multiplier in mult_lr_params.items(): | ||
self._optimizer.lr = base_lr * multiplier | ||
updates.extend(self._optimizer.get_updates(loss, [param])) | ||
|
||
self._optimizer.lr = base_lr | ||
updates.extend(self._optimizer.get_updates(loss, base_lr_params)) | ||
|
||
return updates | ||
|
||
def get_config(self): | ||
config = {'optimizer': self._class, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since optimizer will be an instance of the class optimizer, you should use the function |
||
'lr_multipliers': self._lr_multipliers} | ||
base_config = self._optimizer.get_config() | ||
return dict(list(base_config.items()) + list(config.items())) | ||
|
||
def __getattr__(self, name): | ||
return getattr(self._optimizer, name) | ||
|
||
def __setattr__(self, name, value): | ||
if name.startswith('_'): | ||
super(LearningRateMultiplier, self).__setattr__(name, value) | ||
else: | ||
self._optimizer.__setattr__(name, value) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think You'll likely have to have a |
||
|
||
|
||
get_custom_objects().update({'LearningRateMultiplier': LearningRateMultiplier}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from keras_contrib.tests import optimizers | ||
from keras_contrib.optimizers import LearningRateMultiplier | ||
from keras.optimizers import SGD, Adam | ||
from keras.callbacks import LearningRateScheduler | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unused import |
||
|
||
|
||
def test_lr_multiplier(): | ||
opt1 = LearningRateMultiplier(SGD, lr=0.01, momentum=0.9, nesterov=True) | ||
optimizers._test_optimizer(opt1, target=0.95) | ||
|
||
mult = {'dense': 10} | ||
opt2 = LearningRateMultiplier(SGD, lr_multipliers=mult, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you make a second function |
||
lr=0.001, momentum=0.9, nesterov=True) | ||
optimizers._test_optimizer(opt2, target=0.95) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We'll also need a third test And a fourth test with a more complex optimizer (ADAM would be a good fit) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about two examples?
{'conv_1/kernel':0.5, 'conv_1/bias':0.1}
layer.name
as the key of the dictionary.