Skip to content

Commit

Permalink
modif for linter
Browse files Browse the repository at this point in the history
  • Loading branch information
vhahnschutz committed Mar 6, 2024
1 parent 21a21d5 commit b8ac91d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
3 changes: 1 addition & 2 deletions sdk/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import sys
import importlib

import diffusers
import transformers

# Authorized module names for download
Expand Down Expand Up @@ -396,7 +395,7 @@ def parse_arguments():
help="Path to the downloads directory")
parser.add_argument("model_name", type=str,
help="Model name")
parser.add_argument("model_module", type=str, help=f"Module name",
parser.add_argument("model_module", type=str, help="Module name",
choices=AUTHORIZED_MODULE_NAMES)

# Optional arguments regarding the model
Expand Down
21 changes: 12 additions & 9 deletions sdk/tests/test_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import argparse
import json
from unittest.mock import patch, MagicMock

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__),
"..")))
from downloader import (Model, Tokenizer, download_model,
download_transformers_tokenizer,
is_path_valid_for_download, process_options,
map_args_to_model, parse_arguments, main, exit_error)
from downloader import (Model, Tokenizer, download_model, # noqa: E402
download_transformers_tokenizer,
is_path_valid_for_download, process_options,
map_args_to_model, main, exit_error)


class TestDownloader(unittest.TestCase):
Expand Down Expand Up @@ -46,13 +47,13 @@ def test_build_paths(self):
def test_build_paths_transformers(self):
model = Model(name="TestModel", module="transformers")
model.build_paths("/models")
expected_download_path = f"/models\\TestModel\\model"
expected_download_path = "/models\\TestModel\\model"
self.assertEqual(model.download_path, expected_download_path)

def test_build_paths_non_transformers(self):
model = Model(name="TestModel", module="non_transformers")
model.build_paths("/models")
expected_download_path = f"/models\\TestModel"
expected_download_path = "/models\\TestModel"
self.assertEqual(model.download_path, expected_download_path)

@patch('os.path.exists')
Expand Down Expand Up @@ -269,11 +270,13 @@ def test_download(self, mock_download_model):

@patch('os.path.exists')
@patch('os.listdir')
def test_download_transformers_tokenizer_exception(self, mock_listdir, mock_exists):
def test_download_transformers_tokenizer_exception(self, mock_listdir,
mock_exists):
mock_exists.return_value = False
mock_listdir.return_value = []
model = Model(name="TestModel", module="transformers",
tokenizer=Tokenizer(class_name="PreTrainedTokenizerFast"))
tokenizer=Tokenizer(
class_name="PreTrainedTokenizerFast"))
model.tokenizer.download_path = "/tokenizer_path"
model.base_path = "/model_path"

Expand All @@ -284,4 +287,4 @@ def test_download_transformers_tokenizer_exception(self, mock_listdir, mock_exis


if __name__ == '__main__':
unittest.main() # pragma: no cover
unittest.main() # pragma: no cover

0 comments on commit b8ac91d

Please sign in to comment.