From 7e255d2b731bc6dababf036338824d0d367e1d84 Mon Sep 17 00:00:00 2001 From: Sourabh Medapati Date: Tue, 24 Oct 2023 14:56:34 -0700 Subject: [PATCH] Internal PiperOrigin-RevId: 576293918 --- init2winit/callbacks.py | 4 + init2winit/full_batch_statistics_callback.py | 129 +++++++++++++++++++ 2 files changed, 133 insertions(+) create mode 100644 init2winit/full_batch_statistics_callback.py diff --git a/init2winit/callbacks.py b/init2winit/callbacks.py index 85fc33b7..e78c801c 100644 --- a/init2winit/callbacks.py +++ b/init2winit/callbacks.py @@ -15,6 +15,7 @@ """Registry for the available callbacks.""" +from init2winit import full_batch_statistics_callback from init2winit.hessian import hessian_callback from init2winit.hessian import model_debugger_callback from init2winit.mt_eval import mt_callback @@ -24,6 +25,9 @@ 'hessian': hessian_callback.HessianCallback, 'mt': mt_callback.MTEvaluationCallback, 'model_debugger': model_debugger_callback.ModelDebugCallback, + 'full_batch_statistics': ( + full_batch_statistics_callback.FullBatchStatisticsCallback + ), } diff --git a/init2winit/full_batch_statistics_callback.py b/init2winit/full_batch_statistics_callback.py new file mode 100644 index 00000000..dc7406bc --- /dev/null +++ b/init2winit/full_batch_statistics_callback.py @@ -0,0 +1,129 @@ +# coding=utf-8 +# Copyright 2023 The init2winit Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Callback for computing full batch statistics given set of params. +""" + + +import itertools +import os + +import flax +from init2winit import base_callback +from init2winit import checkpoint +from init2winit.dataset_lib import data_utils +import jax +import jax.numpy as jnp + + +class FullBatchStatisticsCallback(base_callback.BaseCallBack): + """Runs evals on MT models with datasets/params different than in training.""" + + def __init__(self, + model, + params, + batch_stats, + optimizer_state, + optimizer_update_fn, + dataset, + hps, + callback_config, + train_dir, + rng): + del optimizer_state + del optimizer_update_fn + del batch_stats + + self.dataset = dataset + self.model = model + self.hps = hps + self.callback_config = callback_config + self.rng = rng + self.save_path = os.path.join(train_dir, 'gradient_statistics/') + + num_batches_in_training_epoch = self.hps.train_size // self.hps.batch_size + self.train_iter = itertools.islice( + self.dataset.train_iterator_fn(), num_batches_in_training_epoch + ) + self.num_updates = 0 + + @jax.jit + def update(params, batch, batch_stats, dropout_rng): + def opt_cost(params): + return self.model.training_cost( + params, + batch=batch, + batch_stats=batch_stats, + dropout_rng=dropout_rng, + ) + + grad_fn = jax.value_and_grad(opt_cost, has_aux=True) + _, grad = grad_fn(params) + + grad = jax.lax.pmean(grad, axis_name='batch') + return grad + + self.pmapped_update = jax.pmap( + update, axis_name='batch', in_axes=(0, 0, 0, None)) + + def run_eval(self, params, batch_stats, optimizer_state, global_step): + """Computes gradient statistics from mini batches over full training data. + """ + del optimizer_state + unreplicated_params = flax.jax_utils.unreplicate(params) + + self.grad_mean = jax.tree_map(jnp.zeros_like, unreplicated_params) + self.grad_std = jax.tree_map(jnp.zeros_like, unreplicated_params) + + for batch in self.train_iter: + sharded_batch = data_utils.shard(batch) + grads = self.pmapped_update(params, sharded_batch, batch_stats, self.rng) + grads = flax.jax_utils.unreplicate(grads) + + self.grad_mean = jax.tree_util.tree_map( + lambda g_sum, g: g_sum + g, self.grad_mean, grads + ) + + self.grad_std = jax.tree_util.tree_map( + lambda g_squared, g: g_squared + g**2, self.grad_std, grads + ) + + self.num_updates += 1 + + self.grad_mean = jax.tree_util.tree_map( + lambda g_sum: g_sum / self.num_updates, self.grad_mean + ) + self.grad_std = jax.tree_util.tree_map( + lambda g_squared, g_mean: jnp.sqrt( # pylint: disable=g-long-lambda + g_squared / self.num_updates - g_mean**2 + ), + self.grad_std, + self.grad_mean, + ) + + state = dict( + grad_std=self.grad_std, + grad_mean=self.grad_mean, + step=global_step + ) + + checkpoint.save_checkpoint( + self.save_path, + step=global_step, + state=state, + prefix='measurement_', + max_to_keep=None) + + return {}