Skip to content

Commit

Permalink
TF2.2 compatib, usage changes, code cleanup
Browse files Browse the repository at this point in the history
**FEATURES**:
 - Compatibility with TF 2.2 (other versions still compatible, but no longer tested)
 - `eta_t` now behaves deterministically, updating after `t_cur` (previously, behavior was semi-random)
 - Lots of code cleanup

**USAGE NOTES**:
 - `t_cur` should now be set to `-1` instead of `0` to restart `eta_t`
 - `t_cur` should now be set at `iters == total_iterations - 2`; explanation [here](https://github.com/OverLordGoldDragon/keras-adamw/blob/master/tests/test_optimizers.py#L53)
 - `total_iterations` must now be `> 1`, instead of only `> 0`
 - `total_iterations <= 1` will force `weight_decays` and `lr_multipliers` to `None`

**FIXES**:
 - Optimizers will no longer zero layer penalties if weight decays cannot be applied (i.e. `total_iterations` is not `> 1`)
 - `eta_t` is now properly updated as a `tf.Variable`, instead of being an update `tf.Tensor`
 - Testing didn't actually include Eager in last version - now does

**BREAKING**:
 - `utils_225tf.py` removed
 - `utils_common.py` removed
 - `optimizers_tfpy.py` removed
 - `utils.py` code is now that of `utils_225tf.py`
 - `utils_common.py` merged with `utils.py`
 - `self.batch_size` is now an `int`, instead of `tf.Variable`

**MISC**:
 - `tests`: `/test_optimizers`, `/test_optimizers_225`, `/test_optimizers_225tf`, `test_optimizers_v2`, `test_optimizers_tfpy` removed
 - All tests now done in single file: `tests/test_optimizers.py`
 - `_update_t_cur_eta_t` and `_update_t_cur_eta_t_apply_lr_mult` added to `utils.py`
 - Updated `examples.py` and related parts in README
  • Loading branch information
OverLordGoldDragon authored May 25, 2020
1 parent c5f1e8f commit d55194b
Show file tree
Hide file tree
Showing 28 changed files with 785 additions and 2,515 deletions.
9 changes: 4 additions & 5 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@ python:
- "3.6"
env:
global:
- TF_PYTHON="0"
- TF_EAGER="0"
- TF_KERAS="0"
matrix:
- TF_VERSION="1.14.0" KERAS_VERSION="2.2.5"
- TF_VERSION="1.14.0" KERAS_VERSION="2.2.5" TF_KERAS="1"
- TF_VERSION="2.1.0" KERAS_VERSION="2.3.0" TF_EAGER="1"
- TF_VERSION="2.1.0" KERAS_VERSION="2.3.0"
- TF_VERSION="2.1.0" KERAS_VERSION="2.3.0" TF_KERAS="1" TF_EAGER="1"
- TF_VERSION="2.1.0" KERAS_VERSION="2.3.0" TF_KERAS="1"
- TF_VERSION="2.2.0" KERAS_VERSION="2.3.0" TF_EAGER="1"
- TF_VERSION="2.2.0" KERAS_VERSION="2.3.0"
- TF_VERSION="2.2.0" KERAS_VERSION="2.3.0" TF_KERAS="1" TF_EAGER="1"
- TF_VERSION="2.2.0" KERAS_VERSION="2.3.0" TF_KERAS="1"

notifications:
email: false
Expand Down
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ from keras_adamw import AdamW
ipt = Input(shape=(120, 4))
x = LSTM(60, activation='relu', name='lstm_1',
kernel_regularizer=l1(1e-4), recurrent_regularizer=l2(2e-4))(ipt)
out = Dense(1, activation='sigmoid', kernel_regularizer=l1_l2(1e-4))(x)
out = Dense(1, activation='sigmoid', kernel_regularizer=l1_l2(1e-4, 2e-4))(x)
model = Model(ipt, out)
```
```python
Expand All @@ -98,8 +98,9 @@ for epoch in range(3):
y = np.random.randint(0, 2, (10, 1)) # dummy labels
loss = model.train_on_batch(x, y)
print("Iter {} loss: {}".format(iteration + 1, "%.3f" % loss))
if iteration == (24 - 2):
K.set_value(model.optimizer.t_cur, -1) # WARM RESTART: reset cosine annealing argument
print("EPOCH {} COMPLETED\n".format(epoch + 1))
K.set_value(model.optimizer.t_cur, 0) # WARM RESTART: reset cosine annealing argument
```
<img src="https://user-images.githubusercontent.com/16495490/65729113-2063d400-e08b-11e9-8b6a-3a2ea1c62fdd.png" width="450">

Expand All @@ -112,7 +113,8 @@ for epoch in range(3):
- `total_iterations_wd` --> set to normalize over _all epochs_ (or other interval `!= total_iterations`) instead of per-WR when using WR; may _sometimes_ yield better results --_My note_

### Warm restarts
- Set `t_cur = 0` to restart schedule multiplier (see _Example_). Can be done at compilation or during training. Non-`0` is also valid, and will start `eta_t` at another point on the cosine curve. Details in A-2,3
- Set `t_cur = -1` to restart schedule multiplier (see _Example_). Can be done at compilation or during training. Non-`-1` is also valid, and will start `eta_t` at another point on the cosine curve. Details in A-2,3
- `t_cur` should be set at `iter == total_iterations - 2`; explanation [here](https://github.com/OverLordGoldDragon/keras-adamw/blob/master/tests/test_optimizers.py#L53)
- Set `total_iterations` to the # of expected weight updates _for the given restart_ --_Authors_ (A-1,2)
- `eta_min=0, eta_max=1` are tunable hyperparameters; e.g., an exponential schedule can be used for `eta_max`. If unsure, the defaults were shown to work well in the paper. --_Authors_
- **[Save/load](https://keras.io/getting-started/faq/#how-can-i-save-a-keras-model) optimizer state**; WR relies on using the optimizer's update history for effective transitions --_Authors_ (A-2)
Expand Down
7 changes: 4 additions & 3 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import matplotlib.pyplot as plt

from keras_adamw import AdamW
from keras_adamw.utils_common import K_eval
from keras_adamw.utils import K_eval


ipt = Input(shape=(120,4))
x = LSTM(60, activation='relu', name='lstm_1',
kernel_regularizer=l1(1e-4), recurrent_regularizer=l2(2e-4))(ipt)
out = Dense(1, activation='sigmoid', kernel_regularizer=l1_l2(1e-4))(x)
out = Dense(1, activation='sigmoid', kernel_regularizer=l1_l2(1e-4, 2e-4))(x)
model = Model(ipt, out)

lr_multipliers = {'lstm_1': 0.5}
Expand All @@ -29,8 +29,9 @@
loss = model.train_on_batch(x, y)
eta_history.append(K_eval(model.optimizer.eta_t, K))
print("Iter {} loss: {}".format(iteration + 1, "%.3f" % loss))
if iteration == (24 - 2):
K.set_value(model.optimizer.t_cur, -1) # WARM RESTART
print("EPOCH {} COMPLETED\n".format(epoch + 1))
K.set_value(model.optimizer.t_cur, 0) # WARM RESTART

plt.plot(eta_history, linewidth=2)
plt.xlim(0, len(eta_history))
Expand Down
4 changes: 2 additions & 2 deletions keras_adamw/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
```python
TensorFlow 1.14.0 + Keras 2.2.5 + 'keras' >> optimizers_225.py + utils.py
TensorFlow 1.14.0 + Keras 2.2.5 + 'tf.keras' >> optimizers_225tf.py + utils_225tf.py
TensorFlow 2.0.0 + Keras 2.3.0 + 'keras' >> optimizers.py + utils.py
TensorFlow 2.0.0 + Keras 2.3.0 + 'tf.keras' >> optimizers_v2.py + utils.py
TensorFlow 2+ + Keras 2.3.0 + 'keras' >> optimizers.py + utils.py
TensorFlow 2+ + Keras 2.3.0 + 'tf.keras' >> optimizers_v2.py + utils.py
```

[__init__.py](https://github.com/OverLordGoldDragon/keras-adamw/blob/master/keras_adamw/__init__.py) takes care of making the correct selection, but
Expand Down
6 changes: 3 additions & 3 deletions keras_adamw/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
else:
from .optimizers_225 import AdamW, NadamW, SGDW

from .utils_common import get_weight_decays, fill_dict_in_order
from .utils_common import reset_seeds, K_eval
from .utils import get_weight_decays, fill_dict_in_order
from .utils import reset_seeds, K_eval


__version__ = '1.23'
__version__ = '1.3'
84 changes: 39 additions & 45 deletions keras_adamw/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
from keras import backend as K
from keras.legacy import interfaces
from keras.optimizers import Optimizer
from .utils_common import _init_weight_decays, _check_args
from .utils_common import K_eval as KE
from .utils import _apply_weight_decays, _compute_eta_t
from .utils import _apply_lr_multiplier
from .utils import _init_weight_decays, _apply_weight_decays, _check_args
from .utils import _apply_lr_multiplier, _update_t_cur_eta_t
from .utils import K_eval as KE


def K_eval(x):
Expand Down Expand Up @@ -72,7 +71,9 @@ def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999,
use_cosine_annealing=False, lr_multipliers=None,
weight_decays=None, init_verbose=True,
eta_min=0, eta_max=1, t_cur=0, **kwargs):
weight_decays = _init_weight_decays(model, zero_penalties, weight_decays)
if total_iterations > 1:
weight_decays = _init_weight_decays(model, zero_penalties,
weight_decays)

self.initial_decay = kwargs.pop('decay', 0.0)
self.epsilon = kwargs.pop('epsilon', K.epsilon())
Expand All @@ -86,13 +87,12 @@ def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999,
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
self.decay = K.variable(self.initial_decay, name='decay')
self.batch_size = K.variable(batch_size, dtype='int64',
name='batch_size')
self.eta_min = K.constant(eta_min, name='eta_min')
self.eta_max = K.constant(eta_max, name='eta_max')
self.eta_t = K.variable(eta_t, dtype='float32', name='eta_t')
self.t_cur = K.variable(t_cur, dtype='int64', name='t_cur')

self.batch_size = batch_size
self.total_iterations = total_iterations
self.total_iterations_wd = total_iterations_wd or total_iterations
self.amsgrad = amsgrad
Expand All @@ -101,15 +101,15 @@ def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999,
self.init_verbose = init_verbose
self.use_cosine_annealing = use_cosine_annealing

_check_args(self, total_iterations, use_cosine_annealing, weight_decays)
self._init_lr = learning_rate # to print lr_mult setup
self._init_notified = False
_check_args(total_iterations, use_cosine_annealing, self.weight_decays)

@interfaces.legacy_get_updates_support
@K.symbolic
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = [K.update_add(self.iterations, 1)]
self.updates.append(K.update_add(self.t_cur, 1))

lr = self.learning_rate
if self.initial_decay > 0:
Expand Down Expand Up @@ -139,12 +139,6 @@ def get_updates(self, loss, params):
for i in range(len(params))]
self.weights = [self.iterations] + ms + vs + vhats

total_iterations = self.total_iterations
# Cosine annealing
if self.use_cosine_annealing and total_iterations != 0:
self.eta_t = _compute_eta_t(self)
self.lr_t = lr_t * self.eta_t # for external tracking

for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):
# Learning rate multipliers
if self.lr_multipliers is not None:
Expand All @@ -163,7 +157,7 @@ def get_updates(self, loss, params):
self.updates.append(K.update(v, v_t))

# Weight decays
if p.name in self.weight_decays.keys() and total_iterations != 0:
if p.name in self.weight_decays.keys():
p_t = _apply_weight_decays(self, p, p_t)
new_p = p_t

Expand All @@ -173,6 +167,10 @@ def get_updates(self, loss, params):

self.updates.append(K.update(p, new_p))

# Cosine annealing
_update_t_cur_eta_t(self)
self.lr_t = lr_t * self.eta_t # for external tracking

self._init_notified = True
return self.updates

Expand All @@ -182,7 +180,7 @@ def get_config(self):
'beta_1': float(K_eval(self.beta_1)),
'beta_2': float(K_eval(self.beta_2)),
'decay': float(K_eval(self.decay)),
'batch_size': int(K_eval(self.batch_size)),
'batch_size': int(self.batch_size),
'total_iterations': int(self.total_iterations),
'weight_decays': self.weight_decays,
'lr_multipliers': self.lr_multipliers,
Expand Down Expand Up @@ -261,7 +259,9 @@ def __init__(self, learning_rate=0.002, beta_1=0.9, beta_2=0.999,
use_cosine_annealing=False, lr_multipliers=None,
weight_decays=None, init_verbose=True,
eta_min=0, eta_max=1, t_cur=0, **kwargs):
weight_decays = _init_weight_decays(model, zero_penalties, weight_decays)
if total_iterations > 1:
weight_decays = _init_weight_decays(model, zero_penalties,
weight_decays)

self.schedule_decay = kwargs.pop('schedule_decay', 0.004)
self.epsilon = kwargs.pop('epsilon', K.epsilon())
Expand All @@ -275,29 +275,28 @@ def __init__(self, learning_rate=0.002, beta_1=0.9, beta_2=0.999,
self.learning_rate = K.variable(learning_rate, name='learning_rate')
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
self.batch_size = K.variable(batch_size, dtype='int64',
name='batch_size')
self.eta_min = K.constant(eta_min, name='eta_min')
self.eta_max = K.constant(eta_max, name='eta_max')
self.eta_t = K.variable(eta_t, dtype='float32', name='eta_t')
self.t_cur = K.variable(t_cur, dtype='int64', name='t_cur')

self.batch_size = batch_size
self.total_iterations = total_iterations
self.total_iterations_wd = total_iterations_wd or total_iterations
self.lr_multipliers = lr_multipliers
self.weight_decays = weight_decays or {}
self.use_cosine_annealing = use_cosine_annealing
self.init_verbose = init_verbose

_check_args(self, total_iterations, use_cosine_annealing, weight_decays)
self._init_lr = learning_rate # to print lr_mult setup
self._init_notified = False
_check_args(total_iterations, use_cosine_annealing, self.weight_decays)

@interfaces.legacy_get_updates_support
@K.symbolic
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = [K.update_add(self.iterations, 1)]
self.updates.append(K.update_add(self.t_cur, 1))

t = K.cast(self.iterations, K.floatx()) + 1

Expand All @@ -318,12 +317,6 @@ def get_updates(self, loss, params):

self.weights = [self.iterations, self.m_schedule] + ms + vs

total_iterations = self.total_iterations
# Cosine annealing
if self.use_cosine_annealing and total_iterations != 0:
self.eta_t = _compute_eta_t(self)
self.lr_t = self.learning_rate * self.eta_t # for external tracking

for p, g, m, v in zip(params, grads, ms, vs):
# Learning rate multipliers
lr_t = self.learning_rate
Expand All @@ -345,16 +338,19 @@ def get_updates(self, loss, params):
K.sqrt(v_t_prime) + self.epsilon)

# Weight decays
if p.name in self.weight_decays.keys() and total_iterations != 0:
if p.name in self.weight_decays.keys():
p_t = _apply_weight_decays(self, p, p_t)
new_p = p_t

# Apply constraints.
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)

self.updates.append(K.update(p, new_p))

# Cosine annealing
_update_t_cur_eta_t(self)
self.lr_t = lr_t * self.eta_t # for external tracking

self._init_notified = True
return self.updates

Expand All @@ -374,7 +370,7 @@ def get_config(self):
'beta_2': float(K_eval(self.beta_2)),
'epsilon': self.epsilon,
'schedule_decay': self.schedule_decay,
'batch_size': int(K_eval(self.batch_size)),
'batch_size': int(self.batch_size),
'total_iterations': int(self.total_iterations),
'weight_decays': self.weight_decays,
'lr_multipliers': self.lr_multipliers,
Expand Down Expand Up @@ -448,7 +444,9 @@ def __init__(self, learning_rate=0.01, momentum=0., nesterov=False,
use_cosine_annealing=False, lr_multipliers=None,
weight_decays=None, init_verbose=True,
eta_min=0, eta_max=1, t_cur=0, **kwargs):
weight_decays = _init_weight_decays(model, zero_penalties, weight_decays)
if total_iterations > 1:
weight_decays = _init_weight_decays(model, zero_penalties,
weight_decays)

self.initial_decay = kwargs.pop('decay', 0.0)
learning_rate = kwargs.pop('lr', learning_rate)
Expand All @@ -460,13 +458,12 @@ def __init__(self, learning_rate=0.01, momentum=0., nesterov=False,
self.learning_rate = K.variable(learning_rate, name='learning_rate')
self.momentum = K.variable(momentum, name='momentum')
self.decay = K.variable(self.initial_decay, name='decay')
self.batch_size = K.variable(batch_size, dtype='int64',
name='batch_size')
self.eta_min = K.constant(eta_min, name='eta_min')
self.eta_max = K.constant(eta_max, name='eta_max')
self.eta_t = K.variable(eta_t, dtype='float32', name='eta_t')
self.t_cur = K.variable(t_cur, dtype='int64', name='t_cur')

self.batch_size = batch_size
self.total_iterations = total_iterations
self.total_iterations_wd = total_iterations_wd or total_iterations
self.nesterov = nesterov
Expand All @@ -475,15 +472,15 @@ def __init__(self, learning_rate=0.01, momentum=0., nesterov=False,
self.init_verbose = init_verbose
self.use_cosine_annealing = use_cosine_annealing

_check_args(self, total_iterations, use_cosine_annealing, weight_decays)
self._init_lr = learning_rate # to print lr_mult setup
self._init_notified = False
_check_args(total_iterations, use_cosine_annealing, self.weight_decays)

@interfaces.legacy_get_updates_support
@K.symbolic
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = [K.update_add(self.iterations, 1)]
self.updates.append(K.update_add(self.t_cur, 1))

lr = self.learning_rate
if self.initial_decay > 0:
Expand All @@ -495,12 +492,6 @@ def get_updates(self, loss, params):
for (i, shape) in enumerate(shapes)]
self.weights = [self.iterations] + moments

total_iterations = self.total_iterations
# Cosine annealing
if self.use_cosine_annealing and total_iterations != 0:
self.eta_t = _compute_eta_t(self)
self.lr_t = lr * self.eta_t # for external tracking

for p, g, m in zip(params, grads, moments):
# Learning rate multipliers
lr_t = self.learning_rate
Expand All @@ -516,16 +507,19 @@ def get_updates(self, loss, params):
p_t = p + v

# Weight decays
if p.name in self.weight_decays.keys() and total_iterations != 0:
if p.name in self.weight_decays.keys():
p_t = _apply_weight_decays(self, p, p_t)
new_p = p_t

# Apply constraints.
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)

self.updates.append(K.update(p, new_p))

# Cosine annealing
_update_t_cur_eta_t(self)
self.lr_t = lr_t * self.eta_t # for external tracking

self._init_notified = True
return self.updates

Expand All @@ -535,7 +529,7 @@ def get_config(self):
'momentum': float(K_eval(self.momentum)),
'decay': float(K_eval(self.decay)),
'nesterov': self.nesterov,
'batch_size': int(K_eval(self.batch_size)),
'batch_size': int(self.batch_size),
'total_iterations': int(self.total_iterations),
'weight_decays': self.weight_decays,
'lr_multipliers': self.lr_multipliers,
Expand Down
Loading

0 comments on commit d55194b

Please sign in to comment.