Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port sequence_model to TF2 #286

Open
wants to merge 1 commit into
base: users/boomanaiden154/main.port-sequence_model-to-tf2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions gematria/sequence/python/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ gematria_py_test(
timeout = "moderate",
srcs = ["sequence_model_test.py"],
shard_count = 10,
tags = [
"manual",
],
deps = [
":sequence_model",
"//gematria/basic_block/python:basic_block",
Expand Down
40 changes: 13 additions & 27 deletions gematria/sequence/python/sequence_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ class SequenceModelBase(token_model.TokenModel, model_base.ModelBase):
_num_tokens_per_instruction_placeholder: tf.Tensor
_num_instructions_per_block_placeholder: tf.Tensor

def __init__(self, **kwargs):
super().__init__(**kwargs)
self._model = self._create_model()

@abc.abstractmethod
def _create_model(self) -> tf_keras.Model:
"""Creates the Keras model for this class.
Expand All @@ -101,35 +105,17 @@ def _create_model(self) -> tf_keras.Model:
The Keras model matching the input/output specification.
"""

def _create_tf_graph(self) -> None:
super()._create_tf_graph()
self._model = self._create_model()
self._token_sequence_placeholder = tf.placeholder(
dtype=tf.dtypes.int32,
shape=(None,),
name='SequenceModelBase.token_sequence',
)
self._num_tokens_per_instruction_placeholder = tf.placeholder(
dtype=tf.dtypes.int32,
shape=(None,),
name='SequenceModelBase.num_tokens_per_instruction',
)
self._num_instructions_per_block_placeholder = tf.placeholder(
dtype=tf.dtypes.int32,
shape=(None,),
name='SequenceModelBase.num_instructions_per_block',
)

def _forward(self, feed_dict):
model_output = self._model((
self._token_sequence_placeholder,
self._num_tokens_per_instruction_placeholder,
self._num_instructions_per_block_placeholder,
feed_dict['token_sequence'],
feed_dict['num_tokens_per_instruction'],
feed_dict['num_instructions_per_block'],
))

if self._use_deltas:
self._output_tensor_deltas = model_output
return {'output_deltas': model_output}
else:
self._output_tensor = model_output
return {'output': model_output}

# @Override
def _start_batch(self) -> None:
Expand All @@ -151,11 +137,11 @@ def _make_batch_feed_dict(self) -> model_base.FeedDict:
batch_tokens[oov_injection_mask] = self._oov_token

return {
self._token_sequence_placeholder: batch_tokens,
self._num_tokens_per_instruction_placeholder: np.array(
'token_sequence': batch_tokens,
'num_tokens_per_instruction': np.array(
self._batch_num_tokens_per_instruction, dtype=np.int32
),
self._num_instructions_per_block_placeholder: np.array(
'num_instructions_per_block': np.array(
self._batch_num_instructions_per_block, dtype=np.int32
),
}
Expand Down
21 changes: 6 additions & 15 deletions gematria/sequence/python/sequence_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,9 @@ def test_schedule_batch(self, use_deltas):
model = TestSequenceModel(tokens=self.tokens, use_deltas=use_deltas)
model.initialize()
schedule = model.schedule_batch(self.blocks_with_throughput)
self.assertEqual(schedule[model._token_sequence_placeholder].shape, (58,))
self.assertEqual(
schedule[model._num_tokens_per_instruction_placeholder].shape, (7,)
)
self.assertEqual(
schedule[model._num_instructions_per_block_placeholder].shape, (3,)
)
self.assertEqual(schedule['token_sequence'].shape, (58,))
self.assertEqual(schedule['num_tokens_per_instruction'].shape, (7,))
self.assertEqual(schedule['num_instructions_per_block'].shape, (3,))

def test_schedule_batch_with_invalid_block(self):
model = TestSequenceModel(tokens=self.tokens)
Expand Down Expand Up @@ -181,7 +177,7 @@ def test_inject_out_of_vocabulary_tokens(self):
self.assertGreaterEqual(model._oov_token, 0)

schedule = model.schedule_batch(self.blocks_with_throughput)
token_sequence = schedule[model._token_sequence_placeholder]
token_sequence = schedule['token_sequence']
expected_token_sequence = np.full_like(token_sequence, model._oov_token)
self.assertAllEqual(token_sequence, expected_token_sequence)

Expand All @@ -206,9 +202,7 @@ def test_inject_out_of_vocabulary_estimate(self):
num_all_elements = 0
for _ in range(num_trials):
schedule = model.schedule_batch(self.blocks_with_throughput)
oov_token_mask = (
schedule[model._token_sequence_placeholder] == model._oov_token
)
oov_token_mask = schedule['token_sequence'] == model._oov_token
num_ones += sum(oov_token_mask)
num_all_elements += oov_token_mask.size

Expand All @@ -233,9 +227,7 @@ def test_inject_out_of_vocabulary_tokens_zero_probability(self):
# not contain unknown tokens, we expect that the out-of-vocabulary
# replacement token is never used.
schedule = model.schedule_batch(self.blocks_with_throughput)
oov_token_mask = (
schedule[model._token_sequence_placeholder] == model._oov_token
)
oov_token_mask = schedule['token_sequence'] == model._oov_token
expected_oov_token_mask = np.zeros_like(oov_token_mask)
self.assertAllEqual(oov_token_mask, expected_oov_token_mask)

Expand All @@ -259,5 +251,4 @@ def test_validate_basic_block(self):


if __name__ == '__main__':
tf.disable_v2_behavior()
tf.test.main()
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.