Skip to content

Commit

Permalink
Port training to TF2
Browse files Browse the repository at this point in the history
This patch ports the training class to TF2. Most of the class needs
little porting as the TrainingEpochStats class is already completely
eager. The partially_restore_from_checkpoint function did not updating
to use new TF APIs. Future patches will fix everything up to use proper
TF2 APIs, but for now this works. Some future work needs to be done when
it comes up in tests to restore the step value as TF2 no longer contains
global state.

Pull Request: google#283
  • Loading branch information
boomanaiden154 committed Jan 6, 2025
1 parent 867a7fd commit 4c4930e
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 84 deletions.
3 changes: 0 additions & 3 deletions gematria/model/python/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,6 @@ gematria_py_test(
name = "training_test",
size = "small",
srcs = ["training_test.py"],
tags = [
"manual",
],
deps = [
":training",
"//gematria/testing/python:basic_blocks_with_throughput",
Expand Down
25 changes: 10 additions & 15 deletions gematria/model/python/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def batches(


def partially_restore_from_checkpoint(
checkpoint_file: str, load_global_step_from_ckpt: bool, sess: tf.Session
checkpoint_file: str, load_step_from_ckpt: bool, model: tf.Module
) -> None:
"""Partially restores a checkpoint to the current graph.
Expand All @@ -311,29 +311,24 @@ def partially_restore_from_checkpoint(
Args:
checkpoint_file: A checkpoint to partially restore from.
load_global_step_from_ckpt: If True, load global step value from the given
checkpoint file.
sess: A TensorFlow session to restore into.
load_step_from_ckpt: If True, load the step value from the given checkpoint
file.
model: The tf.Module object representing the model that the weights should
be restored into.
"""
reader = tf.train.load_checkpoint(checkpoint_file)
shapes = reader.get_variable_to_shape_map()
dtypes = reader.get_variable_to_dtype_map()

if load_global_step_from_ckpt:
logging.info(
'Loading global step from checkpoint file: %s', checkpoint_file
)
global_step = tf.train.get_global_step()
global_step.load(reader.get_tensor('global_step'), sess)
if load_step_from_ckpt:
raise NotImplementedError()

for variable in tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope=None
):
for variable in model.trainable_variables:
# All variable names should end with ':0'; this ':0' is not used in the
# checkpoint.
if not variable.name.endswith(':0'):
continue
variable_name = variable.name[:-2]
variable_name = variable.name[:-2] + '/.ATTRIBUTES/VARIABLE_VALUE'
if variable_name not in shapes:
logging.info('%s not found in the checkpoint', variable_name)
continue
Expand All @@ -354,7 +349,7 @@ def partially_restore_from_checkpoint(
)
continue
logging.info('Restoring %s', variable_name)
variable.load(reader.get_tensor(variable_name), sess)
variable.load(reader.get_tensor(variable_name))


def _as_list(values: Sequence[float]) -> list[float]:
Expand Down
105 changes: 39 additions & 66 deletions gematria/model/python/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from gematria.model.python import training
from gematria.testing.python import basic_blocks_with_throughput
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow as tf


class TrainingEpochStatsTest(tf.test.TestCase):
Expand Down Expand Up @@ -442,85 +442,58 @@ def test_both_limits(self, with_throughput):
self.assertSequenceEqual(batches, expected_batches)


class DummyModel(tf.Module):
"""A test model that contains some trainable variables."""

def __init__(
self, initial_value: int, var1_spec, var2_spec, var3_spec, var4_spec
):
self.var1 = tf.Variable(
tf.cast(tf.fill(var1_spec.shape, initial_value), dtype=var1_spec.dtype),
name='var1',
)
self.var2 = tf.Variable(
tf.cast(tf.fill(var2_spec.shape, initial_value), dtype=var2_spec.dtype),
name='var2',
)
self.var3 = tf.Variable(
tf.cast(tf.fill(var3_spec.shape, initial_value), dtype=var3_spec.dtype),
name='var3',
)
self.var4 = tf.Variable(
tf.cast(tf.fill(var4_spec.shape, initial_value), dtype=var4_spec.dtype),
name='var4',
)


class PartiallyRestoreFromCheckpointTest(tf.test.TestCase):

def test_partially_restore(self):
checkpoint_file = os.path.join(tf.test.get_temp_dir(), 'checkpoint')
v1_name = 'var_1'
checkpoint_folder = os.path.join(self.get_temp_dir(), 'checkpoint')
v1_spec = tf.TensorSpec((3,), dtype=tf.dtypes.int32)

v2_name = 'var_2'
v2_spec_a = tf.TensorSpec((2, 2), dtype=tf.dtypes.float32)
v2_spec_b = tf.TensorSpec((2, 2), dtype=tf.dtypes.int32)

v3_name = 'var_3'
v3_spec_a = tf.TensorSpec((1, 3), dtype=tf.dtypes.float32)
v3_spec_b = tf.TensorSpec((2, 1), dtype=tf.dtypes.float32)

v4_name = 'var_4'
v4_spec = tf.TensorSpec((3,), dtype=tf.dtypes.int32)
with tf.Graph().as_default():
initializer = tf.initializers.constant(1)
v1 = tf.get_variable(
name=v1_name,
shape=v1_spec.shape,
dtype=v1_spec.dtype,
initializer=initializer,
)
v2 = tf.get_variable(
name=v2_name,
shape=v2_spec_a.shape,
dtype=v2_spec_a.dtype,
initializer=initializer,
)
v3 = tf.get_variable(
name=v3_name,
shape=v3_spec_a.shape,
dtype=v3_spec_a.dtype,
initializer=initializer,
)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saved_filename = saver.save(sess, checkpoint_file, global_step=0)

with tf.Graph().as_default():
initializer = tf.initializers.constant(2)
v1 = tf.get_variable(
name=v1_name,
shape=v1_spec.shape,
dtype=v1_spec.dtype,
initializer=initializer,
)
v2 = tf.get_variable(
name=v2_name,
shape=v2_spec_b.shape,
dtype=v2_spec_b.dtype,
initializer=initializer,
)
v3 = tf.get_variable(
name=v3_name,
shape=v3_spec_b.shape,
dtype=v3_spec_b.dtype,
initializer=initializer,
)
v4 = tf.get_variable(
name=v4_name,
shape=v4_spec.shape,
dtype=v4_spec.dtype,
initializer=initializer,
)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
training.partially_restore_from_checkpoint(saved_filename, False, sess)

v1d, v2d, v3d, v4d = sess.run((v1, v2, v3, v4))
self.assertAllEqual(v1d, [1, 1, 1])
self.assertAllEqual(v2d, [[2, 2], [2, 2]])
self.assertAllEqual(v3d, [[2], [2]])
self.assertAllEqual(v4d, [2, 2, 2])
model_a = DummyModel(1, v1_spec, v2_spec_a, v3_spec_a, v4_spec)
checkpoint = tf.train.Checkpoint(model_a)
model_a_save_path = checkpoint.save(checkpoint_folder)

model_b = DummyModel(2, v1_spec, v2_spec_b, v3_spec_b, v4_spec)
training.partially_restore_from_checkpoint(
model_a_save_path, False, model_b
)

self.assertAllEqual(model_b.var1, [1, 1, 1])
self.assertAllEqual(model_b.var2, [[2, 2], [2, 2]])
self.assertAllEqual(model_b.var3, [[2], [2]])
self.assertAllEqual(model_b.var4, [1, 1, 1])


if __name__ == '__main__':
tf.disable_v2_behavior()
tf.test.main()

0 comments on commit 4c4930e

Please sign in to comment.