Skip to content

Commit

Permalink
Port token_model to TF2
Browse files Browse the repository at this point in the history
This patch ports token_model to TF2. This patch is relatively simple as
most of the functionality provided by the token_model class is already
eager.

Pull Request: google#282
  • Loading branch information
boomanaiden154 committed Jan 4, 2025
1 parent 19639e9 commit 09bf1e3
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 20 deletions.
3 changes: 0 additions & 3 deletions gematria/model/python/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,6 @@ gematria_py_test(
size = "small",
timeout = "moderate",
srcs = ["token_model_test.py"],
tags = [
"manual",
],
deps = [
":oov_token_behavior",
":token_model",
Expand Down
23 changes: 11 additions & 12 deletions gematria/model/python/token_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ def __init__(
if self._oov_token is None:
raise ValueError(f'Token {replacement_token} was not found in tokens.')

token_list_array = np.frombuffer(
b'\0'.join(token.encode('utf-8') for token in self._token_list),
dtype=np.uint8,
)
self._token_list_tensor = tf.constant(
token_list_array, name=TokenModel.TOKENS_TENSOR_NAME
)

super().__init__(**kwargs)

@property
Expand Down Expand Up @@ -195,16 +203,7 @@ def validate_basic_blockTokens(self, block: basic_block.BasicBlock) -> bool:
return False
return True

def _create_tf_graph(self):
def _forward(self, feed_dict: model_base.FeedDict):
"""See base class."""
super()._create_tf_graph()
# Convert the token list into an array of bytes. We need to go through NumPy
# because tf.constant() always treats a bytes() object as a string and can't
# use it with any other dtype.
token_list_array = np.frombuffer(
b'\0'.join(token.encode('utf-8') for token in self._token_list),
dtype=np.uint8,
)
self._token_list_tensor = tf.constant(
token_list_array, name=TokenModel.TOKENS_TENSOR_NAME
)
del feed_dict # Unused.
raise NotImplementedError()
7 changes: 2 additions & 5 deletions gematria/model/python/token_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from gematria.model.python import oov_token_behavior
from gematria.model.python import token_model
from gematria.testing.python import model_test
import tensorflow.compat.v1 as tf
import tensorflow as tf

_OutOfVocabularyTokenBehavior = oov_token_behavior.OutOfVocabularyTokenBehavior

Expand Down Expand Up @@ -66,9 +66,7 @@ def test_token_list_tensor(self):
model.initialize()

self.assertSequenceEqual(model._token_list, self.tokens)
with self.session() as sess:
raw_token_list = sess.run(model.token_list_tensor)
token_list = raw_token_list.tobytes().split(b'\0')
token_list = bytes(model.token_list_tensor).split(b'\0')
self.assertLen(token_list, len(set(token_list)))
for token in self.tokens:
self.assertIn(token.encode(), token_list)
Expand Down Expand Up @@ -167,5 +165,4 @@ def test_validate_basic_block_with_replacement_token(self):


if __name__ == '__main__':
tf.disable_v2_behavior()
tf.test.main()

0 comments on commit 09bf1e3

Please sign in to comment.