From a9f19665c112704804e5c6143800d47bc94c90ad Mon Sep 17 00:00:00 2001 From: Sourabh Medapati Date: Wed, 11 Sep 2024 23:26:54 -0700 Subject: [PATCH] upgrading init2winit from pmap to jit PiperOrigin-RevId: 673695362 --- init2winit/base_callback.py | 4 +- init2winit/checkpoint.py | 99 +++++------------ init2winit/dataset_lib/data_utils.py | 68 ++++++------ init2winit/dataset_lib/ogbg_molpcba.py | 9 +- init2winit/dataset_lib/test_ogbg_molpcba.py | 20 ++-- init2winit/model_lib/base_model.py | 8 +- init2winit/model_lib/conformer.py | 11 +- init2winit/model_lib/deepspeech.py | 18 +-- init2winit/model_lib/unet.py | 2 +- init2winit/model_lib/xformer_translate.py | 6 +- init2winit/mt_eval/inference.py | 42 +++---- init2winit/mt_eval/mt_callback.py | 45 +++++--- init2winit/trainer_lib/base_trainer.py | 115 ++++++++++++-------- init2winit/trainer_lib/trainer.py | 73 ++++++------- init2winit/trainer_lib/trainer_utils.py | 75 ++++--------- 15 files changed, 256 insertions(+), 339 deletions(-) diff --git a/init2winit/base_callback.py b/init2winit/base_callback.py index 09f435d7..65588751 100644 --- a/init2winit/base_callback.py +++ b/init2winit/base_callback.py @@ -20,7 +20,7 @@ callback_builder = callbacks.get_callback(config['callback_name']) callback = callback_builder(model, params, batch_stats, optimizer_state, - dataset, hps, config, train_dir, rng) + dataset, hps, config, train_dir, rng, mesh) callback_metrics = callback.run_eval(params, batch_stats, optimizer_state, global_step). @@ -39,7 +39,7 @@ class BaseCallBack: def __init__(self, model, params, batch_stats, optimizer_state, optimizer_update_fn, dataset, hps, callback_config, train_dir, - rng): + rng, mesh): """Defines the API for callback construction.""" pass diff --git a/init2winit/checkpoint.py b/init2winit/checkpoint.py index 46500fc1..fea3886d 100644 --- a/init2winit/checkpoint.py +++ b/init2winit/checkpoint.py @@ -18,14 +18,14 @@ This is useful for training neural networks with stax, where model parameters are nested numpy arrays. """ + import os import sys -from typing import Sequence from absl import flags from absl import logging -from flax import jax_utils from flax.training import checkpoints as flax_checkpoints +from init2winit.dataset_lib import data_utils as utils import jax FLAGS = flags.FLAGS @@ -44,47 +44,12 @@ def load_pytree(pytree_file, orbax_checkpointer=None): return None -def replicate_checkpoint( - latest, - pytree_keys: Sequence[str], - replicate=True): - """Restores from the provided checkpoint. - - Args: - latest: A dict representing the state of the - checkpoint we want to restore. - pytree_keys: A sequence of keys into `latest` that are pytrees, which will - be replicated if replicate=True. - replicate: If set, replicate the state across devices. - - Returns: - Tuple of (pytree, extra_dict) where pytree is a JAX pytree holding the - arrays that need to be replicated/unreplicated and extra_dict holds any - additional python state. We expect extra_dict to have the keys of - 'global_step', 'preemption_count', 'sum_train_cost', but old checkpoints - might be missing something. - """ - logging.info('Loaded model parameters from latest checkpoint.') - # Old checkpoints without 'sum_train_cost' can still be restored, but the - # train() function will break. Evals and curvature stuff should be fine, - # however. - expected = ['global_step', 'preemption_count', 'sum_train_cost'] - if any(k not in latest for k in expected): - logging.warn('Checkpoint state missing keys, obtained %s expected %s', - list(latest.keys()), expected) - - pytree = {k: latest[k] for k in pytree_keys} - if replicate: - pytree = jax_utils.replicate(pytree) - extra_dict = {k: latest[k] for k in latest.keys() if k not in pytree_keys} - return pytree, extra_dict - - def replicate_and_maybe_restore_checkpoint( unreplicated_optimizer_state, unreplicated_params, unreplicated_batch_stats, unreplicated_training_metrics_state, + mesh, train_dir, external_checkpoint_path=None, orbax_checkpointer=None): @@ -104,6 +69,7 @@ def replicate_and_maybe_restore_checkpoint( unreplicated_params: unreplicated params unreplicated_batch_stats: unreplicated batch stats unreplicated_training_metrics_state: unreplicated metrics state + mesh: Mesh specification to use for sharding. train_dir: (str) The training directory where we will look for a checkpoint. external_checkpoint_path: (str) If this argument is set, then we will load the external checkpoint stored there. @@ -165,43 +131,34 @@ def replicate_and_maybe_restore_checkpoint( # Handle failure to load from external_checkpoint_path. if ckpt_to_return['global_step'] == -1: return ( - jax_utils.replicate(unreplicated_optimizer_state), - jax_utils.replicate(unreplicated_params), - jax_utils.replicate(unreplicated_batch_stats), - jax_utils.replicate(unreplicated_training_metrics_state), + utils.shard_pytree(unreplicated_optimizer_state, mesh), + utils.shard_pytree(unreplicated_params, mesh), + utils.shard_pytree(unreplicated_batch_stats, mesh), + utils.shard_pytree(unreplicated_training_metrics_state, mesh), 0, # global_step - jax_utils.replicate(0), # sum_train_cost + 0, # sum_train_cost 0, # preemption_count False) # is_restored else: # Else, don't restore from any checkpoint. return ( - jax_utils.replicate(unreplicated_optimizer_state), - jax_utils.replicate(unreplicated_params), - jax_utils.replicate(unreplicated_batch_stats), - jax_utils.replicate(unreplicated_training_metrics_state), + utils.shard_pytree(unreplicated_optimizer_state, mesh), + utils.shard_pytree(unreplicated_params, mesh), + utils.shard_pytree(unreplicated_batch_stats, mesh), + utils.shard_pytree(unreplicated_training_metrics_state, mesh), 0, # global_step - jax_utils.replicate(0), # sum_train_cost + 0, # sum_train_cost 0, # preemption_count False) # is_restored - pytree_dict, extra_state = replicate_checkpoint( - ckpt_to_return, - pytree_keys=[ - 'optimizer_state', - 'params', - 'batch_stats', - 'training_metrics_grabber', - 'sum_train_cost', - ]) return ( - pytree_dict['optimizer_state'], - pytree_dict['params'], - pytree_dict['batch_stats'], - pytree_dict['training_metrics_grabber'], - extra_state['global_step'], - pytree_dict['sum_train_cost'], - extra_state['preemption_count'], - is_restored) + utils.shard_pytree(ckpt_to_return['optimizer_state'], mesh), + utils.shard_pytree(ckpt_to_return['params'], mesh), + utils.shard_pytree(ckpt_to_return['batch_stats'], mesh), + utils.shard_pytree(ckpt_to_return['training_metrics_grabber'], mesh), + ckpt_to_return['global_step'], # global_step + ckpt_to_return['sum_train_cost'], + ckpt_to_return['preemption_count'], # preemption_count + is_restored) # is_restored def save_unreplicated_checkpoint( @@ -217,14 +174,12 @@ def save_unreplicated_checkpoint( max_to_keep=1): """Saves pytree, step, preemption_count, and sum_train_cost to train_dir.""" logging.info('Saving checkpoint to ckpt_%d', global_step) - unreplicated_optimizer_state = jax.device_get( - jax_utils.unreplicate(optimizer_state)) - unreplicated_params = jax.device_get(jax_utils.unreplicate(params)) - unreplicated_batch_stats = jax.device_get(jax_utils.unreplicate(batch_stats)) + unreplicated_optimizer_state = jax.device_get(optimizer_state) + unreplicated_params = jax.device_get(params) + unreplicated_batch_stats = jax.device_get(batch_stats) unreplicated_training_metrics_state = jax.device_get( - jax_utils.unreplicate(training_metrics_state)) - unreplicated_sum_train_cost = jax.device_get( - jax_utils.unreplicate(sum_train_cost)) + training_metrics_state) + unreplicated_sum_train_cost = jax.device_get(sum_train_cost) state = dict(global_step=global_step, preemption_count=preemption_count, sum_train_cost=unreplicated_sum_train_cost, diff --git a/init2winit/dataset_lib/data_utils.py b/init2winit/dataset_lib/data_utils.py index 53ecaf42..53fabc7b 100644 --- a/init2winit/dataset_lib/data_utils.py +++ b/init2winit/dataset_lib/data_utils.py @@ -16,10 +16,14 @@ """Common code used by different models.""" import collections + +import flax.linen as nn import jax from jax.nn import one_hot +from jax.sharding import PartitionSpec as P import numpy as np + Dataset = collections.namedtuple('Dataset', [ 'train_iterator_fn', 'eval_train_epoch', @@ -143,40 +147,6 @@ def zero_pad(ar, pad_axis): return padded_batch -def shard(batch, n_devices=None): - """Prepares the batch for pmap by adding a leading n_devices dimension. - - If all the entries are lists, assume they are already divided into n_devices - smaller arrays and stack them for pmapping. If all the entries are arrays, - assume they have leading dimension divisible by n_devices and reshape. - - Args: - batch: A dict of arrays or lists of arrays - n_devices: If None, this will be set to jax.local_device_count(). - - Returns: - Sharded data. - """ - if n_devices is None: - n_devices = jax.local_device_count() - - # TODO(mbadura): Specify a sharding function per dataset instead - # If entries in the batch dict are lists, then the data is already divided - # into n_devices chunks, so we need to stack them. - if all((isinstance(v, list) for v in batch.values())): - assert all(len(v) == n_devices for v in batch.values()) - # transpose a dict of lists to a list of dicts - shards = [{k: v[i] for (k, v) in batch.items()} for i in range(n_devices)] - return jax.tree.map(lambda *vals: np.stack(vals, axis=0), shards[0], - *shards[1:]) - - # Otherwise, the entries are arrays, so just reshape them. - def _shard_array(array): - return array.reshape((n_devices, -1) + array.shape[1:]) - - return jax.tree.map(_shard_array, batch) - - def tf_to_numpy(tfds_data): # Safe because we won't mutate. Avoids an extra copy from tfds. convert_data = lambda x: x._numpy() # pylint: disable=protected-access @@ -187,4 +157,32 @@ def tf_to_numpy(tfds_data): def convert_jax_to_tf_random_seed(jax_prng_key: jax.random.PRNGKey) -> int: tf_seed = jax.random.bits(jax_prng_key) return tf_seed - \ No newline at end of file + + +def make_global_array(local_data, mesh): + """Util to combine per-host batches into a global batch array. + + Args: + local_data: local data batch on host. + mesh: mesh specification to shard the data. + + Returns: + global_array: global data batch. + """ + global_shape = ( + local_data.shape[0] * jax.process_count(), + *local_data.shape[1:], + ) + sharding = jax.NamedSharding(mesh, P('devices')) + + global_array = jax.make_array_from_process_local_data( + sharding, local_data, global_shape + ) + return global_array + + +def shard_pytree(pytree, mesh): + shardings = nn.get_sharding(pytree, mesh) + pytree = jax.device_put(pytree, shardings) + + return shardings, pytree diff --git a/init2winit/dataset_lib/ogbg_molpcba.py b/init2winit/dataset_lib/ogbg_molpcba.py index edb5748b..9d7e9d64 100644 --- a/init2winit/dataset_lib/ogbg_molpcba.py +++ b/init2winit/dataset_lib/ogbg_molpcba.py @@ -208,8 +208,7 @@ def _get_batch_iterator(dataset_iter, num_shards: How many devices we should be able to shard the batch into. Yields: - Batch in the init2winit format. Each field is a list of num_shards separate - smaller batches. + Batch in the init2winit format. """ if not num_shards: @@ -252,9 +251,9 @@ def _get_batch_iterator(dataset_iter, if count == num_shards: yield { - 'inputs': graphs_shards, - 'targets': labels_shards, - 'weights': weights_shards + 'inputs': jraph.batch(graphs_shards), + 'targets': np.vstack(labels_shards), + 'weights': np.vstack(weights_shards) } count = 0 diff --git a/init2winit/dataset_lib/test_ogbg_molpcba.py b/init2winit/dataset_lib/test_ogbg_molpcba.py index 70d9aef7..6c047b56 100644 --- a/init2winit/dataset_lib/test_ogbg_molpcba.py +++ b/init2winit/dataset_lib/test_ogbg_molpcba.py @@ -117,10 +117,9 @@ def test_get_batch_pads_correctly(self): dataset = _get_dataset(jax.random.PRNGKey(0)) batch = next(dataset.valid_epoch()) - inputs = batch['inputs'][0] + inputs = batch['inputs'] # The first two graphs are in the first batch - self.assertLen(batch['inputs'], 1) self.assertNDArrayNear(inputs.n_node[:2], np.array(NUMS_NODES[:2]), 1e-3) # The graphs are padded to the right size @@ -130,9 +129,9 @@ def test_get_batch_pads_correctly(self): self.assertEqual(np.sum(inputs.n_edge), BATCH_SIZE * EDGES_SIZE_MULTIPLIER) # Weights are zero at NaN labels and in padded examples - self.assertNDArrayNear(batch['weights'][0], + self.assertNDArrayNear(batch['weights'], np.array([[1, 1], [0, 1], [0, 0]]), 1e-3) - self.assertFalse(np.any(np.isnan(batch['targets'][0]))) + self.assertFalse(np.any(np.isnan(batch['targets']))) def test_train_shuffle_is_deterministic(self): """Tests that shuffling of the train split is deterministic.""" @@ -144,19 +143,18 @@ def test_train_shuffle_is_deterministic(self): batch_same = next(dataset_same.train_iterator_fn()) batch_different = next(dataset_different.train_iterator_fn()) - self.assertAllClose(batch['inputs'][0], batch_same['inputs'][0]) - self.assertNotAllClose(batch['inputs'][0], batch_different['inputs'][0]) + self.assertAllClose(batch['inputs'], batch_same['inputs']) + self.assertNotAllClose(batch['inputs'], batch_different['inputs']) def test_add_virtual_node(self): """Tests that adding a virtual node works correctly.""" dataset = _get_dataset(jax.random.PRNGKey(0), {'add_virtual_node': True}) batch = next(dataset.valid_epoch()) - inputs = batch['inputs'][0] + inputs = batch['inputs'] num_nodes = np.array(NUMS_NODES[0]) num_edges = np.array(NUMS_EDGES[0]) - self.assertLen(batch['inputs'], 1) self.assertNDArrayNear( inputs.n_node[0], np.array(num_nodes + 1), 1e-3) self.assertNDArrayNear( @@ -173,11 +171,10 @@ def test_add_bidirectional_edges(self): jax.random.PRNGKey(0), {'add_bidirectional_edges': True}) batch = next(dataset.valid_epoch()) - inputs = batch['inputs'][0] + inputs = batch['inputs'] num_nodes = np.array(NUMS_NODES[0]) num_edges = np.array(NUMS_EDGES[0]) - self.assertLen(batch['inputs'], 1) self.assertNDArrayNear( inputs.n_node[0], np.array(num_nodes), 1e-3) self.assertNDArrayNear( @@ -188,11 +185,10 @@ def test_add_self_loops(self): dataset = _get_dataset(jax.random.PRNGKey(0), {'add_self_loops': True}) batch = next(dataset.valid_epoch()) - inputs = batch['inputs'][0] + inputs = batch['inputs'] num_nodes = np.array(NUMS_NODES[0]) num_edges = np.array(NUMS_EDGES[0]) - self.assertLen(batch['inputs'], 1) self.assertNDArrayNear( inputs.n_node[0], np.array(num_nodes), 1e-3) self.assertNDArrayNear( diff --git a/init2winit/model_lib/base_model.py b/init2winit/model_lib/base_model.py index ea398897..6e78f305 100644 --- a/init2winit/model_lib/base_model.py +++ b/init2winit/model_lib/base_model.py @@ -67,8 +67,8 @@ def _evaluate_batch(flax_module, params, batch_stats, batch, metrics_bundle, # We don't use CLU's `mask` argument here, we handle it ourselves through # `weights`. - return metrics_bundle.gather_from_model_output( - logits=logits, targets=targets, weights=weights, axis_name='batch') + return metrics_bundle.single_from_model_output( + logits=logits, targets=targets, weights=weights) class BaseModel(object): @@ -300,10 +300,6 @@ def training_objective_fn(self, params, logits, targets, weights): logits, targets, weights ) - (objective_numerator, objective_denominator) = jax.lax.psum( - (objective_numerator, objective_denominator), axis_name='batch' - ) - # epsilon added to handle empty batch case if we encounter one. objective_value = objective_numerator / (objective_denominator + 1e-9) if self.hps.get('l2_decay_factor'): diff --git a/init2winit/model_lib/conformer.py b/init2winit/model_lib/conformer.py index e2353dcd..66115654 100644 --- a/init2winit/model_lib/conformer.py +++ b/init2winit/model_lib/conformer.py @@ -855,19 +855,15 @@ def evaluate_batch(self, params, batch_stats, batch): (objective_numerator, objective_denominator) = self.loss_fn( logits, logit_paddings, labels, label_paddings) - (objective_numerator, objective_denominator) = jax.lax.psum( - (objective_numerator, objective_denominator), axis_name='batch') - normalized_loss = (objective_numerator / (objective_denominator)) hyps, hyp_paddings = self.greedy_decode(logits, logit_paddings) - return self.metrics_bundle.gather_from_model_output( + return self.metrics_bundle.single_from_model_output( normalized_loss=normalized_loss, hyps=hyps, hyp_paddings=hyp_paddings, targets=labels, - target_paddings=label_paddings, - axis_name='batch') + target_paddings=label_paddings) def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): """Return CTC loss.""" @@ -891,9 +887,6 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): (objective_numerator, objective_denominator) = self.loss_fn( outputs, output_paddings, labels, label_paddings) - (objective_numerator, objective_denominator) = jax.lax.psum( - (objective_numerator, objective_denominator), axis_name='batch') - # epsilon added to handle empty batch case if we encounter one. objective_value = (objective_numerator / (objective_denominator + 1e-9)) return objective_value, new_batch_stats diff --git a/init2winit/model_lib/deepspeech.py b/init2winit/model_lib/deepspeech.py index 0ba1c5a9..73626c41 100644 --- a/init2winit/model_lib/deepspeech.py +++ b/init2winit/model_lib/deepspeech.py @@ -405,10 +405,6 @@ def __call__(self, inputs, input_paddings=None, train=False): count_v = jnp.sum( jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True) - if self.enable_synced_batchnorm: - sum_v = jax.lax.psum(sum_v, axis_name='batch') - count_v = jax.lax.psum(count_v, axis_name='batch') - count_v = jnp.maximum(count_v, 1.0) mean = sum_v / count_v @@ -417,9 +413,6 @@ def __call__(self, inputs, input_paddings=None, train=False): axis=reduce_over_dims, keepdims=True) - if self.enable_synced_batchnorm: - sum_vv = jax.lax.psum(sum_vv, axis_name='batch') - var = sum_vv / count_v self.ra_mean.value = momentum * self.ra_mean.value + (1 - momentum) * mean @@ -959,20 +952,16 @@ def evaluate_batch(self, params, batch_stats, batch): (objective_numerator, objective_denominator) = self.loss_fn( logits, logit_paddings, labels, label_paddings) - (objective_numerator, objective_denominator) = jax.lax.psum( - (objective_numerator, objective_denominator), axis_name='batch') - # epsilon added to handle empty batch case if we encounter one. normalized_loss = (objective_numerator / (objective_denominator + 1e-9)) hyps, hyp_paddings = self.greedy_decode(logits, logit_paddings) - return self.metrics_bundle.gather_from_model_output( + return self.metrics_bundle.single_from_model_output( normalized_loss=normalized_loss, hyps=hyps, hyp_paddings=hyp_paddings, targets=labels, - target_paddings=label_paddings, - axis_name='batch') + target_paddings=label_paddings) def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): """Return CTC loss.""" @@ -996,9 +985,6 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): (objective_numerator, objective_denominator) = self.loss_fn( outputs, output_paddings, labels, label_paddings) - (objective_numerator, objective_denominator) = jax.lax.psum( - (objective_numerator, objective_denominator), axis_name='batch') - objective_value = (objective_numerator / (objective_denominator)) return objective_value, new_batch_stats diff --git a/init2winit/model_lib/unet.py b/init2winit/model_lib/unet.py index 950c6c4f..ecffab64 100644 --- a/init2winit/model_lib/unet.py +++ b/init2winit/model_lib/unet.py @@ -333,7 +333,7 @@ def evaluate_batch(self, params, batch_stats, batch): # We don't use CLU's `mask` argument here, we handle it ourselves through # `weights`. - return self.metrics_bundle.gather_from_model_output( + return self.metrics_bundle.single_from_model_output( logits=logits, targets=targets, weights=weights, diff --git a/init2winit/model_lib/xformer_translate.py b/init2winit/model_lib/xformer_translate.py index 7508ac0e..c6ae1ce8 100644 --- a/init2winit/model_lib/xformer_translate.py +++ b/init2winit/model_lib/xformer_translate.py @@ -1045,7 +1045,7 @@ def evaluate_batch(self, params, batch_stats, batch): targets = one_hot(batch['targets'], logits.shape[-1]) # Add log-perplexity metric. - return self.metrics_bundle.gather_from_model_output( + return self.metrics_bundle.single_from_model_output( logits=logits, targets=targets, weights=weights, axis_name='batch') def apply_on_batch(self, @@ -1101,8 +1101,8 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): (total_loss, total_weight) = self.loss_fn( logits, targets, weights) - (total_loss, total_weight) = lax.psum( - (total_loss, total_weight), axis_name='batch') + # (total_loss, total_weight) = lax.psum( + # (total_loss, total_weight), axis_name='batch') total_loss = (total_loss / total_weight) diff --git a/init2winit/mt_eval/inference.py b/init2winit/mt_eval/inference.py index 5f0ad410..1f12df6b 100644 --- a/init2winit/mt_eval/inference.py +++ b/init2winit/mt_eval/inference.py @@ -14,6 +14,7 @@ # limitations under the License. r"""BLEU evaluator container class.""" + import copy import dataclasses import functools @@ -21,9 +22,7 @@ from typing import Any, Sequence from absl import logging -from flax import jax_utils -from flax.training import common_utils -from init2winit import utils +from init2winit.dataset_lib import data_utils as utils from init2winit.dataset_lib import mt_tokenizer from init2winit.mt_eval import decode from init2winit.mt_eval import eval_utils @@ -32,7 +31,6 @@ import numpy as np from tensorflow.io import gfile -glob = gfile.glob DEFAULT_EVAL_CONFIG = { 'eval_batch_size': 16, @@ -68,6 +66,7 @@ def __init__(self, *args, **kwargs): if kwargs['mode'] not in ['offline', 'online']: raise ValueError('BLEU score computation only support online or ' 'offline modes.') + self._mesh = kwargs['mesh'] if kwargs['mode'] == 'offline': self.init_offline_evaluator(*args) else: @@ -168,12 +167,12 @@ def initialize_model(self, model_cls, dataset_meta_data, dropout_rng, params = init_dict['params'] self.flax_module = model.flax_module self.params = params - self.pmapped_init_cache = jax.pmap( + self.init_cache = jax.jit( functools.partial( self.initialize_cache, max_length=self.max_length, params_rng=params_rng, - dropout_rng=dropout_rng), axis_name='gather') + dropout_rng=dropout_rng)) def initialize_cache(self, inputs, max_length, params_rng, dropout_rng): """Initialize a cache for a given input shape and max decode length.""" @@ -217,7 +216,7 @@ def build_predictor(self): eos_id=self.eos_id, beam_size=self.mt_eval_config.get('beam_size'), offset=self.mt_eval_config.get('scan_over_layers_offset', 0)) - self.pmapped_predictor = jax.pmap(decoder, static_broadcasted_argnums=()) + self.predictor = jax.jit(decoder) def translate_and_calculate_bleu(self): """Iterate over all checkpoints and calculate BLEU.""" @@ -233,7 +232,7 @@ def translate_and_calculate_bleu(self): params = eval_utils.average_checkpoints( checkpoint_paths=ckpt_paths, params=self.params) - params_replicated = jax_utils.replicate(params) + params_replicated = utils.replicate_pytree(params, self._mesh) decoding_output = self.translate_and_calculate_bleu_single_model( params_replicated, self.eval_split) logging.info('Sacre bleu score at step %d: %f', step, @@ -246,20 +245,23 @@ def translate_and_calculate_bleu_single_model(self, params, eval_split): self.build_predictor() decode_output = DecodingOutput() logging.info('Starting decoding..') + + make_global_array_fn = functools.partial( + utils.make_global_array, mesh=self._mesh + ) + for batch in self.get_ds_iter(eval_split): - pred_batch = common_utils.shard(batch) - cache = self.pmapped_init_cache(pred_batch['inputs']) - predicted = utils.data_gather( - self.pmapped_predictor(pred_batch, params, cache), - axis_name='gather') - inputs = utils.data_gather(pred_batch['inputs'], axis_name='gather') - targets = utils.data_gather(pred_batch['targets'], axis_name='gather') - weights = utils.data_gather(pred_batch['weights'], axis_name='gather') + pred_batch = jax.tree_util.tree_map(make_global_array_fn, batch) + cache = self.init_cache(pred_batch['inputs']) + predicted = self.predictor(pred_batch, params, cache) + inputs = pred_batch['inputs'] + targets = pred_batch['targets'] + weights = pred_batch['weights'] - predicted = utils.combine_gathered(np.array(predicted)) - inputs = utils.combine_gathered(np.array(inputs)) - targets = utils.combine_gathered(np.array(targets)) - weights = utils.combine_gathered(np.array(weights)) + predicted = np.array(predicted) + inputs = np.array(inputs) + targets = np.array(targets) + weights = np.array(weights) current_batch_size = int(weights[:, 0].sum()) if self.mt_eval_config.get('decoding_type') == 'beam_search': self.process_beam_search_output(inputs, targets, predicted, diff --git a/init2winit/mt_eval/mt_callback.py b/init2winit/mt_eval/mt_callback.py index 1473a554..a8341023 100644 --- a/init2winit/mt_eval/mt_callback.py +++ b/init2winit/mt_eval/mt_callback.py @@ -40,13 +40,13 @@ 'scan_over_layers_offset' equal to the length of that tuple. """ +import functools + from absl import logging from init2winit import base_callback from init2winit import utils -from init2winit.dataset_lib import data_utils from init2winit.dataset_lib import datasets from init2winit.model_lib import models - from init2winit.mt_eval import inference import jax from ml_collections.config_dict import config_dict @@ -74,7 +74,8 @@ def __init__(self, hps, callback_config, train_dir, - rng): + rng, + mesh): del optimizer_state del optimizer_update_fn del train_dir @@ -87,12 +88,14 @@ def __init__(self, self.callback_config = merged_callback_config self._validate_callback_config() - self.evaluate_batch_pmapped = jax.pmap( - model.evaluate_batch, axis_name='batch') + self.evaluate_batch_pmapped = jax.jit( + model.evaluate_batch, donate_argnums=(2,) + ) self.batch_stats = batch_stats dataset, dataset_metadata = self._get_dataset(hps, rng) self.dataset = dataset + self.mesh = mesh model_class = models.get_model(callback_config['model_name']) self.inference_manager = inference.InferenceManager( @@ -102,7 +105,8 @@ def __init__(self, dataset, dataset_metadata, self.callback_config, - mode='online') + mode='online', + mesh=mesh) def _validate_callback_config(self): assert all(key in self.callback_config for key in _REQUIRED_KEYS), ( @@ -137,7 +141,7 @@ def _evaluate(self, params, batch_stats, batch_iter, - evaluate_batch_pmapped): + evaluate_batch_jitted): """Compute aggregated metrics on the given data iterator. This function is taken as is from trainer.py to avoid circular dependency. @@ -148,19 +152,25 @@ def _evaluate(self, batch_stats: A dict of batch_stats. batch_iter: Generator which yields batches. Must support the API for b in batch_iter: - evaluate_batch_pmapped: A function with API - evaluate_batch_pmapped(params, batch_stats, batch). Returns a dictionary - mapping keys to the metric values across the sharded batch. + evaluate_batch_jitted: A function with API evaluate_batch_jitted(params, + batch_stats, batch). Returns a dictionary mapping keys to the metric + values across the sharded batch. Returns: A dictionary of aggregated metrics. The keys will match the keys returned - by evaluate_batch_pmapped. + by evaluate_batch_jitted. """ metrics = None + make_global_array_fn = functools.partial( + utils.make_global_array, mesh=self.mesh + ) + for batch in batch_iter: - batch = data_utils.shard(batch) - computed_metrics = evaluate_batch_pmapped( - params=params, batch_stats=batch_stats, batch=batch) + batch = utils.maybe_remove_leading_dimension(batch) + batch = jax.tree_util.tree_map(make_global_array_fn, batch) + computed_metrics = evaluate_batch_jitted( + params=params, batch_stats=batch_stats, batch=batch + ) if metrics is None: metrics = computed_metrics else: @@ -169,7 +179,7 @@ def _evaluate(self, # For data splits with no data (e.g. Imagenet no test set) no values # will appear for that split. if metrics is not None: - metrics = metrics.unreplicate().compute() + metrics = metrics.compute() for key, val in metrics.items(): if np.isnan(val): raise utils.TrainingDivergedError('NaN detected in {}'.format(key)) @@ -191,7 +201,8 @@ def _merge_and_apply_prefix(self, d1, d2, prefix): d1[prefix+key] = d2[key] return d1 - def run_eval(self, params, batch_stats, optimizer_state, global_step): + def run_eval( + self, params, batch_stats, optimizer_state, global_step): """Runs the MT models to evals specified by MT model. Args: @@ -230,7 +241,7 @@ def run_eval(self, params, batch_stats, optimizer_state, global_step): self.inference_manager.translate_and_calculate_bleu_single_model( params, split_name)) split_metrics = self._evaluate(params, batch_stats, split_iter, - self.evaluate_batch_pmapped) + self.evaluate_batch_jitted) split_metrics['bleu_score'] = decoding_output.bleu_score metrics = self._merge_and_apply_prefix( diff --git a/init2winit/trainer_lib/base_trainer.py b/init2winit/trainer_lib/base_trainer.py index e2a0700a..60337f64 100644 --- a/init2winit/trainer_lib/base_trainer.py +++ b/init2winit/trainer_lib/base_trainer.py @@ -21,7 +21,6 @@ import time from absl import logging -from flax import jax_utils from init2winit import callbacks from init2winit import checkpoint from init2winit import schedules @@ -31,10 +30,13 @@ from init2winit.trainer_lib import trainer_utils from init2winit.training_metrics_grabber import make_training_metrics import jax +from jax.experimental import mesh_utils import numpy as np import optax import orbax.checkpoint as orbax_checkpoint +NamedSharding = jax.sharding.NamedSharding + class BaseTrainer(metaclass=abc.ABCMeta): """Abstract parent class for all trainers.""" @@ -200,13 +202,20 @@ def __init__( # During eval, we can donate the 'batch' buffer. We don't donate the # 'params' and 'batch_stats' buffers as we don't re-assign those values in # eval, we do that only in train. - self._evaluate_batch_pmapped = jax.pmap( - self._model.evaluate_batch, axis_name='batch', donate_argnums=(2,)) + self._evaluate_batch_jitted = jax.jit( + self._model.evaluate_batch, donate_argnums=(2,)) # Numpy array of range(0, local_device_count) to send to each device to be # folded into the RNG inside each train step to get a unique per-device RNG. self._local_device_indices = np.arange(jax.local_device_count()) + # Creates a 1-d mesh with all devices available globally. + mesh_shape = (jax.device_count(),) + self._mesh = jax.sharding.Mesh( + mesh_utils.create_device_mesh(mesh_shape, devices=jax.devices()), + axis_names=('devices',), + ) + def wait_until_orbax_checkpointer_finished(self): self._orbax_checkpointer.wait_until_finished() @@ -224,7 +233,7 @@ def setup_and_maybe_restore(self, init_rng, data_rng, trainer_update_fn): data_rng: the jax PRNGKey used for dataset randomness. Should be *different* across hosts! trainer_update_fn: the function for updating the model. If None, this will - skip pmapping the update function. + skip jitting the update function. Returns: A long tuple of the following: @@ -232,10 +241,14 @@ def setup_and_maybe_restore(self, init_rng, data_rng, trainer_update_fn): optimizer_update_fn: the optax update fn. metrics_update_fn: the optional metrics update fn. metrics_summary_fn: the optional metrics summary fn. - optimizer_state: the replicated optimizer state. - params: the replicated model parameters. - batch_stats: the replicated (optional) model batch statistics. - metrics_state: the replicated metric states. + (optimizer_state_sharding, optimizer_state): the replicated optimizer + state and corresponding sharding annotations. + (params_sharding, params): the replicated model parameters and + corresponding sharding annotations. + (batch_stats_sharding, batch_stats): the replicated (optional) model + batch statistics and corresponding sharding annotations. + (metrics_state_sharding, metrics_state) : the replicated metric states + and corresponding sharding annotations. global_step: the global step to start training at. sum_train_cost: the sum of the train costs. preemption_count: the number of times training has been preempted. @@ -299,10 +312,10 @@ def setup_and_maybe_restore(self, init_rng, data_rng, trainer_update_fn): unreplicated_batch_stats) ( - optimizer_state, - params, - batch_stats, - metrics_state, + (optimizer_state_sharding, optimizer_state), + (params_sharding, params), + (batch_stats_sharding, batch_stats), + (metrics_state_sharding, metrics_state), global_step, sum_train_cost, preemption_count, @@ -312,6 +325,7 @@ def setup_and_maybe_restore(self, init_rng, data_rng, trainer_update_fn): unreplicated_params, unreplicated_batch_stats, unreplicated_metrics_state, + self._mesh, train_dir=self._train_dir, external_checkpoint_path=self._external_checkpoint_path, orbax_checkpointer=self._orbax_checkpointer, @@ -367,31 +381,32 @@ def setup_and_maybe_restore(self, init_rng, data_rng, trainer_update_fn): grad_clip=self._hps.get('grad_clip'), optimizer_update_fn=optimizer_update_fn, metrics_update_fn=metrics_update_fn) - # in_axes = ( - # optimizer_state = 0, - # params = 0, - # batch_stats = 0, - # metrics_state = 0, - # batch = 0, - # step = None, - # lr = None, - # rng = None, - # local_device_index = 0, - # running_train_cost = 0, - # training_cost, - # grad_clip, - # optimizer_update_fn, - # metrics_state_update_fn) - # Also, we can donate buffers for 'optimizer', 'batch_stats', - # 'batch' and 'training_metrics_state' for update's pmapped computation. - update_pmapped = jax.pmap( + + # We donate optimizer_state, params and batch_stats in jitted computation. + # This helps reduce memory usage as outputs corresponding to these inputs + # arguments can re-use the memory. + update_jitted = jax.jit( update_fn, - axis_name='batch', - in_axes=(0, 0, 0, 0, 0, None, None, None, 0, 0), - donate_argnums=(0, 1, 2, 8), + donate_argnums=(0, 1, 2), + in_shardings=( + optimizer_state_sharding, + params_sharding, + batch_stats_sharding, + metrics_state_sharding, + NamedSharding(self._mesh, jax.sharding.PartitionSpec('devices')), + None, None, None, None + ), + out_shardings=( + optimizer_state_sharding, + params_sharding, + batch_stats_sharding, + None, + metrics_state_sharding, + None + ), ) else: - update_pmapped = None + update_jitted = None return ( lr_fn, @@ -399,14 +414,18 @@ def setup_and_maybe_restore(self, init_rng, data_rng, trainer_update_fn): metrics_update_fn, metrics_summary_fn, optimizer_state, + optimizer_state_sharding, params, + params_sharding, batch_stats, + batch_stats_sharding, metrics_state, + metrics_state_sharding, global_step, sum_train_cost, preemption_count, dataset, - update_pmapped) + update_jitted) def _setup_and_maybe_restore( self, init_rng, data_rng, callback_rng, trainer_update_fn): @@ -425,7 +444,7 @@ def _setup_and_maybe_restore( - initializing and maybe restoring self._sum_train_cost. - initializing and maybe restoring self._preemption_count. - setting self._dataset - - setting self._update_pmapped + - setting self._update_jitted - setting self._eval_callbacks Args: @@ -436,21 +455,25 @@ def _setup_and_maybe_restore( callback_rng: the jax PRNGKey used for eval callbacks. Should be *different* across hosts! trainer_update_fn: the function for updating the model. If None, this will - skip pmapping the update function. + skip jitting the update function. """ (self._lr_fn, self._optimizer_update_fn, self._metrics_update_fn, self._metrics_summary_fn, self._optimizer_state, + self._optimizer_state_sharding, self._params, + self._params_sharding, self._batch_stats, + self._batch_stats_sharding, self._metrics_state, + self._metrics_state_sharding, self._global_step, self._sum_train_cost, self._preemption_count, self._dataset, - self._update_pmapped) = self.setup_and_maybe_restore( + self._update_jitted) = self.setup_and_maybe_restore( init_rng, data_rng, trainer_update_fn) self._eval_callbacks = self._setup_eval_callbacks(callback_rng) @@ -482,7 +505,7 @@ def _setup_eval_callbacks(self, callback_rng): eval_callback = callbacks.get_callback(config['callback_name'])( self._model, self._params, self._batch_stats, self._optimizer_state, self._optimizer_update_fn, self._dataset, self._hps, config, - self._train_dir, rng) + self._train_dir, rng, self._mesh) eval_callbacks.append(eval_callback) return eval_callbacks @@ -546,9 +569,6 @@ def _eval( """ time_since_last_eval = time.time() - self._time_at_prev_eval_end - self._batch_stats = trainer_utils.maybe_sync_batchnorm_stats( - self._batch_stats - ) if self._eval_use_ema: if isinstance( @@ -574,7 +594,8 @@ def _eval( self._eval_num_batches, self._test_num_batches, self._eval_train_num_batches, - self._evaluate_batch_pmapped) + self._evaluate_batch_jitted, + self._mesh) self._run_eval_callbacks(report) if save: self._save(self._train_dir) @@ -583,10 +604,10 @@ def _eval( run_time = time.time() - self._time_at_prev_eval_end steps_per_sec = steps_since_last_eval / run_time - mean_train_cost = jax.lax.pmean(self._sum_train_cost, axis_name=[])[ - 0 - ].item() / max(1, self._global_step - self._prev_eval_step) - self._sum_train_cost = jax_utils.replicate(0.0) + mean_train_cost = self._sum_train_cost / max( + 1, self._global_step - self._prev_eval_step + ) + self._sum_train_cost = 0.0 epoch = self._global_step * self._hps.batch_size // self._hps.train_size overall_steps_per_sec = self._get_step_frequency( self._global_step, start_step, start_time) diff --git a/init2winit/trainer_lib/trainer.py b/init2winit/trainer_lib/trainer.py index e4ddbff3..65166423 100644 --- a/init2winit/trainer_lib/trainer.py +++ b/init2winit/trainer_lib/trainer.py @@ -14,20 +14,20 @@ # limitations under the License. """Standard trainer for the init2winit project.""" +import functools import itertools import time from absl import logging from init2winit import utils from init2winit.model_lib import model_utils -from init2winit.optimizer_lib import utils as optimizer_utils from init2winit.trainer_lib import base_trainer from init2winit.trainer_lib import trainer_utils import jax -from jax import lax import jax.numpy as jnp import optax + _GRAD_CLIP_EPS = 1e-6 @@ -40,16 +40,14 @@ def update( step, lr, rng, - local_device_index, running_train_cost, training_cost, grad_clip, optimizer_update_fn, - metrics_update_fn, - axis_name='batch'): + metrics_update_fn): """Single step of the training loop. - This function will later be pmapped so we keep it outside of the Trainer class + This function will later be jitted so we keep it outside of the Trainer class to avoid the temptation to introduce side-effects. Args: @@ -67,9 +65,6 @@ def update( lr: the floating point learning rate for this step. rng: the RNG used for calling the model. `step` and `local_device_index` will be folded into this to produce a unique per-device, per-step RNG. - local_device_index: an integer that is unique to this device amongst all - devices on this host, usually in the range [0, jax.local_device_count()]. - It is folded in to `rng` to produce a unique per-device, per-step RNG. running_train_cost: the cumulative train cost over some past number of train steps. Reset at evaluation time. training_cost: a function used to calculate the training objective that will @@ -80,7 +75,6 @@ def update( value g / ||g||_2 * grad_clip. If None, then no clipping will be applied. optimizer_update_fn: the optimizer update function. metrics_update_fn: the training metrics update function. - axis_name: axis_name used by pmap. Returns: A tuple of the new optimizer, the new batch stats, the scalar training cost, @@ -89,7 +83,6 @@ def update( # `jax.random.split` is very slow outside the train step, so instead we do a # `jax.random.fold_in` here. rng = jax.random.fold_in(rng, step) - rng = jax.random.fold_in(rng, local_device_index) optimizer_state = trainer_utils.inject_learning_rate(optimizer_state, lr) @@ -99,21 +92,8 @@ def opt_cost(params): grad_fn = jax.value_and_grad(opt_cost, has_aux=True) (cost_value, new_batch_stats), grad = grad_fn(params) - new_batch_stats = new_batch_stats.get('batch_stats', None) - if axis_name is not None: - if optimizer_utils.requires_gradient_aggregation(optimizer_update_fn): - grad = lax.pmean((grad), axis_name=axis_name) - else: - # Skip gradient aggregationas it'll be handled in gradient_accumulator. - if grad_clip: - # Calculating the gradient norm requires cross-device aggregation, - # performed, in this case, inside the optimizer. Calculating it again - # at this point may be inefficient. - raise NotImplementedError( - 'Gradient clipping is not supported when gradient aggregation is' - ' performed internally by the optimizer.' - ) + new_batch_stats = new_batch_stats.get('batch_stats', None) grad_norm = jnp.sqrt(model_utils.l2_regularization(grad, 0)) # TODO(znado): move to inside optax gradient clipping. @@ -131,6 +111,7 @@ def opt_cost(params): cost_fn=opt_cost, grad_fn=grad_fn, value=cost_value) + new_params = optax.apply_updates(params, model_updates) new_metrics_state = None @@ -183,9 +164,6 @@ def train(self): ) - train_iter = trainer_utils.prefetch_input_pipeline( - train_iter, self._hps.num_device_prefetches) - if self._data_selector: train_iter = self._data_selector( train_iter, @@ -207,25 +185,40 @@ def train(self): if self._global_step in self._checkpoint_steps: self._save(self._checkpoint_dir, max_to_keep=None) + make_global_array_fn = functools.partial( + utils.make_global_array, mesh=self._mesh + ) + for _ in range(start_step, self._num_train_steps): - with jax.profiler.StepTraceAnnotation('train', - step_num=self._global_step): + with jax.profiler.StepTraceAnnotation( + 'train', step_num=self._global_step + ): # NOTE(dsuo): to properly profile each step, we must include batch # creation in the StepTraceContext (as opposed to putting `train_iter` # directly in the top-level for loop). batch = next(train_iter) + batch = jax.tree_util.tree_map(make_global_array_fn, batch) lr = self._lr_fn(self._global_step) - # It looks like we are reusing an rng key, but we aren't. - # TODO(gdahl): Make it more obvious that passing rng is safe. - # TODO(gdahl,gilmer,znado): investigate possibly merging the member - # variable inputs/outputs of this function into a named tuple. - (self._optimizer_state, self._params, self._batch_stats, - self._sum_train_cost, - self._metrics_state, self._grad_norm) = self._update_pmapped( - self._optimizer_state, self._params, self._batch_stats, - self._metrics_state, batch, self._global_step, lr, rng, - self._local_device_indices, self._sum_train_cost) + + ( + self._optimizer_state, + self._params, + self._batch_stats, + self._sum_train_cost, + self._metrics_state, + self._grad_norm, + ) = self._update_jitted( + self._optimizer_state, + self._params, + self._batch_stats, + self._metrics_state, + batch, + self._global_step, + lr, + rng, + self._sum_train_cost, + ) self._global_step += 1 if self._global_step in self._checkpoint_steps: self._save(self._checkpoint_dir, max_to_keep=None) diff --git a/init2winit/trainer_lib/trainer_utils.py b/init2winit/trainer_lib/trainer_utils.py index ca3afa09..5b55b75b 100644 --- a/init2winit/trainer_lib/trainer_utils.py +++ b/init2winit/trainer_lib/trainer_utils.py @@ -14,14 +14,13 @@ # limitations under the License. """Utility functions related to training.""" + +import functools import time from absl import logging - from flax import jax_utils -from init2winit import utils -from init2winit.dataset_lib import data_utils -from init2winit.model_lib import model_utils +from init2winit.dataset_lib import data_utils as utils import jax import jax.numpy as jnp import numpy as np @@ -89,20 +88,6 @@ def maybe_log_training_metrics(metrics_state, prefix='metrics_state') -def maybe_sync_batchnorm_stats(batch_stats): - """Sync batch_stats across devices.""" - # We first check that batch_stats is used (pmap will throw an error if - # it's a non batch norm model). If batch norm is not used then - # batch_stats = None. Note that, in the case of using our implementation of - # virtual batch norm, this will also handle synchronizing the multiple moving - # averages on each device before doing a cross-host sync. - if batch_stats: - batch_stats = jax.pmap( - model_utils.sync_batchnorm_stats, axis_name='batch')( - batch_stats) - return batch_stats - - def should_eval(global_step, eval_frequency, eval_steps): on_step = eval_steps and global_step in eval_steps on_freq = (global_step % eval_frequency == 0) @@ -141,30 +126,12 @@ def check_for_early_stopping( ) -def prefetch_input_pipeline(ds, n_prefetch=0, devices=None): - """Modify input pipeline to prefetch from host to device. - - Args: - ds: tf.data pipeline - n_prefetch: number of items to prefetch - devices: devices to prefetch to - - Returns: - prefetching ds - - """ - it = iter(ds) - it = (data_utils.shard(x) for x in it) - if n_prefetch > 0: - it = jax_utils.prefetch_to_device(it, n_prefetch, devices=devices) - return it - - def evaluate( params, batch_stats, batch_iter, - evaluate_batch_pmapped): + evaluate_batch_jitted, + mesh): """Compute aggregated metrics on the given data iterator. WARNING: The caller is responsible for synchronizing the batch norm statistics @@ -184,25 +151,24 @@ def evaluate( {'batch_stats': batch_stats} into flax_module.apply(). batch_iter: Generator which yields batches. Must support the API for b in batch_iter: - evaluate_batch_pmapped: A function with API - evaluate_batch_pmapped(params, batch_stats, batch). Returns a dictionary - mapping keys to the metric values across the sharded batch. + evaluate_batch_jitted: A function with API evaluate_batch_jitted(params, + batch_stats, batch). Returns a dictionary mapping keys to the metric + values across the sharded batch. + mesh: Mesh specification to use for sharding. Returns: A dictionary of aggregated metrics. The keys will match the keys returned by - evaluate_batch_pmapped. + evaluate_batch_jitted. """ metrics = None + make_global_array_fn = functools.partial(utils.make_global_array, mesh=mesh) + for batch in batch_iter: - batch = data_utils.shard(batch) + batch = jax.tree_util.tree_map(make_global_array_fn, batch) # Returns a clu.metrics.Collection object. We assume that - # `evaluate_batch_pmpapped` calls CLU's `gather_from_model_outputs`, - # which includes an `all_gather` to replicate the values on all devices. - # We need to `unreplicate` before merging the results across batches to - # accommodate CollectingMetric, which concatenates the values across the - # leading dimension, so we need to remove the leading shard dimension first. - computed_metrics = evaluate_batch_pmapped( - params=params, batch_stats=batch_stats, batch=batch).unreplicate() + # `evaluate_batch_jitted` calls CLU's `single_from_model_outputs`. + computed_metrics = evaluate_batch_jitted( + params=params, batch_stats=batch_stats, batch=batch) if metrics is None: metrics = computed_metrics else: @@ -266,7 +232,7 @@ def fetch_learning_rate(optimizer_state): ) if all_equal: lr_array = lrs_with_path[0][1] - return lr_array[0] + return lr_array else: raise ValueError( 'All learning rates in the optimizer state must be the same.' @@ -284,7 +250,7 @@ def _merge_and_apply_prefix(d1, d2, prefix): @utils.timed def eval_metrics(params, batch_stats, dataset, eval_num_batches, test_num_batches, eval_train_num_batches, - evaluate_batch_pmapped): + evaluate_batch_jitted, mesh): """Evaluates the given network on the train, validation, and test sets. WARNING: we assume that `batch_stats` has already been synchronized across @@ -307,7 +273,8 @@ def eval_metrics(params, batch_stats, dataset, eval_num_batches, sets. Set to None to evaluate on the whole test set. eval_train_num_batches: (int) The batch size used for evaluating on train set. Set to None to evaluate on the whole training set. - evaluate_batch_pmapped: Computes the metrics on a sharded batch. + evaluate_batch_jitted: Computes the metrics on a sharded batch. + mesh: Mesh specification to use for sharding. Returns: A dictionary of all computed metrics. @@ -320,7 +287,7 @@ def eval_metrics(params, batch_stats, dataset, eval_num_batches, for split_iter, split_name in zip([train_iter, valid_iter, test_iter], ['train', 'valid', 'test']): split_metrics = evaluate(params, batch_stats, split_iter, - evaluate_batch_pmapped) + evaluate_batch_jitted, mesh) # Metrics are None if the dataset doesn't have that split if split_metrics is not None: metrics = _merge_and_apply_prefix(metrics, split_metrics,