Skip to content

Commit

Permalink
Fix for multi-device loss bug in current init2winit.
Browse files Browse the repository at this point in the history
init2winit loss computations follow this pattern of computing `loss_numerator / loss_denominator ` per device followed by a `loss = lax.pmean(per_device_loss)`

this introduces bugs in many cases, example: if there's unequal paddings per device in examples or if there's unequal number of examples per device and so on

most codebases seem to follow the safer patter of returning `sum(loss_numerator)` and `sum(loss_denominator)` per device followed by `lax.psum(per_device_loss_numerator) / (lax.psum(per_device_loss_denominator) + 1e-9)` which is strictly a better implementation to avoid unforeseen bugs.

PiperOrigin-RevId: 523831152
  • Loading branch information
sourabh2k15 authored and copybara-github committed Dec 13, 2023
1 parent 7ab2933 commit cea0916
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 60 deletions.
7 changes: 6 additions & 1 deletion init2winit/model_lib/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,13 @@ def training_objective_fn(self, params, logits, targets, weights):
targets = model_utils.apply_label_smoothing(
targets, self.hps.get('label_smoothing'))

objective_value = self.loss_fn(logits, targets, weights)
objective_numerator, objective_denominator = self.loss_fn(
logits, targets, weights)

(objective_numerator, objective_denominator) = jax.lax.psum(
(objective_numerator, objective_denominator), axis_name='batch')

objective_value = (objective_numerator / objective_denominator)
if self.hps.get('l2_decay_factor'):
l2_loss = model_utils.l2_regularization(params,
self.hps.l2_decay_rank_threshold)
Expand Down
29 changes: 15 additions & 14 deletions init2winit/model_lib/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,15 +780,6 @@ def sequence_mask(self, lengths, maxlen):
c = jnp.less_equal(b, lengths[:, jnp.newaxis]).astype(lengths.dtype)
return c

def compute_loss(self, logits, logit_paddings, labels, label_paddings):
logprobs = nn.log_softmax(logits)
per_seq_loss = self.loss_fn(logprobs, logit_paddings, labels,
label_paddings)
normalizer = jnp.sum(1 - label_paddings)

normalized_loss = jnp.sum(per_seq_loss) / jnp.maximum(normalizer, 1)
return normalized_loss

def collapse_and_remove_blanks(self, labels, seq_length, blank_id: int = 0):
b, t = labels.shape
# Zap out blank
Expand Down Expand Up @@ -861,8 +852,13 @@ def evaluate_batch(self, params, batch_stats, batch):
labels = batch['targets']
label_paddings = batch['target_paddings']

normalized_loss = self.compute_loss(logits, logit_paddings, labels,
label_paddings)
(objective_numerator, objective_denominator) = self.loss_fn(
logits, logit_paddings, labels, label_paddings)

(objective_numerator, objective_denominator) = jax.lax.psum(
(objective_numerator, objective_denominator), axis_name='batch')

normalized_loss = (objective_numerator / (objective_denominator))
hyps, hyp_paddings = self.greedy_decode(logits, logit_paddings)

return self.metrics_bundle.gather_from_model_output(
Expand Down Expand Up @@ -892,9 +888,14 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None):
labels = batch['targets']
label_paddings = batch['target_paddings']

normalized_loss = self.compute_loss(outputs, output_paddings, labels,
label_paddings)
return normalized_loss, new_batch_stats
(objective_numerator, objective_denominator) = self.loss_fn(
outputs, output_paddings, labels, label_paddings)

(objective_numerator, objective_denominator) = jax.lax.psum(
(objective_numerator, objective_denominator), axis_name='batch')

objective_value = (objective_numerator / (objective_denominator))
return objective_value, new_batch_stats

def apply_on_batch(self, params, batch_stats, batch, **apply_kwargs):
"""Wrapper around flax_module.apply."""
Expand Down
37 changes: 18 additions & 19 deletions init2winit/model_lib/deepspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,15 +885,6 @@ def sequence_mask(self, lengths, maxlen):
c = jnp.less_equal(b, lengths[:, jnp.newaxis]).astype(lengths.dtype)
return c

def compute_loss(self, logits, logit_paddings, labels, label_paddings):
logprobs = nn.log_softmax(logits)
per_seq_loss = self.loss_fn(logprobs, logit_paddings, labels,
label_paddings)
normalizer = jnp.sum(1 - label_paddings)

normalized_loss = jnp.sum(per_seq_loss) / jnp.maximum(normalizer, 1)
return normalized_loss

def collapse_and_remove_blanks(self, labels, seq_length, blank_id: int = 0):
b, t = labels.shape
# Zap out blank.
Expand Down Expand Up @@ -955,20 +946,23 @@ def evaluate_batch(self, params, batch_stats, batch):
"""Evaluates cross_entopy on the given batch."""

logits, logit_paddings = self.flax_module.apply(
{
'params': params,
'batch_stats': batch_stats
},
{'params': params, 'batch_stats': batch_stats},
batch['inputs'],
batch['input_paddings'],
train=False,
mutable=False)
mutable=False,
)

labels = batch['targets']
label_paddings = batch['target_paddings']

normalized_loss = self.compute_loss(logits, logit_paddings, labels,
label_paddings)
(objective_numerator, objective_denominator) = self.loss_fn(
logits, logit_paddings, labels, label_paddings)

(objective_numerator, objective_denominator) = jax.lax.psum(
(objective_numerator, objective_denominator), axis_name='batch')

normalized_loss = (objective_numerator / (objective_denominator))
hyps, hyp_paddings = self.greedy_decode(logits, logit_paddings)

return self.metrics_bundle.gather_from_model_output(
Expand Down Expand Up @@ -998,9 +992,14 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None):
labels = batch['targets']
label_paddings = batch['target_paddings']

normalized_loss = self.compute_loss(outputs, output_paddings, labels,
label_paddings)
return normalized_loss, new_batch_stats
(objective_numerator, objective_denominator) = self.loss_fn(
outputs, output_paddings, labels, label_paddings)

(objective_numerator, objective_denominator) = jax.lax.psum(
(objective_numerator, objective_denominator), axis_name='batch')

objective_value = (objective_numerator / (objective_denominator))
return objective_value, new_batch_stats

def build_flax_module(self):
config = DeepspeechConfig(
Expand Down
49 changes: 29 additions & 20 deletions init2winit/model_lib/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,17 @@ def sigmoid_binary_cross_entropy(logits, targets, weights=None):
else:
normalization = weights.sum()

return jnp.sum(
unnormalized_sigmoid_binary_cross_entropy(logits, targets,
weights)) / normalization
return (
jnp.sum(
unnormalized_sigmoid_binary_cross_entropy(logits, targets, weights)
),
normalization,
)


def unnormalized_bi_tempered_sigmoid_binary_cross_entropy(
logits, targets, weights=None, t1=1.0, t2=1.0):
logits, targets, weights=None, t1=1.0, t2=1.0
):
"""Computes the bi-tempered sigmoid binary cross entropy per example.
Args:
Expand Down Expand Up @@ -172,7 +176,7 @@ def bi_tempered_sigmoid_binary_cross_entropy(hps,
hps.bi_tempered_loss_t2,
)

return jnp.sum(losses) / normalization
return jnp.sum(losses), normalization


def unnormalized_sigmoid_mean_squared_error(logits, targets, weights=None):
Expand Down Expand Up @@ -207,7 +211,7 @@ def sigmoid_mean_squared_error(logits, targets, weights=None):
unnormalized_sigmoid_mse = unnormalized_sigmoid_mean_squared_error(
logits, targets, weights)

return jnp.sum(unnormalized_sigmoid_mse) / normalization
return jnp.sum(unnormalized_sigmoid_mse), normalization


def rescaled_mean_squared_error(hps, logits, targets, weights=None):
Expand Down Expand Up @@ -247,7 +251,7 @@ def rescaled_mean_squared_error(hps, logits, targets, weights=None):
else:
normalization = targets.shape[0]

return jnp.sum(losses) / normalization
return jnp.sum(losses), normalization


def weighted_unnormalized_cross_entropy(logits, targets, weights=None):
Expand Down Expand Up @@ -286,15 +290,15 @@ def weighted_cross_entropy(logits, targets, weights=None):
else:
normalization = weights.sum()
unnormalized_cross_entropy = weighted_unnormalized_cross_entropy(
logits, targets, weights)
return jnp.sum(unnormalized_cross_entropy) / normalization
logits, targets, weights
)

return jnp.sum(unnormalized_cross_entropy), normalization


def weighted_unnormalized_bi_tempered_cross_entropy(logits,
targets,
weights=None,
t1=1.0,
t2=1.0):
def weighted_unnormalized_bi_tempered_cross_entropy(
logits, targets, weights=None, t1=1.0, t2=1.0
):
"""Compute weighted bi-tempered loss for log probs and targets.
This computes sum_(x,y) bi_tempered(x, y) for a single, potentially padded
Expand Down Expand Up @@ -339,12 +343,16 @@ def weighted_bi_tempered_cross_entropy(hps,
unnormalized_cross_entropy = weighted_unnormalized_bi_tempered_cross_entropy(
logits, targets, weights, hps.bi_tempered_loss_t1, hps.bi_tempered_loss_t2
)
return jnp.sum(unnormalized_cross_entropy) / normalization
return jnp.sum(unnormalized_cross_entropy), normalization


def ctc_loss(logits, logit_paddings, labels, label_paddings, blank_id=0):
return optax.ctc_loss(logits, logit_paddings, labels, label_paddings,
blank_id)
def ctc_loss(logits, logit_paddings, labels, label_paddings):
logprobs = nn.log_softmax(logits)
per_seq_loss = optax.ctc_loss(logprobs, logit_paddings, labels,
label_paddings)
normalizer = jnp.sum(1 - label_paddings)

return jnp.sum(per_seq_loss), jnp.maximum(normalizer, 1)


def weighted_unnormalized_mean_absolute_error(logits,
Expand Down Expand Up @@ -390,8 +398,9 @@ def weighted_mean_absolute_error(logits, targets, weights=None):
else:
normalization = weights.sum()
unnormalized_mean_absolute_error = weighted_unnormalized_mean_absolute_error(
logits, targets, weights)
return jnp.sum(unnormalized_mean_absolute_error) / normalization
logits, targets, weights
)
return jnp.sum(unnormalized_mean_absolute_error), normalization


# TODO(cheolmin): add mean_squared_error
Expand Down
8 changes: 7 additions & 1 deletion init2winit/model_lib/xformer_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,13 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None):
if self.hps.get('label_smoothing') is not None:
targets = model_utils.apply_label_smoothing(
targets, self.hps.get('label_smoothing'))
total_loss = self.loss_fn(logits, targets, weights)
(total_loss, total_weight) = self.loss_fn(
logits, targets, weights)

(total_loss, total_weight) = lax.psum(
(total_loss, total_weight), axis_name='batch')

total_loss = (total_loss / total_weight)

if self.hps.get('l2_decay_factor'):
l2_loss = model_utils.l2_regularization(
Expand Down
8 changes: 7 additions & 1 deletion init2winit/model_lib/xformer_translate_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,13 @@ def training_cost(self,
batch.get('targets_segmentation'),
train=False)
targets = lax.stop_gradient(softmax(targets))
total_loss = self.loss_fn(logits, targets, weights)
(total_loss, total_weight) = self.loss_fn(
logits, targets, weights)

(total_loss, total_weight) = lax.psum(
(total_loss, total_weight), axis_name='batch')

total_loss = (total_loss / total_weight)

if self.hps.get('l2_decay_factor'):
l2_loss = model_utils.l2_regularization(
Expand Down
8 changes: 7 additions & 1 deletion init2winit/model_lib/xformer_translate_mlc_variant.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,13 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None):
if self.hps.get('label_smoothing') is not None:
targets = model_utils.apply_label_smoothing(
targets, self.hps.get('label_smoothing'))
total_loss = self.loss_fn(logits, targets, weights)
(total_loss, total_weight) = self.loss_fn(
logits, targets, weights)

(total_loss, total_weight) = lax.psum(
(total_loss, total_weight), axis_name='batch')

total_loss = (total_loss / total_weight)

if self.hps.get('l2_decay_factor'):
l2_loss = model_utils.l2_regularization(
Expand Down
5 changes: 2 additions & 3 deletions init2winit/trainer_lib/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,9 @@ def opt_cost(params):

if axis_name is not None:
if optimizer_utils.requires_gradient_aggregation(optimizer_update_fn):
cost_value, grad = lax.pmean((cost_value, grad), axis_name=axis_name)
grad = lax.pmean((grad), axis_name=axis_name)
else:
# Skip gradient aggregation, only aggregate the cost value.
cost_value = lax.pmean(cost_value, axis_name=axis_name)
# Skip gradient aggregationas it'll be handled in gradient_accumulator.
if grad_clip:
# Calculating the gradient norm requires cross-device aggregation,
# performed, in this case, inside the optimizer. Calculating it again
Expand Down

0 comments on commit cea0916

Please sign in to comment.