Skip to content

Commit

Permalink
downloader script to return options (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zapharaos authored Mar 12, 2024
1 parent 6e7f90d commit 72823fd
Show file tree
Hide file tree
Showing 2 changed files with 266 additions and 31 deletions.
45 changes: 34 additions & 11 deletions sdk/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,18 +212,20 @@ def download(self, skip: str, overwrite: bool, result_dict: dict) -> None:
# Checking for model download
if skip != DOWNLOAD_MODEL:
# Downloading the model
download_model(self, overwrite)
options = download_model(self, overwrite)

# Adding downloaded model path to result
# Adding downloaded model properties to result
result_dict["path"] = self.download_path
result_dict["options"] = get_options_for_json(options)

# Checking for tokenizer download
if self.belongs_to_module(TRANSFORMERS) and skip != DOWNLOAD_TOKENIZER:
# Download a tokenizer for the model
download_transformers_tokenizer(self, overwrite)
options = download_transformers_tokenizer(self, overwrite)

# Adding downloaded tokenizer path to result
# Adding downloaded tokenizer properties to result
result_dict["tokenizer"]["path"] = self.tokenizer.download_path
result_dict["tokenizer"]["options"] = get_options_for_json(options)


def set_class_names(model: Model) -> None:
Expand Down Expand Up @@ -279,7 +281,7 @@ def set_diffusers_class_names(model: Model) -> None:
model.class_name = config['_class_name']


def download_model(model: Model, overwrite: bool) -> None:
def download_model(model: Model, overwrite: bool) -> dict:
"""
Download the model.
Expand All @@ -289,7 +291,7 @@ def download_model(model: Model, overwrite: bool) -> None:
it exists.
Returns:
None. Exit with error if anything goes wrong.
dict: A dictionary containing options used for model downloading.
"""

# Check if the model already exists at path
Expand All @@ -302,7 +304,7 @@ def download_model(model: Model, overwrite: bool) -> None:
model.module)

# Processing options
options = process_options(model.options or [])
options = process_options(model.options)

# Processing access token
access_token = process_access_token(options, model)
Expand All @@ -322,14 +324,16 @@ def download_model(model: Model, overwrite: bool) -> None:
# Downloading the model
try:
model_downloaded = model_class_obj.from_pretrained(
model.name, **options, token=access_token)
model.name, token=access_token, **options)
model_downloaded.save_pretrained(model.download_path)
except Exception as e:
err = f"Error downloading model {model.name}: {e}"
exit_error(err, ERROR_EXIT_MODEL)

return options

def download_transformers_tokenizer(model: Model, overwrite: bool) -> None:

def download_transformers_tokenizer(model: Model, overwrite: bool) -> dict:
"""
Download a transformers tokenizer for the model.
Expand All @@ -339,7 +343,7 @@ def download_transformers_tokenizer(model: Model, overwrite: bool) -> None:
it exists.
Returns:
None. Exit with error if anything goes wrong.
dict: A dictionary containing options used for model downloading.
"""

# Retrieving tokenizer class from module
Expand Down Expand Up @@ -367,7 +371,7 @@ def download_transformers_tokenizer(model: Model, overwrite: bool) -> None:
exit_error(err)

# Processing options
options = process_options(model.tokenizer.options or [])
options = process_options(model.tokenizer.options)

# Downloading the tokenizer
try:
Expand All @@ -378,6 +382,8 @@ def download_transformers_tokenizer(model: Model, overwrite: bool) -> None:
err = f"Error downloading tokenizer {model.tokenizer.class_name}: {e}"
exit_error(err, ERROR_EXIT_TOKENIZER)

return options


def is_path_valid_for_download(path: str, overwrite: bool) -> bool:
"""
Expand Down Expand Up @@ -502,6 +508,23 @@ def process_access_token(options: dict, model: Model) -> str | None:
return access_token


def get_options_for_json(options_dict: dict) -> dict:
"""
Prepares a dictionary containing options for conversion to JSON.
Args:
options_dict (dict): A dictionary containing options as key-value pairs.
Returns:
dict: A new dictionary with the same keys but with values prepared for
JSON serialization (strings with quotes for string values).
"""
for key, value in options_dict.items():
if isinstance(value, str):
options_dict[key] = "\"{}\"".format(value)
else:
options_dict[key] = str(value)
return options_dict


def map_args_to_model(args) -> Model:
"""
Maps command-line arguments to a Model object.
Expand Down
Loading

0 comments on commit 72823fd

Please sign in to comment.