Skip to content

Commit

Permalink
Upgrade to transformers 4.34.1 (#2635)
Browse files Browse the repository at this point in the history
* bump transformers version

* add new special casing to tokenizer equivalence check

* try/except for flash v1 issue
  • Loading branch information
dakinggg authored Oct 23, 2023
1 parent 5430213 commit 40cf910
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 3 deletions.
13 changes: 12 additions & 1 deletion composer/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,18 @@ def _is_registered_causal_lm(model: transformers.PreTrainedModel) -> bool:
raise MissingConditionalImportError(extra_deps_group='nlp',
conda_package='transformers',
conda_channel='conda-forge') from e
causal_lm_classes = list(MODEL_FOR_CAUSAL_LM_MAPPING.values())

# This try/except is needed until https://github.com/huggingface/transformers/issues/26778
# is resolved in a release. This means that this attempt to automatically detect causal LMs
# does not currently work in an environment with flash attention <2 installed.
try:
causal_lm_classes = list(MODEL_FOR_CAUSAL_LM_MAPPING.values())
except RuntimeError as e:
if 'Failed to import transformers.models' in str(e):
MODEL_FOR_CAUSAL_LM_MAPPING = {}
return False
else:
raise e
return any(isinstance(model, causal_lm_class) for causal_lm_class in causal_lm_classes)


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def package_files(prefix: str, directory: str, extension: str):
]

extra_deps['nlp'] = [
'transformers>=4.11,<4.34',
'transformers>=4.11,<4.35,!=4.34.0',
'datasets>=2.4,<3',
]

Expand Down
48 changes: 47 additions & 1 deletion tests/models/test_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,46 @@ def check_hf_tokenizer_equivalence(tokenizer1, tokenizer2):
tokenizer1.__dict__['init_kwargs'].pop('auto_map', None)
tokenizer2.__dict__['init_kwargs'].pop('auto_map', None)

# Additional special tokens do not match between original tokenizer and loaded tokenizer due to transformers
# constructor differences
additional_special_tokens_1 = {
t if isinstance(t, str) else t.content for t in tokenizer1.__dict__.pop('_additional_special_tokens', [])
}
additional_special_tokens_2 = {
t if isinstance(t, str) else t.content for t in tokenizer2.__dict__.pop('_additional_special_tokens', [])
}
# Also pop it out of init_kwargs
tokenizer1.__dict__['init_kwargs'].pop('additional_special_tokens', None)
tokenizer2.__dict__['init_kwargs'].pop('additional_special_tokens', None)
tokenizer1.__dict__['init_kwargs'].pop('added_tokens_decoder', None)
tokenizer2.__dict__['init_kwargs'].pop('added_tokens_decoder', None)
# If the additional special tokens are the same (or a subset of each other), or if one of them is empty, then we are good
assert additional_special_tokens_1.issubset(additional_special_tokens_2) or additional_special_tokens_2.issubset(
additional_special_tokens_1)

# The special token attributes may be strings or they may be AddedToken objects, so we just check string values
# First check that they have the same attrs
assert tokenizer1.SPECIAL_TOKENS_ATTRIBUTES == tokenizer2.SPECIAL_TOKENS_ATTRIBUTES
# Then check that the values are the same
for special_token_attr in tokenizer1.SPECIAL_TOKENS_ATTRIBUTES:
# Skip additional_special_tokens because we already checked it above
if special_token_attr == 'additional_special_tokens':
continue

# The init_kwargs can change between the original tokenizer and the loaded tokenizer,
# so we just pop them
tokenizer1.__dict__['init_kwargs'].pop(special_token_attr, None)
tokenizer2.__dict__['init_kwargs'].pop(special_token_attr, None)

attr1 = tokenizer1.__dict__.pop('_' + special_token_attr, None)
attr2 = tokenizer2.__dict__.pop('_' + special_token_attr, None)
if attr1 is None and attr2 is None:
continue

attr_value1 = attr1 if isinstance(attr1, str) else attr1.content
attr_value2 = attr2 if isinstance(attr2, str) else attr2.content
assert attr_value1 == attr_value2

assert tokenizer1.__dict__ == tokenizer2.__dict__


Expand Down Expand Up @@ -559,7 +599,10 @@ def test_hf_loading_sentencepiece_tokenizer(modify_tokenizer: bool, tmp_path: Pa
if modify_tokenizer:
assert t0_pp_tokenizer is not None # pyright
t0_pp_tokenizer.add_special_tokens({'bos_token': '[NEWSPECIAL]'})
t0_pp_tokenizer.add_special_tokens({'additional_special_tokens': ['[MOSAICML']})
# This is apparently not allowed anymore
# It results in ValueError: Both extra_ids (100) and additional_special_tokens (['[MOSAICML'])
# are provided to T5Tokenizer. In this case the additional_special_tokens must include the extra_ids tokens
# t0_pp_tokenizer.add_special_tokens({'additional_special_tokens': ['[MOSAICML']})
t0_pp_tokenizer.add_tokens(['totallyarealtoken', 'mosaicml'])
tiny_t5_model.resize_token_embeddings(len(t0_pp_tokenizer))

Expand All @@ -585,6 +628,8 @@ def test_hf_loading_sentencepiece_tokenizer(modify_tokenizer: bool, tmp_path: Pa


@pytest.mark.parametrize('modify_tokenizer', [False, True])
# https://github.com/huggingface/transformers/issues/26777
@pytest.mark.skip('This tokenizer no longer loads at all as of transformers 4.34')
def test_hf_loading_tokenizer_with_python_file(modify_tokenizer: bool, tmp_path: Path, tiny_gpt2_model):
transformers = pytest.importorskip('transformers')
replit_tokenizer = transformers.AutoTokenizer.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)
Expand Down Expand Up @@ -618,6 +663,7 @@ def test_hf_loading_llama_tokenizer(modify_tokenizer: bool, tmp_path: Path, tiny
llama_tokenizer.add_special_tokens({'bos_token': '[NEWSPECIAL]'})
llama_tokenizer.add_special_tokens({'additional_special_tokens': ['[MOSAICML']})
llama_tokenizer.add_tokens(['totallyarealtoken', 'mosaicml'])
llama_tokenizer.update_post_processor()

# we don't actually need the right model here, so avoiding adding llama
tiny_gpt2_model.resize_token_embeddings(len(llama_tokenizer))
Expand Down

0 comments on commit 40cf910

Please sign in to comment.