Skip to content

Commit

Permalink
Revert "Randomly initialize new embedding vectors when updating vocab…
Browse files Browse the repository at this point in the history
…ulary (#499)"

This reverts commit 1b3cf14.
  • Loading branch information
guillaumekln committed Oct 22, 2019
1 parent 9beb15f commit 1b73cce
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
6 changes: 2 additions & 4 deletions opennmt/tests/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,10 @@ def testVocabMappingReplace(self):

def testVocabVariableUpdate(self):
mapping = [0, -1, -1, 2, -1, 4]
old = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0])
old = np.array([1, 2, 3, 4, 5, 6, 7])
vocab_size = 7
new = checkpoint._update_vocabulary_variable(old, vocab_size, mapping)
for index, value in zip(mapping, new):
if index >= 0:
self.assertEqual(value, old[index])
self.assertAllEqual([1, 0, 0, 3, 0, 5], new)

def _generateCheckpoint(self,
model_dir,
Expand Down
2 changes: 1 addition & 1 deletion opennmt/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _update_vocabulary_variable(variable, vocab_size, mapping):
variable_t = np.transpose(variable, axes=perm)
new_shape = list(variable_t.shape)
new_shape[0] = len(mapping)
new_variable_t = np.random.uniform(-1.0, 1.0, size=new_shape).astype(variable.dtype)
new_variable_t = np.zeros(new_shape, dtype=variable.dtype)
for i, j in enumerate(mapping):
if j >= 0:
new_variable_t[i] = variable_t[j]
Expand Down

0 comments on commit 1b73cce

Please sign in to comment.