diff --git a/sdk/downloader.py b/sdk/downloader.py index 675c915..f924292 100644 --- a/sdk/downloader.py +++ b/sdk/downloader.py @@ -4,7 +4,6 @@ import sys import importlib -import diffusers import transformers # Authorized module names for download @@ -396,7 +395,7 @@ def parse_arguments(): help="Path to the downloads directory") parser.add_argument("model_name", type=str, help="Model name") - parser.add_argument("model_module", type=str, help=f"Module name", + parser.add_argument("model_module", type=str, help="Module name", choices=AUTHORIZED_MODULE_NAMES) # Optional arguments regarding the model diff --git a/sdk/tests/test_downloader.py b/sdk/tests/test_downloader.py index b95d4c3..61e0a75 100644 --- a/sdk/tests/test_downloader.py +++ b/sdk/tests/test_downloader.py @@ -4,12 +4,13 @@ import argparse import json from unittest.mock import patch, MagicMock + sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -from downloader import (Model, Tokenizer, download_model, - download_transformers_tokenizer, - is_path_valid_for_download, process_options, - map_args_to_model, parse_arguments, main, exit_error) +from downloader import (Model, Tokenizer, download_model, # noqa: E402 + download_transformers_tokenizer, + is_path_valid_for_download, process_options, + map_args_to_model, main, exit_error) class TestDownloader(unittest.TestCase): @@ -46,13 +47,13 @@ def test_build_paths(self): def test_build_paths_transformers(self): model = Model(name="TestModel", module="transformers") model.build_paths("/models") - expected_download_path = f"/models\\TestModel\\model" + expected_download_path = "/models\\TestModel\\model" self.assertEqual(model.download_path, expected_download_path) def test_build_paths_non_transformers(self): model = Model(name="TestModel", module="non_transformers") model.build_paths("/models") - expected_download_path = f"/models\\TestModel" + expected_download_path = "/models\\TestModel" self.assertEqual(model.download_path, expected_download_path) @patch('os.path.exists') @@ -269,11 +270,13 @@ def test_download(self, mock_download_model): @patch('os.path.exists') @patch('os.listdir') - def test_download_transformers_tokenizer_exception(self, mock_listdir, mock_exists): + def test_download_transformers_tokenizer_exception(self, mock_listdir, + mock_exists): mock_exists.return_value = False mock_listdir.return_value = [] model = Model(name="TestModel", module="transformers", - tokenizer=Tokenizer(class_name="PreTrainedTokenizerFast")) + tokenizer=Tokenizer( + class_name="PreTrainedTokenizerFast")) model.tokenizer.download_path = "/tokenizer_path" model.base_path = "/model_path" @@ -284,4 +287,4 @@ def test_download_transformers_tokenizer_exception(self, mock_listdir, mock_exis if __name__ == '__main__': - unittest.main() # pragma: no cover + unittest.main() # pragma: no cover