diff --git a/sdk/downloader.py b/sdk/downloader.py index cbd1cf4..1904c7d 100644 --- a/sdk/downloader.py +++ b/sdk/downloader.py @@ -212,18 +212,20 @@ def download(self, skip: str, overwrite: bool, result_dict: dict) -> None: # Checking for model download if skip != DOWNLOAD_MODEL: # Downloading the model - download_model(self, overwrite) + options = download_model(self, overwrite) - # Adding downloaded model path to result + # Adding downloaded model properties to result result_dict["path"] = self.download_path + result_dict["options"] = get_options_for_json(options) # Checking for tokenizer download if self.belongs_to_module(TRANSFORMERS) and skip != DOWNLOAD_TOKENIZER: # Download a tokenizer for the model - download_transformers_tokenizer(self, overwrite) + options = download_transformers_tokenizer(self, overwrite) - # Adding downloaded tokenizer path to result + # Adding downloaded tokenizer properties to result result_dict["tokenizer"]["path"] = self.tokenizer.download_path + result_dict["tokenizer"]["options"] = get_options_for_json(options) def set_class_names(model: Model) -> None: @@ -279,7 +281,7 @@ def set_diffusers_class_names(model: Model) -> None: model.class_name = config['_class_name'] -def download_model(model: Model, overwrite: bool) -> None: +def download_model(model: Model, overwrite: bool) -> dict: """ Download the model. @@ -289,7 +291,7 @@ def download_model(model: Model, overwrite: bool) -> None: it exists. Returns: - None. Exit with error if anything goes wrong. + dict: A dictionary containing options used for model downloading. """ # Check if the model already exists at path @@ -302,7 +304,7 @@ def download_model(model: Model, overwrite: bool) -> None: model.module) # Processing options - options = process_options(model.options or []) + options = process_options(model.options) # Processing access token access_token = process_access_token(options, model) @@ -322,14 +324,16 @@ def download_model(model: Model, overwrite: bool) -> None: # Downloading the model try: model_downloaded = model_class_obj.from_pretrained( - model.name, **options, token=access_token) + model.name, token=access_token, **options) model_downloaded.save_pretrained(model.download_path) except Exception as e: err = f"Error downloading model {model.name}: {e}" exit_error(err, ERROR_EXIT_MODEL) + return options -def download_transformers_tokenizer(model: Model, overwrite: bool) -> None: + +def download_transformers_tokenizer(model: Model, overwrite: bool) -> dict: """ Download a transformers tokenizer for the model. @@ -339,7 +343,7 @@ def download_transformers_tokenizer(model: Model, overwrite: bool) -> None: it exists. Returns: - None. Exit with error if anything goes wrong. + dict: A dictionary containing options used for model downloading. """ # Retrieving tokenizer class from module @@ -367,7 +371,7 @@ def download_transformers_tokenizer(model: Model, overwrite: bool) -> None: exit_error(err) # Processing options - options = process_options(model.tokenizer.options or []) + options = process_options(model.tokenizer.options) # Downloading the tokenizer try: @@ -378,6 +382,8 @@ def download_transformers_tokenizer(model: Model, overwrite: bool) -> None: err = f"Error downloading tokenizer {model.tokenizer.class_name}: {e}" exit_error(err, ERROR_EXIT_TOKENIZER) + return options + def is_path_valid_for_download(path: str, overwrite: bool) -> bool: """ @@ -502,6 +508,23 @@ def process_access_token(options: dict, model: Model) -> str | None: return access_token +def get_options_for_json(options_dict: dict) -> dict: + """ + Prepares a dictionary containing options for conversion to JSON. + Args: + options_dict (dict): A dictionary containing options as key-value pairs. + Returns: + dict: A new dictionary with the same keys but with values prepared for + JSON serialization (strings with quotes for string values). + """ + for key, value in options_dict.items(): + if isinstance(value, str): + options_dict[key] = "\"{}\"".format(value) + else: + options_dict[key] = str(value) + return options_dict + + def map_args_to_model(args) -> Model: """ Maps command-line arguments to a Model object. diff --git a/sdk/tests/test_downloader.py b/sdk/tests/test_downloader.py index 0d32a96..a532163 100644 --- a/sdk/tests/test_downloader.py +++ b/sdk/tests/test_downloader.py @@ -19,6 +19,7 @@ is_path_valid_for_download, process_options, process_access_token, + get_options_for_json, map_args_to_model, main, exit_error, @@ -316,9 +317,54 @@ def test_process_access_token_missing(self): # Assert self.assertEqual("", result) + def test_get_options_for_json_input_empty(self): + # Init + input_options = {} + expected_options = {} + + # Execute + output_options = get_options_for_json(input_options) + + # Assert + self.assertEqual(expected_options, output_options) + + def test_get_options_for_json_value_string(self): + # Init + input_options = { + "key": "value1" + } + expected_options = { + "key": "\"value1\"" + } + + # Execute + output_options = get_options_for_json(input_options) + + # Assert + self.assertEqual(expected_options, output_options) + + def test_get_options_for_json_value_not_string(self): + # Init + input_options = { + "key": True + } + expected_options = { + "key": "True" + } + + # Execute + output_options = get_options_for_json(input_options) + + # Assert + self.assertEqual(expected_options, output_options) + @patch('downloader.is_path_valid_for_download', return_value=False) - def test_download_model_path_invalid( - self, mock_is_path_valid_for_download): + @patch('downloader.process_options') + @patch('downloader.process_access_token') + @patch('transformers.models.auto.modeling_auto.AutoModel.from_pretrained') + def test_download_model_with_path_invalid( + self, mock_from_pretrained, mock_process_access_token, + mock_process_options, mock_is_path_valid_for_download): # Init model = Model(name="TestModel", module="") @@ -329,13 +375,17 @@ def test_download_model_path_invalid( # Assert mock_is_path_valid_for_download.assert_called_once() + mock_process_options.assert_not_called() + mock_process_access_token.assert_not_called() + mock_from_pretrained.assert_not_called() @patch('downloader.is_path_valid_for_download', return_value=True) - @patch('downloader.process_options', return_value=[]) + @patch('downloader.process_options', return_value={}) @patch('downloader.process_access_token', return_value="") + @patch('transformers.models.auto.modeling_auto.AutoModel.from_pretrained') def test_download_model_with_objects_error( - self, mock_is_path_valid_for_download, mock_process_options, - mock_process_access_token): + self, mock_from_pretrained, mock_process_access_token, + mock_process_options, mock_is_path_valid_for_download): # Init model = Model(name="TestModel", module="") @@ -348,13 +398,16 @@ def test_download_model_with_objects_error( mock_is_path_valid_for_download.assert_called_once() mock_process_options.assert_called_once() mock_process_access_token.assert_called_once() + mock_from_pretrained.assert_not_called() @patch('downloader.is_path_valid_for_download', return_value=True) - @patch('downloader.process_options', return_value=[]) + @patch('downloader.process_options', return_value={}) @patch('downloader.process_access_token', return_value="") - def test_download_model_with_download_error( - self, mock_is_path_valid_for_download, mock_process_options, - mock_process_access_token): + @patch('transformers.models.auto.modeling_auto.AutoModel.from_pretrained' + '', side_effect=Exception("Download failed")) + def test_download_model_with_from_pretrained_error( + self, mock_from_pretrained, mock_process_access_token, + mock_process_options, mock_is_path_valid_for_download): # Init model = Model(name="TestModel", module=TRANSFORMERS) @@ -367,10 +420,102 @@ def test_download_model_with_download_error( mock_is_path_valid_for_download.assert_called_once() mock_process_options.assert_called_once() mock_process_access_token.assert_called_once() + mock_from_pretrained.assert_called_once() @patch('downloader.is_path_valid_for_download', return_value=True) + @patch('downloader.process_options', return_value={}) + @patch('downloader.process_access_token', return_value="") + @patch('transformers.models.auto.modeling_auto.AutoModel.from_pretrained') + def test_download_model_with_save_pretrained_error( + self, mock_from_pretrained, mock_process_access_token, + mock_process_options, mock_is_path_valid_for_download): + # Mockers : save_pretrained + mock_save_pretrained = MagicMock(side_effect=Exception("Save failed")) + + # Adding save_pretrained to from_pretrained returned value + data_from_pretrained = MagicMock() + data_from_pretrained.save_pretrained = mock_save_pretrained + mock_from_pretrained.return_value = data_from_pretrained + + # Init + model = Model(name="TestModel", module=TRANSFORMERS) + + # Execute : error = success + with self.assertRaises(SystemExit) as context: + download_model(model, overwrite=False) + self.assertEqual(context.exception.code, ERROR_EXIT_MODEL) + + # Assert + mock_is_path_valid_for_download.assert_called_once() + mock_process_options.assert_called_once() + mock_process_access_token.assert_called_once() + mock_from_pretrained.assert_called_once() + mock_save_pretrained.assert_called_once() + + @patch('downloader.is_path_valid_for_download', return_value=True) + @patch('downloader.process_options') + @patch('downloader.process_access_token', return_value="") + @patch('transformers.models.auto.modeling_auto.AutoModel.from_pretrained') + def test_download_model_success( + self, mock_from_pretrained, mock_process_access_token, + mock_process_options, mock_is_path_valid_for_download): + # Mockers : save_pretrained + mock_save_pretrained = MagicMock(return_value=None) + + # Adding save_pretrained to from_pretrained returned value + data_from_pretrained = MagicMock() + data_from_pretrained.save_pretrained = mock_save_pretrained + mock_from_pretrained.return_value = data_from_pretrained + + # Options + expected_options = {"key1": "value1"} + mock_process_options.return_value = expected_options + + # Init + model = Model(name="TestModel", module=TRANSFORMERS) + model.options = ["key1='value1'"] + + # Execute : + result_options = download_model(model, overwrite=False) + + # Assert + self.assertEqual(expected_options, result_options) + mock_is_path_valid_for_download.assert_called_once() + mock_process_options.assert_called_once() + mock_process_access_token.assert_called_once() + mock_from_pretrained.assert_called_once() + mock_save_pretrained.assert_called_once() + + @patch('downloader.is_path_valid_for_download', return_value=False) + @patch('downloader.process_options') + @patch('transformers.models.auto.tokenization_auto.AutoTokenizer' + '.from_pretrained') + def test_download_transformers_tokenizer_with_path_invalid( + self, mock_from_pretrained, mock_process_options, + mock_is_path_valid_for_download): + # Init + model = Model(name="TestModel", module=TRANSFORMERS) + model.base_path = "path/to/model" + tokenizer = Tokenizer(class_name="AutoTokenizer") + model.tokenizer = tokenizer + + # Execute : error = success + with self.assertRaises(SystemExit) as context: + download_transformers_tokenizer(model, overwrite=False) + self.assertEqual(context.exception.code, ERROR_EXIT_DEFAULT) + + # Assert + mock_is_path_valid_for_download.assert_called_once() + mock_process_options.assert_not_called() + mock_from_pretrained.assert_not_called() + + @patch('downloader.is_path_valid_for_download', return_value=True) + @patch('downloader.process_options') + @patch('transformers.models.auto.tokenization_auto.AutoTokenizer' + '.from_pretrained') def test_download_transformers_tokenizer_with_objects_error( - self, mock_is_path_valid_for_download): + self, mock_from_pretrained, mock_process_options, + mock_is_path_valid_for_download): # Init model = Model(name="TestModel", module="") model.tokenizer = Tokenizer(class_name="error") @@ -382,10 +527,16 @@ def test_download_transformers_tokenizer_with_objects_error( # Assert mock_is_path_valid_for_download.assert_not_called() + mock_process_options.assert_not_called() + mock_from_pretrained.assert_not_called() - @patch('downloader.is_path_valid_for_download', return_value=False) - def test_download_transformers_path_invalid( - self, mock_is_path_valid_for_download): + @patch('downloader.is_path_valid_for_download', return_value=True) + @patch('downloader.process_options', return_value={}) + @patch('transformers.models.auto.tokenization_auto.AutoTokenizer' + '.from_pretrained', side_effect=Exception("Download failed")) + def test_download_transformers_with_from_pretrained_error( + self, mock_from_pretrained, mock_process_options, + mock_is_path_valid_for_download): # Init model = Model(name="TestModel", module=TRANSFORMERS) model.base_path = "path/to/model" @@ -395,15 +546,28 @@ def test_download_transformers_path_invalid( # Execute : error = success with self.assertRaises(SystemExit) as context: download_transformers_tokenizer(model, overwrite=False) - self.assertEqual(context.exception.code, ERROR_EXIT_DEFAULT) + self.assertEqual(context.exception.code, ERROR_EXIT_TOKENIZER) # Assert mock_is_path_valid_for_download.assert_called_once() + mock_process_options.assert_called_once() + mock_from_pretrained.assert_called_once() @patch('downloader.is_path_valid_for_download', return_value=True) - @patch('downloader.process_options', return_value=[]) - def test_download_transformers_download_error( - self, mock_is_path_valid_for_download, mock_process_options): + @patch('downloader.process_options', return_value={}) + @patch('transformers.models.auto.tokenization_auto.AutoTokenizer' + '.from_pretrained') + def test_download_transformers_with_save_pretrained_error( + self, mock_from_pretrained, mock_process_options, + mock_is_path_valid_for_download): + # Mockers : save_pretrained + mock_save_pretrained = MagicMock(side_effect=Exception("Save failed")) + + # Adding save_pretrained to from_pretrained returned value + data_from_pretrained = MagicMock() + data_from_pretrained.save_pretrained = mock_save_pretrained + mock_from_pretrained.return_value = data_from_pretrained + # Init model = Model(name="TestModel", module=TRANSFORMERS) model.base_path = "path/to/model" @@ -418,21 +582,64 @@ def test_download_transformers_download_error( # Assert mock_is_path_valid_for_download.assert_called_once() mock_process_options.assert_called_once() + mock_from_pretrained.assert_called_once() + mock_save_pretrained.assert_called_once() + + @patch('downloader.is_path_valid_for_download', return_value=True) + @patch('downloader.process_options') + @patch('transformers.models.auto.tokenization_auto.AutoTokenizer' + '.from_pretrained') + def test_download_transformers_success( + self, mock_from_pretrained, mock_process_options, + mock_is_path_valid_for_download): + # Mockers : save_pretrained + mock_save_pretrained = MagicMock(return_value=None) + + # Adding save_pretrained to from_pretrained returned value + data_from_pretrained = MagicMock() + data_from_pretrained.save_pretrained = mock_save_pretrained + mock_from_pretrained.return_value = data_from_pretrained + + # Options + expected_options = {"key1": "value1"} + mock_process_options.return_value = expected_options - @patch('downloader.download_model', return_value=None) + # Init + model = Model(name="TestModel", module=TRANSFORMERS) + model.base_path = "path/to/model" + tokenizer = Tokenizer(class_name="AutoTokenizer") + tokenizer.options = expected_options + model.tokenizer = tokenizer + + # Execute : + result = download_transformers_tokenizer(model, overwrite=False) + + # Assert + self.assertEqual(expected_options, result) + mock_is_path_valid_for_download.assert_called_once() + mock_process_options.assert_called_once() + mock_from_pretrained.assert_called_once() + mock_save_pretrained.assert_called_once() + + @patch('downloader.download_model') def test_download_model_skip_success(self, mock_download_model): # Init model = Model(name="TestModel", module="module") model.download_path = "path/to/model" model.class_name = "class_name" + model.options = ["test='test'"] model.validate = MagicMock() model.build_paths = MagicMock() # Prepare + mock_download_model.return_value = {"test": "test"} expected_result = { "path": model.download_path, "module": model.module, - "class": model.class_name + "class": model.class_name, + "options": { + "test": "\"test\"" + } } result_dict = { @@ -448,22 +655,27 @@ def test_download_model_skip_success(self, mock_download_model): self.assertEqual(result_dict, expected_result) mock_download_model.assert_called_once() - @patch('downloader.download_transformers_tokenizer', return_value=None) + @patch('downloader.download_transformers_tokenizer') def test_download_tokenizer_skip_success( self, mock_download_transformers_tokenizer): # Init model = Model(name="TestModel", module=TRANSFORMERS) tokenizer = Tokenizer(class_name="AutoTokenizer") tokenizer.download_path = "path/to/tokenizer" + tokenizer.options = ["test='test'"] model.tokenizer = tokenizer model.validate = MagicMock() model.build_paths = MagicMock() # Prepare + mock_download_transformers_tokenizer.return_value = {"test": "test"} expected_result = { "tokenizer": { "path": model.tokenizer.download_path, "class": model.tokenizer.class_name, + "options": { + "test": "\"test\"" + } } } result_dict = {"tokenizer": {