Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotfix transformers class #79

Merged
merged 3 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions sdk/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,11 @@ def set_transformers_class_names(model: Model) -> None:
# Get the configuration
config = transformers.AutoConfig.from_pretrained(model.name)

# Map model class from model type
model_mapping = transformers.AutoModel._model_mapping._model_mapping

# Set model class name if not already set
model.class_name = model.class_name or model_mapping.get(
config.model_type)
model.class_name = model.class_name or config.architectures[0] \
if config.architectures and config.architectures[0] else (
model_config_default_class_for_module[TRANSFORMERS]
)

# Set tokenizer class name if not already set
# and config.tokenizer_class exists
Expand Down
6 changes: 3 additions & 3 deletions sdk/tests/test_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ def test_set_diffusers_class_names_with_configured_model(self,
self.assertEqual(model.class_name, 'TestModel')

@patch('transformers.AutoConfig.from_pretrained',
return_value=MagicMock(model_type='t5',
return_value=MagicMock(architectures=['T5Model'],
tokenizer_class='TokenizerClass'))
def test_set_transformers_class_names(self, mock_load_config):
# Init
Expand All @@ -909,7 +909,7 @@ def test_set_transformers_class_names(self, mock_load_config):
self.assertEqual(model.tokenizer.class_name, 'TokenizerClass')

@patch('transformers.AutoConfig.from_pretrained',
return_value=MagicMock(model_type='t5',
return_value=MagicMock(architectures=['T5Model'],
tokenizer_class=None))
def test_set_transformers_class_names_with_default_tokenizer(
self, mock_load_config
Expand All @@ -927,7 +927,7 @@ def test_set_transformers_class_names_with_default_tokenizer(
self.assertEqual(model.tokenizer.class_name, 'AutoTokenizer')

@patch('transformers.AutoConfig.from_pretrained',
return_value=MagicMock(model_type='t5',
return_value=MagicMock(architectures=['T5Model'],
tokenizer_class='TokenizerClass'))
def test_set_transformers_class_names_with_configured_model(
self, mock_load_config
Expand Down
Loading