-
Notifications
You must be signed in to change notification settings - Fork 231
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into angelayi/aoti_metadata
- Loading branch information
Showing
5 changed files
with
169 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
""" | ||
Abstract base class for all tokenizer classes in python matching c++ interface. | ||
""" | ||
|
||
# Standard | ||
from abc import ABC, abstractmethod | ||
from typing import List | ||
|
||
|
||
class TokenizerBase(ABC): | ||
__doc__ = __doc__ | ||
|
||
@abstractmethod | ||
def encode(self, s: str, *, bos: bool = False, eos: bool = False) -> List[int]: | ||
"""Encode the given string and optionally include bos/eos tokens""" | ||
|
||
@abstractmethod | ||
def decode(self, ids: List[int]) -> str: | ||
"""Decode the given token ids into a string""" | ||
|
||
@abstractmethod | ||
def bos_id(self) -> int: | ||
"""The id of the begin-of-string token""" | ||
|
||
@abstractmethod | ||
def eos_id(self) -> int: | ||
"""The id of the end-of-string token""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# Standard | ||
from typing import List, Optional | ||
import json | ||
import os | ||
|
||
# Third Party | ||
from tokenizers import Tokenizer | ||
|
||
# Local | ||
from .base import TokenizerBase | ||
|
||
|
||
class HFTokenizer(TokenizerBase): | ||
""" | ||
Wrapper around the Huggingface `tokenizers` library for API compatibility | ||
""" | ||
|
||
def __init__(self, file_path: str): | ||
# If the path is a directory, look for "tokenizer.json" which is | ||
# standard for transformers checkpoints and also look for the | ||
# "tokenizer_config.json" file to parse eos/bos tokens | ||
if os.path.isdir(file_path): | ||
tokenizer_path = os.path.join(file_path, "tokenizer.json") | ||
tokenizer_config_path = os.path.join(file_path, "tokenizer_config.json") | ||
else: | ||
tokenizer_path = file_path | ||
tokenizer_config_path = os.path.join(os.path.dirname(file_path), "tokenizer_config.json") | ||
if not os.path.isfile(tokenizer_path): | ||
tokenizer_config_path = None | ||
|
||
# Load the tokenizer itself | ||
self._tokenizer = Tokenizer.from_file(tokenizer_path) | ||
|
||
# If available, parse bos/eos tokens from the tokenizer config | ||
self._bos_id, self._eos_id = None, None | ||
if tokenizer_config_path is not None: | ||
with open(tokenizer_config_path, "r") as handle: | ||
tok_config = json.load(handle) | ||
bos_token = tok_config.get("bos_token") | ||
eos_token = tok_config.get("eos_token") | ||
if bos_token is not None: | ||
self._bos_id = self._tokenizer.token_to_id(bos_token) | ||
if eos_token is not None: | ||
self._eos_id = self._tokenizer.token_to_id(eos_token) | ||
|
||
# If no eos/bos tokens found, go looking for them! | ||
if None in [self._bos_id, self._eos_id]: | ||
tok_content = json.loads(self._tokenizer.to_str()) | ||
if self._bos_id is None: | ||
self._bos_id = self._look_for_special_token(tok_content, ["begin", "text"]) | ||
if self._eos_id is None: | ||
self._eos_id = self._look_for_special_token(tok_content, ["end", "text"]) | ||
|
||
assert None not in [self._bos_id, self._eos_id], "Unable to find an BOS/EOS tokens" | ||
|
||
@staticmethod | ||
def _look_for_special_token(added_tokens: dict, search_strs: List[str]) -> Optional[int]: | ||
candidate_toks = added_tokens | ||
for search_str in search_strs: | ||
candidate_toks = [ | ||
tok for tok in candidate_toks | ||
if tok["special"] and search_str in tok["content"] | ||
] | ||
if len(candidate_toks) == 1: | ||
return candidate_toks[0]["id"] | ||
|
||
def encode( | ||
self, | ||
s: str, | ||
*, | ||
bos: bool = False, | ||
eos: bool = False, | ||
) -> List[int]: | ||
res = self._tokenizer.encode(s, add_special_tokens=bos).ids | ||
if eos and (not res or res[-1] != self._eos_token): | ||
res.append(self._eos_token) | ||
return res | ||
|
||
def decode(self, ids: List[int]) -> str: | ||
return self._tokenizer.decode(ids) | ||
|
||
def bos_id(self) -> int: | ||
return self._bos_id | ||
|
||
def eos_id(self) -> int: | ||
return self._eos_id |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters