Skip to content

Commit

Permalink
Fix for multi-device loss bug in current init2winit.
Browse files Browse the repository at this point in the history
init2winit loss computations follow this pattern of computing `loss_numerator / loss_denominator ` per device followed by a `loss = lax.pmean(per_device_loss)`

this introduces bugs in many cases, example: if there's unequal paddings per device in examples or if there's unequal number of examples per device and so on

most codebases seem to follow the safer patter of returning `sum(loss_numerator)` and `sum(loss_denominator)` per device followed by `lax.psum(per_device_loss_numerator) / (lax.psum(per_device_loss_denominator) + 1e-9)` which is strictly a better implementation to avoid unforeseen bugs.

PiperOrigin-RevId: 591130444
  • Loading branch information
sourabh2k15 authored and copybara-github committed Dec 15, 2023
1 parent 7ab2933 commit 0731833
Show file tree
Hide file tree
Showing 14 changed files with 211 additions and 94 deletions.
16 changes: 13 additions & 3 deletions hessian/hessian_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,20 @@ def set_up_hessian_eval(model, params, batch_stats, dataset,
# However, we need to provide batch_stats for the model.training_cost API.
# The copy is needed b/c the trainer will modify underlying arrays.
batch_stats = jax.tree_map(lambda x: x[:][0], batch_stats)
def batch_loss(module, batch_rng):
def batch_loss(params, batch_rng):
batch, rng = batch_rng
return model.training_cost(
module, batch, batch_stats=batch_stats, dropout_rng=rng)[0]

apply_kwargs = {'train': True}
apply_kwargs['mutable'] = ['batch_stats']
apply_kwargs['rngs'] = {'dropout': rng}

logits, _ = model.apply_on_batch(
params, batch_stats, batch, **apply_kwargs)
weights = batch.get('weights')
loss_numerator, loss_denominator = model.loss_fn(
logits, batch['targets'], weights)

return (loss_numerator / loss_denominator)

def batch_output(module, batch_rng):
batch, rng = batch_rng
Expand Down
15 changes: 13 additions & 2 deletions hessian/run_lanczos.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,19 @@ def eval_checkpoints(
# hessian eval.
def batch_loss(params, batch_rng):
batch, rng = batch_rng
return model.training_cost(
params, batch, batch_stats=unreplicated_batch_stats, dropout_rng=rng)[0]

apply_kwargs = {'train': True}
apply_kwargs['mutable'] = ['batch_stats']
apply_kwargs['rngs'] = {'dropout': rng}

logits, _ = model.apply_on_batch(
params, unreplicated_batch_stats, batch, **apply_kwargs)
weights = batch.get('weights')
loss_numerator, loss_denominator = model.loss_fn(
logits, batch['targets'], weights)

return (loss_numerator / loss_denominator)

batch_stats = jax_utils.replicate(unreplicated_batch_stats)

if jax.process_index() == 0:
Expand Down
4 changes: 3 additions & 1 deletion init2winit/init_lib/meta_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ def get_batch(rng_key):
def update(meta_params, optimizer_state, inputs, targets):
"""Update step."""
def params_to_loss(params):
return loss_fn(fprop({'params': params}, inputs, train=True), targets)
loss_value, loss_weight = loss_fn(
fprop({'params': params}, inputs, train=True), targets)
return loss_value / loss_weight

def _meta_loss(params):
return meta_loss(params_to_loss, params, normalized_params, hps.epsilon)
Expand Down
8 changes: 7 additions & 1 deletion init2winit/model_lib/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,14 @@ def training_objective_fn(self, params, logits, targets, weights):
targets = model_utils.apply_label_smoothing(
targets, self.hps.get('label_smoothing'))

objective_value = self.loss_fn(logits, targets, weights)
objective_numerator, objective_denominator = self.loss_fn(
logits, targets, weights)

(objective_numerator, objective_denominator) = jax.lax.psum(
(objective_numerator, objective_denominator), axis_name='batch')

# epsilon added to handle empty batch case if we encounter one.
objective_value = (objective_numerator / (objective_denominator + 1e-9))
if self.hps.get('l2_decay_factor'):
l2_loss = model_utils.l2_regularization(params,
self.hps.l2_decay_rank_threshold)
Expand Down
32 changes: 17 additions & 15 deletions init2winit/model_lib/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,15 +780,6 @@ def sequence_mask(self, lengths, maxlen):
c = jnp.less_equal(b, lengths[:, jnp.newaxis]).astype(lengths.dtype)
return c

def compute_loss(self, logits, logit_paddings, labels, label_paddings):
logprobs = nn.log_softmax(logits)
per_seq_loss = self.loss_fn(logprobs, logit_paddings, labels,
label_paddings)
normalizer = jnp.sum(1 - label_paddings)

normalized_loss = jnp.sum(per_seq_loss) / jnp.maximum(normalizer, 1)
return normalized_loss

def collapse_and_remove_blanks(self, labels, seq_length, blank_id: int = 0):
b, t = labels.shape
# Zap out blank
Expand Down Expand Up @@ -861,8 +852,13 @@ def evaluate_batch(self, params, batch_stats, batch):
labels = batch['targets']
label_paddings = batch['target_paddings']

normalized_loss = self.compute_loss(logits, logit_paddings, labels,
label_paddings)
(objective_numerator, objective_denominator) = self.loss_fn(
logits, logit_paddings, labels, label_paddings)

(objective_numerator, objective_denominator) = jax.lax.psum(
(objective_numerator, objective_denominator), axis_name='batch')

normalized_loss = (objective_numerator / (objective_denominator))
hyps, hyp_paddings = self.greedy_decode(logits, logit_paddings)

return self.metrics_bundle.gather_from_model_output(
Expand Down Expand Up @@ -892,9 +888,15 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None):
labels = batch['targets']
label_paddings = batch['target_paddings']

normalized_loss = self.compute_loss(outputs, output_paddings, labels,
label_paddings)
return normalized_loss, new_batch_stats
(objective_numerator, objective_denominator) = self.loss_fn(
outputs, output_paddings, labels, label_paddings)

(objective_numerator, objective_denominator) = jax.lax.psum(
(objective_numerator, objective_denominator), axis_name='batch')

# epsilon added to handle empty batch case if we encounter one.
objective_value = (objective_numerator / (objective_denominator + 1e-9))
return objective_value, new_batch_stats

def apply_on_batch(self, params, batch_stats, batch, **apply_kwargs):
"""Wrapper around flax_module.apply."""
Expand Down Expand Up @@ -939,7 +941,7 @@ def build_flax_module(self):
return module

def get_fake_inputs(self, hps):
"""Helper method solely for purpose of initalizing the model."""
"""Helper method solely for purpose of initializing the model."""
dummy_inputs = [
jnp.zeros((hps.batch_size, *x), dtype=hps.model_dtype)
for x in hps.input_shape
Expand Down
38 changes: 19 additions & 19 deletions init2winit/model_lib/deepspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,15 +885,6 @@ def sequence_mask(self, lengths, maxlen):
c = jnp.less_equal(b, lengths[:, jnp.newaxis]).astype(lengths.dtype)
return c

def compute_loss(self, logits, logit_paddings, labels, label_paddings):
logprobs = nn.log_softmax(logits)
per_seq_loss = self.loss_fn(logprobs, logit_paddings, labels,
label_paddings)
normalizer = jnp.sum(1 - label_paddings)

normalized_loss = jnp.sum(per_seq_loss) / jnp.maximum(normalizer, 1)
return normalized_loss

def collapse_and_remove_blanks(self, labels, seq_length, blank_id: int = 0):
b, t = labels.shape
# Zap out blank.
Expand Down Expand Up @@ -955,20 +946,24 @@ def evaluate_batch(self, params, batch_stats, batch):
"""Evaluates cross_entopy on the given batch."""

logits, logit_paddings = self.flax_module.apply(
{
'params': params,
'batch_stats': batch_stats
},
{'params': params, 'batch_stats': batch_stats},
batch['inputs'],
batch['input_paddings'],
train=False,
mutable=False)
mutable=False,
)

labels = batch['targets']
label_paddings = batch['target_paddings']

normalized_loss = self.compute_loss(logits, logit_paddings, labels,
label_paddings)
(objective_numerator, objective_denominator) = self.loss_fn(
logits, logit_paddings, labels, label_paddings)

(objective_numerator, objective_denominator) = jax.lax.psum(
(objective_numerator, objective_denominator), axis_name='batch')

# epsilon added to handle empty batch case if we encounter one.
normalized_loss = (objective_numerator / (objective_denominator + 1e-9))
hyps, hyp_paddings = self.greedy_decode(logits, logit_paddings)

return self.metrics_bundle.gather_from_model_output(
Expand Down Expand Up @@ -998,9 +993,14 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None):
labels = batch['targets']
label_paddings = batch['target_paddings']

normalized_loss = self.compute_loss(outputs, output_paddings, labels,
label_paddings)
return normalized_loss, new_batch_stats
(objective_numerator, objective_denominator) = self.loss_fn(
outputs, output_paddings, labels, label_paddings)

(objective_numerator, objective_denominator) = jax.lax.psum(
(objective_numerator, objective_denominator), axis_name='batch')

objective_value = (objective_numerator / (objective_denominator))
return objective_value, new_batch_stats

def build_flax_module(self):
config = DeepspeechConfig(
Expand Down
49 changes: 29 additions & 20 deletions init2winit/model_lib/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,17 @@ def sigmoid_binary_cross_entropy(logits, targets, weights=None):
else:
normalization = weights.sum()

return jnp.sum(
unnormalized_sigmoid_binary_cross_entropy(logits, targets,
weights)) / normalization
return (
jnp.sum(
unnormalized_sigmoid_binary_cross_entropy(logits, targets, weights)
),
normalization,
)


def unnormalized_bi_tempered_sigmoid_binary_cross_entropy(
logits, targets, weights=None, t1=1.0, t2=1.0):
logits, targets, weights=None, t1=1.0, t2=1.0
):
"""Computes the bi-tempered sigmoid binary cross entropy per example.
Args:
Expand Down Expand Up @@ -172,7 +176,7 @@ def bi_tempered_sigmoid_binary_cross_entropy(hps,
hps.bi_tempered_loss_t2,
)

return jnp.sum(losses) / normalization
return jnp.sum(losses), normalization


def unnormalized_sigmoid_mean_squared_error(logits, targets, weights=None):
Expand Down Expand Up @@ -207,7 +211,7 @@ def sigmoid_mean_squared_error(logits, targets, weights=None):
unnormalized_sigmoid_mse = unnormalized_sigmoid_mean_squared_error(
logits, targets, weights)

return jnp.sum(unnormalized_sigmoid_mse) / normalization
return jnp.sum(unnormalized_sigmoid_mse), normalization


def rescaled_mean_squared_error(hps, logits, targets, weights=None):
Expand Down Expand Up @@ -247,7 +251,7 @@ def rescaled_mean_squared_error(hps, logits, targets, weights=None):
else:
normalization = targets.shape[0]

return jnp.sum(losses) / normalization
return jnp.sum(losses), normalization


def weighted_unnormalized_cross_entropy(logits, targets, weights=None):
Expand Down Expand Up @@ -286,15 +290,15 @@ def weighted_cross_entropy(logits, targets, weights=None):
else:
normalization = weights.sum()
unnormalized_cross_entropy = weighted_unnormalized_cross_entropy(
logits, targets, weights)
return jnp.sum(unnormalized_cross_entropy) / normalization
logits, targets, weights
)

return jnp.sum(unnormalized_cross_entropy), normalization


def weighted_unnormalized_bi_tempered_cross_entropy(logits,
targets,
weights=None,
t1=1.0,
t2=1.0):
def weighted_unnormalized_bi_tempered_cross_entropy(
logits, targets, weights=None, t1=1.0, t2=1.0
):
"""Compute weighted bi-tempered loss for log probs and targets.
This computes sum_(x,y) bi_tempered(x, y) for a single, potentially padded
Expand Down Expand Up @@ -339,12 +343,16 @@ def weighted_bi_tempered_cross_entropy(hps,
unnormalized_cross_entropy = weighted_unnormalized_bi_tempered_cross_entropy(
logits, targets, weights, hps.bi_tempered_loss_t1, hps.bi_tempered_loss_t2
)
return jnp.sum(unnormalized_cross_entropy) / normalization
return jnp.sum(unnormalized_cross_entropy), normalization


def ctc_loss(logits, logit_paddings, labels, label_paddings, blank_id=0):
return optax.ctc_loss(logits, logit_paddings, labels, label_paddings,
blank_id)
def ctc_loss(logits, logit_paddings, labels, label_paddings):
logprobs = nn.log_softmax(logits)
per_seq_loss = optax.ctc_loss(logprobs, logit_paddings, labels,
label_paddings)
normalizer = jnp.sum(1 - label_paddings)

return jnp.sum(per_seq_loss), jnp.maximum(normalizer, 1)


def weighted_unnormalized_mean_absolute_error(logits,
Expand Down Expand Up @@ -390,8 +398,9 @@ def weighted_mean_absolute_error(logits, targets, weights=None):
else:
normalization = weights.sum()
unnormalized_mean_absolute_error = weighted_unnormalized_mean_absolute_error(
logits, targets, weights)
return jnp.sum(unnormalized_mean_absolute_error) / normalization
logits, targets, weights
)
return jnp.sum(unnormalized_mean_absolute_error), normalization


# TODO(cheolmin): add mean_squared_error
Expand Down
36 changes: 29 additions & 7 deletions init2winit/model_lib/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,16 @@
]


# Starting cl/523831152 i2w loss functions return loss in 2 parts.
# See CL description for context.
def wrap_loss(loss_fn):
def wrapped_loss(logits, targets, weights=None):
loss_numerator, loss_denominator = loss_fn(logits, targets, weights)
return loss_numerator / loss_denominator

return wrapped_loss


class LossesTest(parameterized.TestCase):
"""Tests for losses.py."""

Expand All @@ -181,6 +191,8 @@ def test_output_activation_fn_registry(self):
def test_classification_losses(self, loss_name):
for data in CLASSIFICATION_TEST_DATA:
loss_fn = losses.get_loss_fn(loss_name, data['hps'])
loss_fn = wrap_loss(loss_fn)

self.assertAlmostEqual(
loss_fn(data['logits'], data['one_hot_targets'], data['weights']),
data[loss_name],
Expand All @@ -190,6 +202,7 @@ def test_classification_losses(self, loss_name):
def test_regression_losses(self, loss_name):
for data in RECONSTRUCTION_TEST_DATA:
loss_fn = losses.get_loss_fn(loss_name, data['hps'])
loss_fn = wrap_loss(loss_fn)
self.assertAlmostEqual(
loss_fn(data['logits'], data['targets'], data['weights']),
data[loss_name],
Expand All @@ -203,7 +216,9 @@ def test_cross_entropy_loss_fn(self):
'bi_tempered_cross_entropy')
]:
sigmoid_binary_ce_fn = losses.get_loss_fn(binary_loss_name, data['hps'])
sigmoid_binary_ce_fn = wrap_loss(sigmoid_binary_ce_fn)
ce_fn = losses.get_loss_fn(loss_name, data['hps'])
ce_fn = wrap_loss(ce_fn)
self.assertAlmostEqual(
sigmoid_binary_ce_fn(
np.array([[logits[0] - logits[1]] for logits in data['logits']
Expand All @@ -220,6 +235,7 @@ def test_sigmoid_cross_entropy_per_label_weights(self):
'bi_tempered_sigmoid_binary_cross_entropy']:
sigmoid_binary_ce_fn = losses.get_loss_fn(
binary_loss_name, HPS_1)
sigmoid_binary_ce_fn = wrap_loss(sigmoid_binary_ce_fn)
logits = np.arange(15).reshape(3, 5)
targets = np.arange(15, 30).reshape(3, 5)
targets = targets / np.max(targets)
Expand All @@ -244,29 +260,35 @@ def test_sigmoid_cross_entropy_per_label_weights(self):
testcase_name='2_char_no_repeat_no_blank',
logits=np.array([[[0.1, 0.8, 0.1], [0.1, 0.1, 0.8]]]),
labels=np.array([[1, 2]]),
result=1.379453),
result=1.379453,
),
dict(
testcase_name='1_char_1_blank',
logits=np.array([[[0.2, 0.8], [0.8, 0.2]]]),
labels=np.array([[1, 0]]),
result=0.874976))
result=0.874976,
),
)
def test_ctc_loss(self, logits, labels, result):
"""Tests the CTC loss computation."""
ctc_loss = losses.get_loss_fn('ctc')
loss_value = ctc_loss(logits, np.zeros(logits.shape[:2]), labels,
np.zeros(labels.shape))

loss_value, _ = ctc_loss(
logits, np.zeros(logits.shape[:2]), labels, np.zeros(labels.shape)
)
self.assertAlmostEqual(loss_value, jax.numpy.array([result]), places=6)

@parameterized.named_parameters(
dict(
testcase_name='single_batch',
logits=np.array([[0.1, 0.8, 0.1], [0.1, 0.1, 0.8]]),
targets=np.array([[1., 2., 3.], [4., 5., 6.]]),
result=3.166667))
targets=np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
result=3.166667,
)
)
def test_weighted_mean_absolute_error(self, logits, targets, result):
"""Tests computing MAE."""
mae = losses.get_loss_fn('mean_absolute_error')
mae = wrap_loss(mae)
loss_value = mae(logits, targets)

self.assertAlmostEqual(loss_value, jax.numpy.array([result]))
Expand Down
Loading

0 comments on commit 0731833

Please sign in to comment.