Skip to content

Commit

Permalink
Refactor tiktoken import bare except (#1024)
Browse files Browse the repository at this point in the history
* Refactor tiktoken import bare except

We shouldn't suppress the import error if tiktoken is not installed

Instead of importing it at the top of the file, we can import it inside the only functions that use these imports; we can consider a different import structure if we end up needing to access these modules in more places

This lazy importing should also decrease the load time of these modules

Also do the same thing for sentencepiece

* Remove __future__.annotations import
  • Loading branch information
ringohoffman authored Oct 7, 2024
1 parent e2301e9 commit e7331ab
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions torchao/_models/llama/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
# copied from https://github.com/pytorch-labs/gpt-fast/blob/main/tokenizer.py

import os
import sentencepiece as spm
try:
import tiktoken
from tiktoken.load import load_tiktoken_bpe
except:
pass
from pathlib import Path
from typing import Dict

class TokenizerInterface:
def __init__(self, model_path):
Expand All @@ -28,6 +21,8 @@ def eos_id(self):

class SentencePieceWrapper(TokenizerInterface):
def __init__(self, model_path):
import sentencepiece as spm

super().__init__(model_path)
self.processor = spm.SentencePieceProcessor(str(model_path))
self.bos_token_id = self.bos_id()
Expand All @@ -50,16 +45,19 @@ class TiktokenWrapper(TokenizerInterface):
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
"""

special_tokens: Dict[str, int]
special_tokens: dict[str, int]

num_reserved_special_tokens = 256

pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501

def __init__(self, model_path):
import tiktoken
import tiktoken.load

super().__init__(model_path)
assert os.path.isfile(model_path), str(model_path)
mergeable_ranks = load_tiktoken_bpe(str(model_path))
mergeable_ranks = tiktoken.load.load_tiktoken_bpe(str(model_path))
num_base_tokens = len(mergeable_ranks)
special_tokens = [
"<|begin_of_text|>",
Expand Down

0 comments on commit e7331ab

Please sign in to comment.