Skip to content

Commit

Permalink
upgrading init2winit from pmap to jit
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 673695362
  • Loading branch information
sourabh2k15 authored and copybara-github committed Dec 6, 2024
1 parent d008bcf commit a9f1966
Show file tree
Hide file tree
Showing 15 changed files with 256 additions and 339 deletions.
4 changes: 2 additions & 2 deletions init2winit/base_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
callback_builder = callbacks.get_callback(config['callback_name'])
callback = callback_builder(model, params, batch_stats, optimizer_state,
dataset, hps, config, train_dir, rng)
dataset, hps, config, train_dir, rng, mesh)
callback_metrics = callback.run_eval(params, batch_stats,
optimizer_state, global_step).
Expand All @@ -39,7 +39,7 @@ class BaseCallBack:

def __init__(self, model, params, batch_stats, optimizer_state,
optimizer_update_fn, dataset, hps, callback_config, train_dir,
rng):
rng, mesh):
"""Defines the API for callback construction."""
pass

Expand Down
99 changes: 27 additions & 72 deletions init2winit/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
This is useful for training neural networks with stax, where model parameters
are nested numpy arrays.
"""

import os
import sys
from typing import Sequence

from absl import flags
from absl import logging
from flax import jax_utils
from flax.training import checkpoints as flax_checkpoints
from init2winit.dataset_lib import data_utils as utils
import jax

FLAGS = flags.FLAGS
Expand All @@ -44,47 +44,12 @@ def load_pytree(pytree_file, orbax_checkpointer=None):
return None


def replicate_checkpoint(
latest,
pytree_keys: Sequence[str],
replicate=True):
"""Restores from the provided checkpoint.
Args:
latest: A dict representing the state of the
checkpoint we want to restore.
pytree_keys: A sequence of keys into `latest` that are pytrees, which will
be replicated if replicate=True.
replicate: If set, replicate the state across devices.
Returns:
Tuple of (pytree, extra_dict) where pytree is a JAX pytree holding the
arrays that need to be replicated/unreplicated and extra_dict holds any
additional python state. We expect extra_dict to have the keys of
'global_step', 'preemption_count', 'sum_train_cost', but old checkpoints
might be missing something.
"""
logging.info('Loaded model parameters from latest checkpoint.')
# Old checkpoints without 'sum_train_cost' can still be restored, but the
# train() function will break. Evals and curvature stuff should be fine,
# however.
expected = ['global_step', 'preemption_count', 'sum_train_cost']
if any(k not in latest for k in expected):
logging.warn('Checkpoint state missing keys, obtained %s expected %s',
list(latest.keys()), expected)

pytree = {k: latest[k] for k in pytree_keys}
if replicate:
pytree = jax_utils.replicate(pytree)
extra_dict = {k: latest[k] for k in latest.keys() if k not in pytree_keys}
return pytree, extra_dict


def replicate_and_maybe_restore_checkpoint(
unreplicated_optimizer_state,
unreplicated_params,
unreplicated_batch_stats,
unreplicated_training_metrics_state,
mesh,
train_dir,
external_checkpoint_path=None,
orbax_checkpointer=None):
Expand All @@ -104,6 +69,7 @@ def replicate_and_maybe_restore_checkpoint(
unreplicated_params: unreplicated params
unreplicated_batch_stats: unreplicated batch stats
unreplicated_training_metrics_state: unreplicated metrics state
mesh: Mesh specification to use for sharding.
train_dir: (str) The training directory where we will look for a checkpoint.
external_checkpoint_path: (str) If this argument is set, then we will load
the external checkpoint stored there.
Expand Down Expand Up @@ -165,43 +131,34 @@ def replicate_and_maybe_restore_checkpoint(
# Handle failure to load from external_checkpoint_path.
if ckpt_to_return['global_step'] == -1:
return (
jax_utils.replicate(unreplicated_optimizer_state),
jax_utils.replicate(unreplicated_params),
jax_utils.replicate(unreplicated_batch_stats),
jax_utils.replicate(unreplicated_training_metrics_state),
utils.shard_pytree(unreplicated_optimizer_state, mesh),
utils.shard_pytree(unreplicated_params, mesh),
utils.shard_pytree(unreplicated_batch_stats, mesh),
utils.shard_pytree(unreplicated_training_metrics_state, mesh),
0, # global_step
jax_utils.replicate(0), # sum_train_cost
0, # sum_train_cost
0, # preemption_count
False) # is_restored
else: # Else, don't restore from any checkpoint.
return (
jax_utils.replicate(unreplicated_optimizer_state),
jax_utils.replicate(unreplicated_params),
jax_utils.replicate(unreplicated_batch_stats),
jax_utils.replicate(unreplicated_training_metrics_state),
utils.shard_pytree(unreplicated_optimizer_state, mesh),
utils.shard_pytree(unreplicated_params, mesh),
utils.shard_pytree(unreplicated_batch_stats, mesh),
utils.shard_pytree(unreplicated_training_metrics_state, mesh),
0, # global_step
jax_utils.replicate(0), # sum_train_cost
0, # sum_train_cost
0, # preemption_count
False) # is_restored

pytree_dict, extra_state = replicate_checkpoint(
ckpt_to_return,
pytree_keys=[
'optimizer_state',
'params',
'batch_stats',
'training_metrics_grabber',
'sum_train_cost',
])
return (
pytree_dict['optimizer_state'],
pytree_dict['params'],
pytree_dict['batch_stats'],
pytree_dict['training_metrics_grabber'],
extra_state['global_step'],
pytree_dict['sum_train_cost'],
extra_state['preemption_count'],
is_restored)
utils.shard_pytree(ckpt_to_return['optimizer_state'], mesh),
utils.shard_pytree(ckpt_to_return['params'], mesh),
utils.shard_pytree(ckpt_to_return['batch_stats'], mesh),
utils.shard_pytree(ckpt_to_return['training_metrics_grabber'], mesh),
ckpt_to_return['global_step'], # global_step
ckpt_to_return['sum_train_cost'],
ckpt_to_return['preemption_count'], # preemption_count
is_restored) # is_restored


def save_unreplicated_checkpoint(
Expand All @@ -217,14 +174,12 @@ def save_unreplicated_checkpoint(
max_to_keep=1):
"""Saves pytree, step, preemption_count, and sum_train_cost to train_dir."""
logging.info('Saving checkpoint to ckpt_%d', global_step)
unreplicated_optimizer_state = jax.device_get(
jax_utils.unreplicate(optimizer_state))
unreplicated_params = jax.device_get(jax_utils.unreplicate(params))
unreplicated_batch_stats = jax.device_get(jax_utils.unreplicate(batch_stats))
unreplicated_optimizer_state = jax.device_get(optimizer_state)
unreplicated_params = jax.device_get(params)
unreplicated_batch_stats = jax.device_get(batch_stats)
unreplicated_training_metrics_state = jax.device_get(
jax_utils.unreplicate(training_metrics_state))
unreplicated_sum_train_cost = jax.device_get(
jax_utils.unreplicate(sum_train_cost))
training_metrics_state)
unreplicated_sum_train_cost = jax.device_get(sum_train_cost)
state = dict(global_step=global_step,
preemption_count=preemption_count,
sum_train_cost=unreplicated_sum_train_cost,
Expand Down
68 changes: 33 additions & 35 deletions init2winit/dataset_lib/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
"""Common code used by different models."""

import collections

import flax.linen as nn
import jax
from jax.nn import one_hot
from jax.sharding import PartitionSpec as P
import numpy as np


Dataset = collections.namedtuple('Dataset', [
'train_iterator_fn',
'eval_train_epoch',
Expand Down Expand Up @@ -143,40 +147,6 @@ def zero_pad(ar, pad_axis):
return padded_batch


def shard(batch, n_devices=None):
"""Prepares the batch for pmap by adding a leading n_devices dimension.
If all the entries are lists, assume they are already divided into n_devices
smaller arrays and stack them for pmapping. If all the entries are arrays,
assume they have leading dimension divisible by n_devices and reshape.
Args:
batch: A dict of arrays or lists of arrays
n_devices: If None, this will be set to jax.local_device_count().
Returns:
Sharded data.
"""
if n_devices is None:
n_devices = jax.local_device_count()

# TODO(mbadura): Specify a sharding function per dataset instead
# If entries in the batch dict are lists, then the data is already divided
# into n_devices chunks, so we need to stack them.
if all((isinstance(v, list) for v in batch.values())):
assert all(len(v) == n_devices for v in batch.values())
# transpose a dict of lists to a list of dicts
shards = [{k: v[i] for (k, v) in batch.items()} for i in range(n_devices)]
return jax.tree.map(lambda *vals: np.stack(vals, axis=0), shards[0],
*shards[1:])

# Otherwise, the entries are arrays, so just reshape them.
def _shard_array(array):
return array.reshape((n_devices, -1) + array.shape[1:])

return jax.tree.map(_shard_array, batch)


def tf_to_numpy(tfds_data):
# Safe because we won't mutate. Avoids an extra copy from tfds.
convert_data = lambda x: x._numpy() # pylint: disable=protected-access
Expand All @@ -187,4 +157,32 @@ def tf_to_numpy(tfds_data):
def convert_jax_to_tf_random_seed(jax_prng_key: jax.random.PRNGKey) -> int:
tf_seed = jax.random.bits(jax_prng_key)
return tf_seed



def make_global_array(local_data, mesh):
"""Util to combine per-host batches into a global batch array.
Args:
local_data: local data batch on host.
mesh: mesh specification to shard the data.
Returns:
global_array: global data batch.
"""
global_shape = (
local_data.shape[0] * jax.process_count(),
*local_data.shape[1:],
)
sharding = jax.NamedSharding(mesh, P('devices'))

global_array = jax.make_array_from_process_local_data(
sharding, local_data, global_shape
)
return global_array


def shard_pytree(pytree, mesh):
shardings = nn.get_sharding(pytree, mesh)
pytree = jax.device_put(pytree, shardings)

return shardings, pytree
9 changes: 4 additions & 5 deletions init2winit/dataset_lib/ogbg_molpcba.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,7 @@ def _get_batch_iterator(dataset_iter,
num_shards: How many devices we should be able to shard the batch into.
Yields:
Batch in the init2winit format. Each field is a list of num_shards separate
smaller batches.
Batch in the init2winit format.
"""
if not num_shards:
Expand Down Expand Up @@ -252,9 +251,9 @@ def _get_batch_iterator(dataset_iter,

if count == num_shards:
yield {
'inputs': graphs_shards,
'targets': labels_shards,
'weights': weights_shards
'inputs': jraph.batch(graphs_shards),
'targets': np.vstack(labels_shards),
'weights': np.vstack(weights_shards)
}

count = 0
Expand Down
20 changes: 8 additions & 12 deletions init2winit/dataset_lib/test_ogbg_molpcba.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,9 @@ def test_get_batch_pads_correctly(self):
dataset = _get_dataset(jax.random.PRNGKey(0))

batch = next(dataset.valid_epoch())
inputs = batch['inputs'][0]
inputs = batch['inputs']

# The first two graphs are in the first batch
self.assertLen(batch['inputs'], 1)
self.assertNDArrayNear(inputs.n_node[:2], np.array(NUMS_NODES[:2]), 1e-3)

# The graphs are padded to the right size
Expand All @@ -130,9 +129,9 @@ def test_get_batch_pads_correctly(self):
self.assertEqual(np.sum(inputs.n_edge), BATCH_SIZE * EDGES_SIZE_MULTIPLIER)

# Weights are zero at NaN labels and in padded examples
self.assertNDArrayNear(batch['weights'][0],
self.assertNDArrayNear(batch['weights'],
np.array([[1, 1], [0, 1], [0, 0]]), 1e-3)
self.assertFalse(np.any(np.isnan(batch['targets'][0])))
self.assertFalse(np.any(np.isnan(batch['targets'])))

def test_train_shuffle_is_deterministic(self):
"""Tests that shuffling of the train split is deterministic."""
Expand All @@ -144,19 +143,18 @@ def test_train_shuffle_is_deterministic(self):
batch_same = next(dataset_same.train_iterator_fn())
batch_different = next(dataset_different.train_iterator_fn())

self.assertAllClose(batch['inputs'][0], batch_same['inputs'][0])
self.assertNotAllClose(batch['inputs'][0], batch_different['inputs'][0])
self.assertAllClose(batch['inputs'], batch_same['inputs'])
self.assertNotAllClose(batch['inputs'], batch_different['inputs'])

def test_add_virtual_node(self):
"""Tests that adding a virtual node works correctly."""
dataset = _get_dataset(jax.random.PRNGKey(0), {'add_virtual_node': True})

batch = next(dataset.valid_epoch())
inputs = batch['inputs'][0]
inputs = batch['inputs']
num_nodes = np.array(NUMS_NODES[0])
num_edges = np.array(NUMS_EDGES[0])

self.assertLen(batch['inputs'], 1)
self.assertNDArrayNear(
inputs.n_node[0], np.array(num_nodes + 1), 1e-3)
self.assertNDArrayNear(
Expand All @@ -173,11 +171,10 @@ def test_add_bidirectional_edges(self):
jax.random.PRNGKey(0), {'add_bidirectional_edges': True})

batch = next(dataset.valid_epoch())
inputs = batch['inputs'][0]
inputs = batch['inputs']
num_nodes = np.array(NUMS_NODES[0])
num_edges = np.array(NUMS_EDGES[0])

self.assertLen(batch['inputs'], 1)
self.assertNDArrayNear(
inputs.n_node[0], np.array(num_nodes), 1e-3)
self.assertNDArrayNear(
Expand All @@ -188,11 +185,10 @@ def test_add_self_loops(self):
dataset = _get_dataset(jax.random.PRNGKey(0), {'add_self_loops': True})

batch = next(dataset.valid_epoch())
inputs = batch['inputs'][0]
inputs = batch['inputs']
num_nodes = np.array(NUMS_NODES[0])
num_edges = np.array(NUMS_EDGES[0])

self.assertLen(batch['inputs'], 1)
self.assertNDArrayNear(
inputs.n_node[0], np.array(num_nodes), 1e-3)
self.assertNDArrayNear(
Expand Down
8 changes: 2 additions & 6 deletions init2winit/model_lib/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def _evaluate_batch(flax_module, params, batch_stats, batch, metrics_bundle,

# We don't use CLU's `mask` argument here, we handle it ourselves through
# `weights`.
return metrics_bundle.gather_from_model_output(
logits=logits, targets=targets, weights=weights, axis_name='batch')
return metrics_bundle.single_from_model_output(
logits=logits, targets=targets, weights=weights)


class BaseModel(object):
Expand Down Expand Up @@ -300,10 +300,6 @@ def training_objective_fn(self, params, logits, targets, weights):
logits, targets, weights
)

(objective_numerator, objective_denominator) = jax.lax.psum(
(objective_numerator, objective_denominator), axis_name='batch'
)

# epsilon added to handle empty batch case if we encounter one.
objective_value = objective_numerator / (objective_denominator + 1e-9)
if self.hps.get('l2_decay_factor'):
Expand Down
Loading

0 comments on commit a9f1966

Please sign in to comment.