Skip to content

Commit

Permalink
Fix T5 prompts with spaces at eol
Browse files Browse the repository at this point in the history
  • Loading branch information
helena-intel committed Mar 25, 2024
1 parent f2ba861 commit d6b7a21
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
version = "0.1"
try:
repo_root = os.path.dirname(os.path.realpath(__file__))
commit_id = subprocess.run(["git", "rev-parse", "--short", "HEAD"], cwd=repo_root, text=True).stdout
commit_id = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"], text=True).strip()
version += f"+{commit_id}"
except subprocess.CalledProcessError:
pass
Expand Down
8 changes: 8 additions & 0 deletions test_prompt_generator/test_prompt_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ def generate_prompt(
)

prompt = tokenizer.decode(tokens[:num_tokens_from_source], skip_special_tokens=True)

# If a prompt ends with a space, t5 will drop that from tokenization and the prompt will
# not have enough tokens. Just increasing num_tokens_from_source doesn't help because
# then a space and a new word will be added, making a prompt with too many tokens.
# This solution is not great, but it's simple and works.
if prompt.endswith(" ") and "t5" in tokenizer_id:
prompt = prompt[:-1] + "."

if "chatglm" in tokenizer_id:
# chatglm adds these tokens even when skip_special_tokens=True
prompt = prompt.replace("[gMASK] sop ", "")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class PromptGeneratorTest(unittest.TestCase):
@parameterized.expand(_preset_tokenizers.items())
def test_prompt_generator_from_preset(self, preset_name, tokenizer_id):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True)
for num_tokens in 32, 55, 1029:
for num_tokens in 16, 32, 64, 128, 256, 512, 768, 1024, 2048:
prompt = generate_prompt(preset_name, num_tokens)
tokens = tokenizer(prompt)["input_ids"]
self.assertEqual(len(tokens), num_tokens)
Expand Down

0 comments on commit d6b7a21

Please sign in to comment.