From 09bf1e3035af233ab2e334c72b1a735664dc8469 Mon Sep 17 00:00:00 2001 From: Aiden Grossman Date: Sat, 4 Jan 2025 02:41:41 +0000 Subject: [PATCH] Port token_model to TF2 This patch ports token_model to TF2. This patch is relatively simple as most of the functionality provided by the token_model class is already eager. Pull Request: https://github.com/google/gematria/pull/282 --- gematria/model/python/BUILD.bazel | 3 --- gematria/model/python/token_model.py | 23 +++++++++++------------ gematria/model/python/token_model_test.py | 7 ++----- 3 files changed, 13 insertions(+), 20 deletions(-) diff --git a/gematria/model/python/BUILD.bazel b/gematria/model/python/BUILD.bazel index 61c7b4e9..08d61f87 100644 --- a/gematria/model/python/BUILD.bazel +++ b/gematria/model/python/BUILD.bazel @@ -185,9 +185,6 @@ gematria_py_test( size = "small", timeout = "moderate", srcs = ["token_model_test.py"], - tags = [ - "manual", - ], deps = [ ":oov_token_behavior", ":token_model", diff --git a/gematria/model/python/token_model.py b/gematria/model/python/token_model.py index d5338578..17313d76 100644 --- a/gematria/model/python/token_model.py +++ b/gematria/model/python/token_model.py @@ -136,6 +136,14 @@ def __init__( if self._oov_token is None: raise ValueError(f'Token {replacement_token} was not found in tokens.') + token_list_array = np.frombuffer( + b'\0'.join(token.encode('utf-8') for token in self._token_list), + dtype=np.uint8, + ) + self._token_list_tensor = tf.constant( + token_list_array, name=TokenModel.TOKENS_TENSOR_NAME + ) + super().__init__(**kwargs) @property @@ -195,16 +203,7 @@ def validate_basic_blockTokens(self, block: basic_block.BasicBlock) -> bool: return False return True - def _create_tf_graph(self): + def _forward(self, feed_dict: model_base.FeedDict): """See base class.""" - super()._create_tf_graph() - # Convert the token list into an array of bytes. We need to go through NumPy - # because tf.constant() always treats a bytes() object as a string and can't - # use it with any other dtype. - token_list_array = np.frombuffer( - b'\0'.join(token.encode('utf-8') for token in self._token_list), - dtype=np.uint8, - ) - self._token_list_tensor = tf.constant( - token_list_array, name=TokenModel.TOKENS_TENSOR_NAME - ) + del feed_dict # Unused. + raise NotImplementedError() diff --git a/gematria/model/python/token_model_test.py b/gematria/model/python/token_model_test.py index 5718bb3d..835ab458 100644 --- a/gematria/model/python/token_model_test.py +++ b/gematria/model/python/token_model_test.py @@ -18,7 +18,7 @@ from gematria.model.python import oov_token_behavior from gematria.model.python import token_model from gematria.testing.python import model_test -import tensorflow.compat.v1 as tf +import tensorflow as tf _OutOfVocabularyTokenBehavior = oov_token_behavior.OutOfVocabularyTokenBehavior @@ -66,9 +66,7 @@ def test_token_list_tensor(self): model.initialize() self.assertSequenceEqual(model._token_list, self.tokens) - with self.session() as sess: - raw_token_list = sess.run(model.token_list_tensor) - token_list = raw_token_list.tobytes().split(b'\0') + token_list = bytes(model.token_list_tensor).split(b'\0') self.assertLen(token_list, len(set(token_list))) for token in self.tokens: self.assertIn(token.encode(), token_list) @@ -167,5 +165,4 @@ def test_validate_basic_block_with_replacement_token(self): if __name__ == '__main__': - tf.disable_v2_behavior() tf.test.main()