From 5e78770ee58fa8256ed7af1c0bec92833ca09174 Mon Sep 17 00:00:00 2001 From: Michal Januszewski Date: Tue, 3 Sep 2024 06:27:14 -0700 Subject: [PATCH] Move JAX FFN training to the OSS repository. PiperOrigin-RevId: 670532613 --- ffn/jax/input_pipeline.py | 481 +++++++++++++++++++++++++++ ffn/jax/main.py | 43 +++ ffn/jax/tracker.py | 64 ++++ ffn/jax/train.py | 678 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 1266 insertions(+) create mode 100644 ffn/jax/input_pipeline.py create mode 100644 ffn/jax/main.py create mode 100644 ffn/jax/tracker.py create mode 100644 ffn/jax/train.py diff --git a/ffn/jax/input_pipeline.py b/ffn/jax/input_pipeline.py new file mode 100644 index 0000000..4ed99d2 --- /dev/null +++ b/ffn/jax/input_pipeline.py @@ -0,0 +1,481 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Input pipeline for FFN models.""" + +from concurrent import futures +import functools +import threading +from typing import Any, Callable + +from absl import logging +from connectomics.common import utils +from ffn.input import volume +from ffn.jax import tracker +from ffn.training import examples +from ffn.training import inputs +from ffn.training import model as ffn_model +import jax +import jmp +import ml_collections +import numpy as np +import tensorflow as tf + + +# TODO(mjanusz): Check if this is still required. +tf.config.threading.set_inter_op_parallelism_threads(128) + +Dataset = tf.data.Dataset + + +def _check_thread_success(future: futures.Future[Any], err_msg: str = ''): + e = future.exception() + + if e is None: + return + + logging.error(err_msg) + raise e + + +def load_examples( + config: ml_collections.ConfigDict, + rng: jax.Array, + load_shape: volume.TupleXYZ, +) -> tuple[grain.TfMixtureDataLoader | tf.data.Dataset, int]: + """Loads a single training example.""" + + if config.get('train_coords'): + train_coords = config.train_coords + # ConfigDict keys cannot contain dots, so we allow specifying the dictionary + # as an iterable of key-value pairs. + if isinstance(train_coords, tuple) or isinstance(train_coords, list): + train_coords = {k: v for k, v in train_coords} + + sampling = volume.SamplingConfig(vsi_coords=train_coords) + else: + train_coords = config.arrayrec_coords + if isinstance(train_coords, tuple) or isinstance(train_coords, list): + train_coords = {k: v for k, v in train_coords} + + sampling = volume.SamplingConfig(arrayrecord_coords=train_coords) + # The shape of the loaded data has to be symmetric so that if we apply + # random reflections, the center point remains in the center. + effective_load_shape = tuple((np.array(load_shape) // 2 * 2 + 1).tolist()) + + cfg = volume.InputConfig( + sampling=sampling, + volumes={ + 'em': volume.VolumeConfig( + config.em_volumes, + load_shape=effective_load_shape, + filter_shape=effective_load_shape, + ), + # Note that it is actually valid for the patch to extend beyond the + # bounding box of the label volume, so we only check that the center + # voxel is within the labeled area. + 'seg': volume.VolumeConfig( + config.seg_volumes, + load_shape=effective_load_shape, + filter_shape=(1, 1, 1), + ), + 'oob': volume.VolumeConfig( + config.seg_volumes, load_shape=effective_load_shape, oob_mask=True + ), + }, + # TODO(mjanusz): Add support for rotation and simulated missing sections. + augmentation=volume.AugmentationConfig( + permutable_axes=config.permutable_axes, + reflectable_axes=config.reflectable_axes, + contrast_factor_range=config.contrast_factor_range, + brightness_factor_range=config.brightness_factor_range, + apply_adjustment_to=config.apply_adjustment_to, + ), + ) + + if config.loss_mask_volumes: + cfg.volumes['loss_mask'] = volume.VolumeConfig( + config.loss_mask_volumes, + load_shape=effective_load_shape, + default_value=1 if config.loss_mask_invert else 0, + ) + + if hasattr(config, 'loss_mass_relabel'): + cfg.volumes['loss_mask'].relabel_maps = config.loss_mask_relabel + + def _add_ffn_data(ex: volume.Example) -> volume.Example: + weights = tf.cast(ex['oob'], tf.float32) + + if config.loss_mask_volumes: + if config.loss_mask_invert: + loss_mask = tf.equal(ex['loss_mask'], 0) + else: + loss_mask = ex['loss_mask'] > 0 + + weights *= 1.0 - tf.cast(loss_mask, tf.float32) + + seg = ex['seg'] + center_val = seg[ + 0, # + seg.shape[1] // 2, # + seg.shape[2] // 2, # + seg.shape[3] // 2, # + 0, + ] + lom = tf.logical_and(seg > 0, tf.equal(seg, center_val)) + labels = inputs.soften_labels(lom) + + lx, ly, lz = load_shape + emt = tf.cast(ex['em'][:, :lz, :ly, :lx, :], tf.float32) + if config.image_clip_value_max > 0.0: + emt = tf.clip_by_value(emt, 0.0, config.image_clip_value_max) + + return dict( + ex, + weights=weights[:, :lz, :ly, :lx, :], + labels=labels[:, :lz, :ly, :lx, :], + patches=(emt - config.image_mean) / config.image_stddev, + ) + + batch_size = config.per_device_batch_size * jax.local_device_count() + + if cfg.sampling.vsi_coords: + num_examples = getattr(config, 'train_num_coords', 100_000_000) + ds = volume.load_and_augment_subvolumes(cfg, int(np.array(rng)[0])) + ds = ds.map(_add_ffn_data) + + options = tf.data.Options() + options.experimental_optimization.map_fusion = True + options.experimental_optimization.parallel_batch = True + options.experimental_optimization.map_parallelization = True + options.threading.private_threadpool_size = 256 + options.threading.max_intra_op_parallelism = 1 + # Temporary workaround. See b/179292577. + options.experimental_external_state_policy = ( + tf.data.experimental.ExternalStatePolicy.WARN + ) + options.experimental_deterministic = False + ds = ds.with_options(options) + ds = ds.batch(batch_size) + ds = ds.prefetch(config.tf_data_prefetch_size) + else: + ds, num_examples = volume.grain_load_and_augment_subvolumes( + cfg, np.array(rng), _add_ffn_data, batch_size + ) + return ds, num_examples + + +def create_dataset( + config: ml_collections.ConfigDict, + seed: jax.Array, + load_shape: volume.TupleXYZ, + data_service_address: str | None = None, +) -> tuple[Dataset, int]: + """Creates a dataset for training. + + Args: + config: Configuration to use. + seed: PRNGKey for seeding operations in the training dataset. + load_shape: XYZ shape of the data patch to load from volumestore. + data_service_address: Unsupported. + + Returns: + Training dataset and the total number of examples. + """ + if data_service_address is not None: + raise NotImplementedError( + 'Support for tf.data service not implemented yet.' + ) + + ds, num_total_examples = load_examples(config, seed, load_shape) + return ds, num_total_examples + + +class BatchDictExampleIter(examples.BatchExampleIter): + """Replaces tuples with dicts.""" + + def __next__(self): + seeds, patches, labels, weights = super().__next__() + return {'seed': seeds, 'label': labels, 'patch': patches, 'weight': weights} + + +class MixingBatchExampleIter(BatchDictExampleIter): + """Like BatchDictExampleIter but with more examples in parallel. + + The total number of examples held in memory at a time is given by + num_batches * batch_size. A full batch is randomly selected out + of these examples at every training stap. This reduces correlations + between training batches in consecutive steps, and makes it possible + to prefetch data in the background. + """ + + # pylint: disable=super-init-not-called + def __init__( + self, + example_generator_fn: Callable[[], examples.ExampleGenerator], + eval_tracker: tracker.EvalTracker, + batch_size: int, + num_batches: int, + model_info: ffn_model.ModelInfo, + batch_prefetch: int = 16, + jmp_policy: jmp.Policy | None = None, + ): + """Constructor. + + Args: + example_generator_fn: function returning a generator of single training + examples + eval_tracker: FFN eval tracker object + batch_size: number of examples per batch + num_batches: number of batches to hold in memory + model_info: FFN model info + batch_prefetch: number of batches to prefetch + jmp_policy: Optional Jax policy for mixed precision training + """ + assert num_batches > 1 + self._eval_tracker = eval_tracker + self._seeds: list[np.ndarray] = [] + # List of indices of self._generators that generated the current + # batch. + self._current_idx: list[int] = [] + self._batch_size = batch_size + self._info = model_info + self._jmp_policy = jmp_policy + + # Loading of individual training examples. + self._generators = [ + example_generator_fn() for _ in range(batch_size * num_batches) + ] + self._tpe = futures.ThreadPoolExecutor( + max_workers=batch_size * batch_prefetch + ) + self._fs_lock = threading.Lock() + self._fs = set() + for i, gen in enumerate(self._generators): + self._fs.add(self._tpe.submit(lambda gen=gen, i=i: (i, next(gen)))) + + # Prefetching of complete batches. + self._batch_tpe = futures.ThreadPoolExecutor(max_workers=batch_prefetch) + self._batch_fs = set() + for i in range(batch_prefetch): + self._batch_fs.add(self._batch_tpe.submit(self._generate_batch)) + + self._seed_update_tpe = futures.ThreadPoolExecutor(max_workers=4) + + def _generate_batch(self): + """Returns a batch of training examples.""" + seeds, patches, labels, weights, batch_ids = [], [], [], [], [] + + while len(batch_ids) < self._batch_size: + with self._fs_lock: + for f in futures.as_completed(self._fs): + self._fs.remove(f) + i, (seed, patch, label, weight) = f.result() + seeds.append(seed) + patches.append(patch) + labels.append(label) + weights.append(weight) + batch_ids.append(i) + + if len(batch_ids) == self._batch_size: + break + + batched_seeds = np.concatenate(seeds) + batched_weights = np.concatenate(weights) + batched_labels = np.concatenate(labels) + batched_patches = np.concatenate(patches) + + if self._jmp_policy is not None: + batched_patches = self._jmp_policy.cast_to_compute(batched_patches) + batched_seeds = self._jmp_policy.cast_to_compute(batched_seeds) + + return ( + batch_ids, + seeds, + batched_seeds, + batched_weights, + batched_labels, + batched_patches, + ) + + def __next__(self): + # The time reported here indicates how long the training script had to + # wait to get a new batch of examples. It should ideally be ~0. + with utils.report_time('MixingBatchExampleIter'): + f = next(futures.as_completed(self._batch_fs)) + + self._batch_fs.remove(f) + self._batch_fs.add(self._batch_tpe.submit(self._generate_batch)) + + ( + self._current_idx, + self._seeds, + batched_seeds, + batched_weights, + batched_labels, + batched_patches, + ) = f.result() + self._eval_tracker.track_weights(batched_weights) + return { + 'seed': batched_seeds, + 'label': batched_labels, + 'patch': batched_patches, + 'weight': batched_weights, + } + + def update_seeds(self, batched_seeds: np.ndarray | jax.Array): + """Propagates data from `batched_seeds` back to the example generators.""" + + def _update( + seeds: list[np.ndarray], + batched_seeds: np.ndarray | jax.Array, + current: list[int], + ): + # Transfer data from device to host if using a JAX array. + batched_seeds = np.array(batched_seeds) + # Fold batch dimensions back to a single one. + batched_seeds = np.reshape( + batched_seeds, [-1] + list(batched_seeds.shape[-4:]) + ) + + dx = self._info.input_seed_size[0] - self._info.pred_mask_size[0] + dy = self._info.input_seed_size[1] - self._info.pred_mask_size[1] + dz = self._info.input_seed_size[2] - self._info.pred_mask_size[2] + + for i, _ in enumerate(current): + if dz == 0 and dy == 0 and dx == 0: + seeds[i][:] = batched_seeds[i, ...] + else: + seeds[i][ + :, # + dz // 2 : -(dz - dz // 2), # + dy // 2 : -(dy - dy // 2), # + dx // 2 : -(dx - dx // 2), # + :, + ] = batched_seeds[i, ...] + + with self._fs_lock: + for gen_idx in current: + gen = self._generators[gen_idx] + self._fs.add( + self._tpe.submit(lambda gen=gen, i=gen_idx: (i, next(gen))) + ) + + # Distribute data asynchronously. + update_future = self._seed_update_tpe.submit( + _update, self._seeds, batched_seeds, self._current_idx + ) + update_future.add_done_callback( + functools.partial( + _check_thread_success, err_msg='Error while updating seeds.' + ) + ) + + +class UnbatchIter: + """Fetches batches from a tf.data iterator and returns elements one by one. + + Iterating over tf.data appears to incur some overhead, so it's faster to + pull complete batches and unpack them here. + + The input arrays expected to be shaped [b, z, y, x, c], while the output + arrays are going to have shape [z, y, x, c]. + """ + + def __init__(self, batch_iter: tf.data.Iterator): + self._batch_iter = batch_iter + self._batch = None + self._idx = 0 + self._lock = threading.Lock() + + def __iter__(self): + return self + + def __next__(self): + with self._lock: + if self._batch is None: + self._idx = 0 + + # The time reported here will reflect delays caused by the tf.data + # pipeline. This does not impact training speed as long as + # MixingBatchExampleIter time is close to 0. + with utils.report_time('tf_data_input'): + ex = next(self._batch_iter) + + # Convert from EagerTensor to numpy. + self._batch = ( + np.array(ex['patches']), + np.array(ex['labels']), + np.array(ex['weights']), + np.array(ex['coord']), + np.array(ex['volname']), + ) + + ret = [x[self._idx] for x in self._batch] + + self._idx += 1 + if self._batch[0].shape[0] == self._idx: + self._batch = None + + return ret + + +def get_batch_iter( + data_iter: tf.data.Iterator, + eval_tracker: tracker.EvalTracker, + policy_fn: examples.GetOffsets, + model_info: ffn_model.ModelInfo, + config: ml_collections.ConfigDict, # + seed_shape: tuple[int, int, int], + batch_size: int, + jmp_policy: jmp.Policy | None = None, +) -> BatchDictExampleIter: + """Creates an iterator over batches of training examples.""" + + # Pull single examples from the TF DS iterator. + unbatched_iter = UnbatchIter(data_iter) + + def _load_example(): + return next(unbatched_iter) + + # Pull training examples from the source (_load_example) and generate + # FFN training examples (when FOV movements are made, there will be more + # than 1 training example per data item loaded from the input). + def _make_example(): + return examples.get_example( + _load_example, + eval_tracker, + model_info, + policy_fn, + config.seed_pad, + seed_shape=seed_shape, + ) + + # Instantiate multiple generators (_make_examples), and batch their outputs + # as they become available. + if config.mix_num_batches == 1: + return BatchDictExampleIter( + _make_example, eval_tracker, batch_size, model_info + ) + else: + return MixingBatchExampleIter( + _make_example, + eval_tracker, + batch_size, + config.mix_num_batches, + model_info, + config.host_num_batch_prefetch, + jmp_policy=jmp_policy, + ) diff --git a/ffn/jax/main.py b/ffn/jax/main.py new file mode 100644 index 0000000..6e3a798 --- /dev/null +++ b/ffn/jax/main.py @@ -0,0 +1,43 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Main file for FFN models.""" + +from typing import Sequence + +from absl import app +from absl import flags + +from connectomics.jax import training +from ffn.jax import train +import jax + +FLAGS = flags.FLAGS + +training.define_training_flags() + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + + training.prep_training() + train.train_and_evaluate(FLAGS.config, FLAGS.workdir, FLAGS.service_address) + + +if __name__ == '__main__': + # Provide access to --jax_backend_target and --jax_xla_backend flags. + jax.config.config_with_absl() + app.run(main) diff --git a/ffn/jax/tracker.py b/ffn/jax/tracker.py new file mode 100644 index 0000000..f4fc73c --- /dev/null +++ b/ffn/jax/tracker.py @@ -0,0 +1,64 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Pure numpy adaptation of the FFN tracker. + +This makes the training script independent from TF, other than for the +input pipeline. +""" + +from ffn.training import tracker +import numpy as np + + +class Variable: + """Variable keeping its value as a numpy array.""" + + def __init__(self, shape, dtype): + self._value = np.zeros(shape, dtype=dtype.as_numpy_dtype) + + @property + def tf_value(self): + return self._value + + @property + def from_tf(self): + return None + + @property + def value(self): + return self._value + + def to_tf(self, ops, feed_dict): + pass + + def reset(self): + self._value[:] = 0. + + +class EvalTracker(tracker.EvalTracker): + """Eval tracker using numpy variables.""" + + def _add_tf_var(self, name, shape, dtype): + v = Variable(shape, dtype) + setattr(self, name, v) + self._tf_vars.append(v) + return v + + def to_tf(self): + pass + + def from_tf(self): + pass diff --git a/ffn/jax/train.py b/ffn/jax/train.py new file mode 100644 index 0000000..29ee016 --- /dev/null +++ b/ffn/jax/train.py @@ -0,0 +1,678 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Training script for FFN models.""" + +import collections +import functools as ft +import os +import random +import time +from typing import Any, Sequence, TypeVar + +from absl import logging +from clu import metric_writers +from clu import metrics +from clu import parameter_overview +from connectomics.jax import training +from etils import epath +from ffn.jax import input_pipeline +from ffn.jax import tracker +from ffn.training import examples +from ffn.training import model as ffn_model +import flax +import flax.jax_utils as flax_utils +import flax.linen as nn +from flax.training import checkpoints as flax_checkpoints +import jax +import jax.numpy as jnp +from jax.sharding import Mesh +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P +import jmp +import ml_collections +import numpy as np +import optax +import orbax.checkpoint as ocp +from scipy import special +from t5x.checkpoints import DatasetArgs +from t5x.checkpoints import DatasetCheckpointHandler +import tensorflow as tf + +from connectomics.jax.models import util as model_util + + +class TrainState(flax.struct.PyTreeNode): # pytype: disable=invalid-function-definition # dataclass_transform + step: int + opt_state: optax.OptState + params: flax.core.FrozenDict[str, Any] + batch_stats: Any + + +DataIterator = TypeVar( + 'DataIterator', tf.data.Iterator # +) + + +def create_train_state( + config: ml_collections.ConfigDict, + rng: jax.Array, + input_shape: Sequence[int], +) -> tuple[nn.Module, optax.Schedule, optax.GradientTransformation, TrainState]: + """Instantiates and initializes the model. + + Args: + config: Configuration for model. + rng: JAX PRNG Key. + input_shape: Shape of the inputs fed into the model. + + Returns: + The initialized TrainState with the optimizer. + """ + model = model_util.model_from_config(config) + rng = {'params': rng, 'dropout': jax.random.PRNGKey(1)} + variables = model.init(rng, jnp.ones(input_shape)) + params = variables['params'] + + parameter_overview.log_parameter_overview(params) + tx, lr = training.get_optimizer(config) + + return ( + model, + lr, + tx, + TrainState( + step=0, + opt_state=tx.init(params), + batch_stats=variables.get('batch_stats', None), + params=params, + ), + ) + + +@flax.struct.dataclass +class TrainMetrics(metrics.Collection): + loss: metrics.Average.from_output('loss') + loss_std: metrics.Std.from_output('loss') + learning_rate: metrics.LastValue.from_output('learning_rate') + + +def _updated_seed(seed: jnp.ndarray, update: jnp.ndarray) -> jnp.ndarray: + """Applies the additive `update` to `seed`.""" + dz = seed.shape[-4] - update.shape[-4] + dy = seed.shape[-3] - update.shape[-3] + dx = seed.shape[-2] - update.shape[-2] + + logging.log_first_n( + logging.INFO, 'Updating seed: %r with: %r', 1, seed.shape, update.shape + ) + + if dz == 0 and dy == 0 and dx == 0: + return seed + update + else: + raise ValueError( + 'Currently only models with same input and output shapes are supported.' + ) + + return seed + jnp.pad( + update, + [ + [0, 0], # + [dz // 2, dz - dz // 2], # + [dy // 2, dy - dy // 2], + [dx // 2, dx - dx // 2], + [0, 0], + ], + ) + + +def train_step( + model: nn.Module, + state: TrainState, + schedule: optax.Schedule, + optimizer: optax.GradientTransformation, + batch: dict[str, jax.Array], # + config: ml_collections.ConfigDict, + dropout_rng: jax.Array, + jmp_policy: jmp.Policy | None = None, + loss_scale: jmp.LossScale = jmp.NoOpLossScale(), +) -> tuple[TrainState, metrics.Collection, jax.Array, jmp.LossScale | None]: + """Performs a single training step. + + Args: + model: Module to compute predictions. + state: Current training state. Updated training state will be returned. + schedule: optax learning rate schedule. + optimizer: optax optimizer. + batch: Training inputs for this step. + config: Configuration for model. + dropout_rng: RNG key for dropout. + jmp_policy: Jax policy for mixed precision training. + loss_scale: Loss scaling policy. + + Returns: + tuple of: updated state, dictionary with metrics, updated part of the + seed, updated loss scaling policy. + """ + step = state.step + 1 + dropout_rng = jax.random.fold_in(dropout_rng, step) + + def loss_fn(params): + variables = {'params': params} + if state.batch_stats is not None: + variables['batch_stats'] = state.batch_stats + + data = jnp.concatenate((batch['patch'], batch['seed']), axis=-1) + kwargs = {} + if 'transformer' in config.model_class or 'mixer' in config.model_class: + kwargs['train'] = True + + logits, new_variables = model.apply( + variables, data, mutable=True, rngs={'dropout': dropout_rng}, **kwargs + ) + + if config.additive_seed_update: + logits = _updated_seed(batch['seed'], logits) + + loss = optax.sigmoid_binary_cross_entropy(logits, batch['label']) + # NOTE: When using float16s, overflows occur so we anyways use float32 here. + loss = jnp.mean(loss * batch['weight'], dtype=jnp.float32) + + loss = loss_scale.scale(loss) + return loss, (new_variables.get('batch_stats', None), logits) + + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + + params_copy = state.params + if jmp_policy is not None: + params_copy = jmp_policy.cast_to_compute(params_copy) + + (loss, (new_batch_stats, logits)), grad = grad_fn(params_copy) + + # Compute average gradient across multiple workers. + if jmp_policy is not None: + grad = jmp_policy.cast_to_param(grad) + + grad = loss_scale.unscale(grad) + updates, new_opt_state = optimizer.update(grad, state.opt_state, state.params) + new_params = optax.apply_updates(state.params, updates) + + # Dynamic loss scaling needs adjustment in order to actually be dynamic. + if config.skip_nonfinite_updates or config.dynamic_loss_scale: + grads_finite = jmp.all_finite(grad) + loss_scale = loss_scale.adjust(grads_finite) + + new_params, new_opt_state = jmp.select_tree( + grads_finite, + (new_params, new_opt_state), + (state.params, state.opt_state), + ) + + new_state = state.replace( # pytype: disable=attribute-error + step=step, + params=new_params, + opt_state=new_opt_state, + batch_stats=new_batch_stats, + ) + + lr = schedule(state.opt_state.count) # pytype: disable=attribute-error + metrics_update = TrainMetrics.single_from_model_output( + loss=loss, learning_rate=lr + ) + + return new_state, metrics_update, logits, loss_scale + + +def fov_moves(config: ml_collections.ConfigDict) -> int: + if config.fov_policy == 'max_pred_moves': + # Add one more move to get a better fill of the evaluation area. + return config.fov_moves + 1 + else: + return config.fov_moves + + +def train_image_size( + info: ffn_model.ModelInfo, config: ml_collections.ConfigDict +) -> np.ndarray: + return np.array(info.input_image_size) + np.array( + info.deltas + ) * 2 * fov_moves(config) + + +def train_canvas_size( + info: ffn_model.ModelInfo, config: ml_collections.ConfigDict +) -> np.ndarray: + return np.array(info.input_seed_size) + np.array(info.deltas) * 2 * fov_moves( + config + ) + + +def train_eval_size( + info: ffn_model.ModelInfo, config: ml_collections.ConfigDict +) -> np.ndarray: + return np.array(info.pred_mask_size) + np.array(info.deltas) * 2 * fov_moves( + config + ) + + +def build_shifts( + config: ml_collections.ConfigDict, +) -> list[tuple[int, int, int]]: + """Builds a sequence of FOV shifts for the network.""" + shifts = [] + d = config.deltas + m = config.fov_moves + for dx in range(-m * d[0], m * d[0] + 1, max(d[0], 1)): + for dy in range(-m * d[1], m * d[1] + 1, max(d[1], 1)): + for dz in range(-m * d[2], m * d[2] + 1, max(d[2], 1)): + if dx == 0 and dy == 0 and dz == 0: + continue + shifts.append((dx, dy, dz)) + + if config.shuffle_fov_moves: + move_by_r = collections.defaultdict(list) + for x, y, z in shifts: + r = abs(x) + abs(y) + abs(z) + move_by_r[r].append((x, y, z)) + + # For multi-step moves, it is important to ensure that the + # locations closer to the center of the seed are covered + # before more distant ones.. + shifts = [] + for r, moves in sorted(move_by_r.items()): + random.shuffle(moves) + shifts.extend(moves) + + return shifts + + +def get_policy( + fov_shifts: list[tuple[int, int, int]], + info: ffn_model.ModelInfo, + config: ml_collections.ConfigDict, +) -> examples.GetOffsets: + """Returns a FOV movement policy function.""" + train_image_radius = train_image_size(info, config) // 2 + input_image_radius = np.array(info.input_image_size) // 2 + policy_map = { + 'fixed': ft.partial( + examples.fixed_offsets, + fov_shifts=fov_shifts, + threshold=special.logit(config.threshold), + ), + 'fixed_window': ft.partial( + examples.fixed_offsets_window, + fov_shifts=fov_shifts, + threshold=special.logit(config.threshold), + radius=8, + ), + 'max_pred_moves': ft.partial( + examples.max_pred_offsets, + max_radius=train_image_radius - input_image_radius, + threshold=special.logit(config.threshold), + ), + 'no_step': examples.no_offsets, + } + return policy_map[config.fov_policy] + + +def _get_tf_writer(writers) -> metric_writers.SummaryWriter | None: + # pylint:disable=protected-access + for writer in writers: + assert isinstance(writer, metric_writers.AsyncWriter) + if isinstance(writer._writer, metric_writers.SummaryWriter): + return writer._writer + # pylint:enable=protected-access + + +def _get_ocp_args(train_iter: DataIterator) -> DataIterator: + if isinstance(train_iter, tf.data.Iterator): + return DatasetArgs(train_iter) + + +def _make_ckpt_args(state, train_iter: DataIterator) -> ocp.args.CheckpointArgs: + return ocp.args.Composite( + train_state=ocp.args.StandardSave(state), + train_iter=_get_ocp_args(train_iter), + ) + + +def train_and_evaluate( + config: ml_collections.ConfigDict, + workdir: str, + data_service_address: str | None = None, +): + """Main training loop.""" + workdir = epath.Path(workdir) + workdir.mkdir(parents=True, exist_ok=True) + + rng = training.get_rng(config.seed) + + info = ffn_model.ModelInfo( + deltas=config.deltas, + pred_mask_size=config.fov_size, + input_seed_size=config.fov_size, + input_image_size=config.fov_size, + ) + + # Set up FFN FOV movement. + fov_shifts = build_shifts(config) + policy_fn = get_policy(fov_shifts, info, config) + + # Build input pipeline. + rng, data_rng = jax.random.split(rng) + data_seed = int( + jax.random.randint(data_rng, [], minval=0, maxval=np.iinfo(np.int32).max) + ) + random.seed(data_seed) + + train_ds, num_total_examples = input_pipeline.create_dataset( + config, + data_rng, + load_shape=tuple(train_image_size(info, config)), + data_service_address=data_service_address, + ) + train_iter = iter(train_ds) # pytype: disable=wrong-arg-types + + logging.info('train_elem_shape=%r', train_iter.element_spec['em'].shape) # pytype:disable=attribute-error + + # batch, z, y, x, (image, seed) + input_shape = [1] + np.array(info.input_image_size).tolist()[::-1] + [2] + + # Initialize model. + rng, model_rng = jax.random.split(rng) + model, schedule, optimizer, state = create_train_state( + config, model_rng, input_shape=input_shape + ) + rng, dropout_rng = jax.random.split(rng) + + item_handlers = {} + if isinstance(train_iter, tf.data.Iterator): + item_handlers = {'train_iter': DatasetCheckpointHandler('ckpt', True)} + + # Checkpointing init. + checkpoint_dir = epath.Path(workdir) / 'checkpoints' + checkpoint_manager = ocp.CheckpointManager( + checkpoint_dir, + item_names=('train_state', 'train_iter'), + item_handlers=item_handlers, + options=ocp.CheckpointManagerOptions( + save_interval_steps=config.checkpoint_every_steps + ), + ) + checkpointed_state = {'train_state': state, 'train_iter': train_iter} + latest_step = checkpoint_manager.latest_step() + # If an initial checkpoint is provided and the checkpointing library does not + # report a 'latest' checkpoint, then we are starting a new experiment. + # Otherwise an existing experiment is being resumed (e.g. after the training + # task being preempted) and the latest checkpoint should take precedence. + if config.init_from_cpoint and latest_step is None: + handler = ocp.StandardCheckpointHandler() + train_state_path = epath.Path(config.init_from_cpoint) / 'train_state' + train_iter_path = epath.Path(config.init_from_cpoint) / 'train_iter' + + if isinstance(train_iter, tf.data.Iterator): + iter_handler = item_handlers['train_iter'] + args = DatasetArgs(train_iter) + + checkpointed_state['train_state'] = handler.restore( + train_state_path, args=ocp.args.StandardRestore(state) + ) + checkpointed_state['train_iter'] = iter_handler.restore( + train_iter_path, args + ) + logging.info('Initializing training from %r', config.init_from_cpoint) + elif latest_step is not None: + checkpointed_state = checkpoint_manager.restore( + latest_step, + args=ocp.args.Composite( + train_state=ocp.args.StandardRestore(state), + train_iter=_get_ocp_args(train_iter), + ), + ) + logging.info('Restored checkpoint for step %d', latest_step) + + if latest_step is None: + logging.info('Starting training from scratch.') + # Save input config to CNS in addition to XM. + if jax.process_index() == 0: + with tf.io.gfile.GFile( + tf.io.gfile.join(workdir, 'config.json'), 'w' + ) as f: + f.write(config.to_json_best_effort() + '\n') + + # Data partitioning, if recovered from checkpoint, can be incompatible + # with the current setup. Avoid the problem by moving the state to the + # host. + state = jax.tree.map(np.array, checkpointed_state['train_state']) + train_iter = checkpointed_state['train_iter'] + initial_step = int(state.step) + 1 + + global_batch_size = config.per_device_batch_size * jax.device_count() + host_batch_size = config.per_device_batch_size * jax.local_device_count() + + # Upper bound. The real number will be lower as not all steps are + # taken for every example. + steps_per_epoch = ( + num_total_examples // global_batch_size * (len(fov_shifts) + 1) + ) + num_train_steps = steps_per_epoch * config.num_epochs + logging.info( + 'num_train_steps=%d, steps_per_epoch=%d', num_train_steps, steps_per_epoch + ) + + # Mixed precision settings. + jmp_policy = jmp.get_policy(config.mp_policy) if config.mp_policy else None + loss_scale = jmp.NoOpLossScale() + if config.loss_scale > 0: + if config.dynamic_loss_scale: + loss_scale = flax_utils.replicate( + jmp.DynamicLossScale(jnp.asarray(float(config.loss_scale))) + ) + else: + loss_scale = jmp.StaticLossScale(config.loss_scale) + + # Shard batch across devices. + mesh = Mesh(np.array(jax.devices()), ('batch',)) + batch_sharding = NamedSharding(mesh, P('batch')) + replicate_sharding = NamedSharding(mesh, P()) + logging.info('Device mesh: %r', mesh) + + def train_fn(state, batch, loss_scale): + return train_step( + model=model, + config=config, + schedule=schedule, + optimizer=optimizer, + jmp_policy=jmp_policy, + loss_scale=loss_scale, + dropout_rng=dropout_rng, + batch=batch, + state=state, + ) + + shard_in = ( + replicate_sharding, # state + batch_sharding, # data + replicate_sharding, # loss scale + ) + shard_out = ( + replicate_sharding, # state + replicate_sharding, # metrics + replicate_sharding, # logits + replicate_sharding, # loss scale + ) + p_train_step = jax.jit(train_fn, shard_in, shard_out) + + # Initialize summary writer. + writer = metric_writers.create_default_writer( + workdir, just_logging=jax.process_index() > 0 + ) + if initial_step == 1: + writer.write_hparams({ + k: v + for k, v in config.items() + if isinstance(v, (bool, float, int, str)) + }) + + logging.info('Starting training loop at step %d.', initial_step) + hooks = [] + report_progress = training.ReportProgress( + global_batch_size, num_train_steps=num_train_steps, writer=writer + ) + if jax.process_index() == 0: + hooks.append(report_progress) + + eval_shape_zyx = train_eval_size(info, config).tolist()[::-1] + eval_tracker = tracker.EvalTracker(eval_shape_zyx, fov_shifts) + + batch_iter = input_pipeline.get_batch_iter( + train_iter, + eval_tracker, + policy_fn, + info, + config, + seed_shape=tuple(train_canvas_size(info, config).tolist()[::-1]), + batch_size=host_batch_size, + jmp_policy=jmp_policy, + ) + + train_metrics = None + shutdown_request = False + timings = collections.defaultdict(list) + + def postprocess_batch(batch): + # Unpack batch dim into (device, batch). + def _reshape(x): + x = np.asarray(x) + per_device_data = np.split(x, len(mesh.local_devices), axis=0) + + on_dev = jax.device_put(per_device_data, mesh.local_devices) + global_shape = ( + len(batch_sharding.device_set) * config.per_device_batch_size, + ) + per_device_data[0].shape[1:] + return jax.make_array_from_single_device_arrays( + global_shape, batch_sharding, on_dev + ) + + return jax.tree.map( + _reshape, + { + 'patch': batch['patch'], + 'seed': batch['seed'], + 'label': batch['label'], + 'weight': batch['weight'], + }, + ) + + with metric_writers.ensure_flushes(writer): + # Record a summary scalar to indicate the specific steps at which + # training restarts occurred. + writer.write_scalars(initial_step, {'start': 1}) + + for step in range(initial_step, num_train_steps + 1): + is_last_step = step == num_train_steps + + with jax.profiler.StepTraceAnnotation('train', step_num=step): + with report_progress.timed('input', wait_jax_async_dispatch=False): + with training.MeasureTime(timings, 'data_load'): + batch = next(batch_iter) + + batch = postprocess_batch(batch) + + with training.MeasureTime(timings, 'train_step'): + state, metrics_update, updated_seed, loss_scale = p_train_step( + state, batch, loss_scale + ) + + logging.log_first_n( + logging.INFO, 'Updated seed shape: %r', 1, updated_seed.shape + ) + + with training.MeasureTime(timings, 'metrics'): + with jax.spmd_mode('allow_all'): + train_metrics = ( + metrics_update + if train_metrics is None + else train_metrics.merge(metrics_update) + ) + + with training.MeasureTime(timings, 'update_seed'): + batch_iter.update_seeds(updated_seed) # pytype: disable=wrong-arg-types # jnp-type + + with training.MeasureTime(timings, 'admin'): + if checkpoint_manager.should_save(step) or is_last_step: + logging.info('Saving checkpoint at %d.', step) + train_state = jax.tree.map(np.array, state) + checkpoint_manager.save( + step, args=_make_ckpt_args(train_state, train_iter) + ) + + if checkpoint_manager.reached_preemption(step): + logging.warn('Interrupting training loop due to shutdown request.') + logging.flush() + shutdown_request = True + break + + for h in hooks: + h(step) + + if step % config.log_loss_every_steps == 0 or is_last_step: + with jax.spmd_mode('allow_all'): + scalars = train_metrics.compute() + for name, values in timings.items(): + scalars[f'time_{name}'] = float(np.mean(values)) + scalars[f'time_{name}/min'] = float(np.min(values)) + scalars[f'time_{name}/max'] = float(np.max(values)) + + timings = collections.defaultdict(list) + raws = [] + for summ in eval_tracker.get_summaries(): + if summ.HasField('simple_value'): + scalars[summ.tag] = summ.simple_value + else: + s = tf.compat.v1.summary.Summary() + s.value.append(summ) + raws.append(s.SerializeToString()) + + writer.write_scalars(step, scalars) + if jax.process_index() == 0: + # pylint:disable=protected-access + tfw = _get_tf_writer(writer._writers) + assert tfw is not None + # TODO(mjanusz): Find a cleaner and less brittle way of saving + # raw summaries. + with tfw._summary_writer.as_default(): + for s in raws: + tf.summary.experimental.write_raw_pb(s, step=step) + + # pylint:enable=protected-access + + train_metrics = None + eval_tracker.reset() + + checkpoint_manager.wait_until_finished() + logging.info('Finished training at step %d.', step) + + if shutdown_request: + # Allow time for other workers to finish checkpoint saving. Soon after + # the first worker is terminated, it will be detected that the clique + # is no longer complete, which will cause an immediate restart of the + # current process via std::quick_exit(42). + time.sleep(60) + + # This return code causes Borglet to restart the binary without changing + # the state of the task as seen by the Borgmaster. + os._exit(42) # pylint:disable=protected-access