Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making comments consistent #117

Merged
merged 13 commits into from
May 12, 2024
11 changes: 10 additions & 1 deletion sdk/demo/demo_text_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,20 @@
class DemoTextConv:
"""
This class demonstrates a text conversation using a chatbot model.

Attributes:
model_path (str): The path to the model.
tokenizer_path (str): The path to the tokenizer.
model_management (ModelsManagement): An instance of
the ModelsManagement class.
model (ModelsTextConversation): An instance of the
ModelsTextConversation class.
"""

def __init__(self):
"""
Initializes the DemoTextConv class with predefined options and models.
__init__ Initializes the DemoTextConv class with predefined
options and models and returns result of prompt .
"""
model_path = "microsoft/phi-2"
tokenizer_path = "microsoft/phi-2"
Expand Down
6 changes: 4 additions & 2 deletions sdk/demo/demo_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@

class DemoTextGen:
"""
This class demonstrates a text conversation using a chatbot model.
This class demonstrates text generation using a chatbot model.

"""

def __init__(self):
"""
Initializes the DemoTextConv class with predefined options and models.
__init__ Initializes the DemoTextGen class with predefined
options and models and returns result of prompt .
"""
model_path = "microsoft/phi-2"
tokenizer_path = "microsoft/phi-2"
Expand Down
8 changes: 8 additions & 0 deletions sdk/demo/demo_text_to_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,16 @@


class DemoTextToImg:
"""
This class demonstrates text-to-image generation using a diffusion model.

"""

def __init__(self):
"""
__init__ Initializes the DemoTextToImg class with
predefined options and models and returns result of prompt .
"""
model_stabilityai_name = "stabilityai/sdxl-turbo"
model_stabilityai_path = "stabilityai/sdxl-turbo"

Expand Down
7 changes: 7 additions & 0 deletions sdk/demo/demo_text_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@


class DemoTextToVideo:
"""
This class demonstrates text generation using a TextToVideo model.

"""
def __init__(self):
"""
__init__ Initializes the DemoTextGen class with predefined
options and models and returns result of prompt .
"""
model_name = "damo-vilab/text-to-video-ms-1.7b"
model_path = "damo-vilab/text-to-video-ms-1.7b"

Expand Down
136 changes: 59 additions & 77 deletions sdk/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def validate(self):
"""
Validate the model.

Returns:
Program exits if invalid
:returns: Program exits if invalid
OmarElChamaa marked this conversation as resolved.
Show resolved Hide resolved
"""

# Check if the model name is not empty
Expand All @@ -117,20 +116,17 @@ def belongs_to_module(self, module: str) -> bool:
"""
Check if the model belongs to a given module.

Returns:
bool: True if the model belongs to the module, False otherwise.
:param module: module to check
:returns: True if the model belongs to the module, False otherwise.
OmarElChamaa marked this conversation as resolved.
Show resolved Hide resolved
"""
return self.module == module

def build_paths(self, models_path: str) -> None:
"""
Build paths for the model.

Args:
models_path (str): The base path where all the models are located.

Returns:
None
:param models_path: The base path where all the models are located.
:return: None
"""

# Local path to the model directory
Expand All @@ -154,19 +150,16 @@ def process(self, models_path: str, skip: str = "",
"""
Process the model.

Args:
models_path (str): The base path where all the models are located.
skip (str): Optional. Skips the download process of either the
model or the tokenizer.
only_configuration (bool): Optional. Whether to only get the
configuration properties without downloading anything or not.
overwrite (bool): Optional. Whether to overwrite the downloaded
model if it exists.

Returns:
Program exits with error if the process fails.
If it succeeds, it returns the JSON props used for downloading
the model.
:param models_path: The base path where all the models are located.
:param skip: Optional. Skips the download process of either the
model or the tokenizer.
:param only_configuration: Optional. Whether to only get the
configuration properties without downloading anything or not.
:param overwrite: Optional. Whether to overwrite the downloaded
model if it exists.
:return: If the process succeeds, it returns the JSON properties
used for downloading the model.
Otherwise, the program exits with an error.
"""

# Validate mandatory arguments
Expand Down Expand Up @@ -226,15 +219,15 @@ def download(self, skip: str, overwrite: bool,
"""
Download the model.

Args:
skip (str): Skips the download process of either the model
or the tokenizer.
overwrite (bool): Whether to overwrite the downloaded model
if it exists.
options (dict): The options dictionary for the model.
options_tokenizer (dict): The options dictionary for the tokenizer.
access_token (str): The access token for the model
:param skip: Skips the download process of either the model
or the tokenizer.
:param overwrite: Whether to overwrite the downloaded model
if it exists.
:param options: The options dictionary for the model.
:param options_tokenizer: The options dictionary for the tokenizer.
:param access_token: The access token for the model.
"""

# Checking for model download
if skip != DOWNLOAD_MODEL:
# Downloading the model
Expand All @@ -256,10 +249,10 @@ def set_class_names(model: Model, access_token: str | None) -> None:
Set the appropriate model class name based on the model's module.
And Set the appropriate tokenizer class name if needed.

Args:
model (Model): The model object.
access_token (str): The access token for the model
:param model: The model object.
:param access_token: The access token for the model.
"""

if model.belongs_to_module(TRANSFORMERS):
set_transformers_class_names(model, access_token)
elif model.belongs_to_module(DIFFUSERS):
Expand All @@ -272,10 +265,10 @@ def set_transformers_class_names(model: Model,
Set the appropriate model class for a Transformers module model
and tokenizer.

Args:
model (Model): The model object.
access_token (str): The access token for the model
:param model: The model object.
:param access_token: The access token for the model.
"""

try:
# Get the configuration
config = transformers.AutoConfig.from_pretrained(
Expand Down Expand Up @@ -307,10 +300,10 @@ def set_diffusers_class_names(model: Model, access_token: str | None) -> None:
"""
Set the appropriate model class for a Diffusers module model.

Args:
model (Model): The model object.
access_token (str): The access token for the model
:param model: The model object.
:param access_token: The access token for the model.
"""

if model.class_name is not None and model.class_name != "":
return

Expand All @@ -330,12 +323,10 @@ def download_model(model: Model, overwrite: bool, options: dict,
"""
Download the model.

Args:
model (Model): Model to be downloaded.
overwrite (bool): Whether to overwrite the downloaded model if
it exists.
options: A dictionary containing options used for model downloading.
access_token (str): The access token for the model
:param model: Model to be downloaded.
:param overwrite: Whether to overwrite the downloaded model if it exists.
:param options: A dictionary containing options used for model downloading.
:param access_token: The access token for the model.
"""

# Check if the model already exists at path
Expand Down Expand Up @@ -374,11 +365,10 @@ def download_transformers_tokenizer(model: Model, overwrite: bool,
"""
Download a transformers tokenizer for the model.

Args:
model (Model): Model to be downloaded.
overwrite (bool): Whether to overwrite the downloaded model if
it exists.
options: A dictionary containing options used for model downloading.
:param model: Model to be downloaded.
:param overwrite: Whether to overwrite the downloaded model if
it exists.
:param options: A dictionary containing options used for model downloading.
"""

# Retrieving tokenizer class from module
Expand Down Expand Up @@ -415,13 +405,12 @@ def is_path_valid_for_download(path: str, overwrite: bool) -> bool:
"""
Check if the path is valid for downloading.

Args:
path (str): The path to check.
overwrite (bool): Whether to overwrite existing files.
:param path: The path to check.
:param overwrite: Whether to overwrite existing files.

Returns:
bool: True if the path is valid for download, False otherwise.
:return: True if the path is valid for download, False otherwise.
"""

return overwrite or not os.path.exists(path) or not os.listdir(path)


Expand All @@ -430,13 +419,11 @@ def process_options(options_list: list) -> dict:
Process the options provided as a list of strings and convert them into a
dictionary.

Args:
options_list (list): A list of options in the form of strings, where
each string is in the format 'key=value'.
:param options_list: A list of options in the form of strings, where
each string is in the format 'key=value'.

Returns:
dict: A dictionary containing the processed options, where keys are
the option names and values are the corresponding values.
:return: A dictionary containing the processed options, where keys are
the option names and values are the corresponding values.
"""

# Processed options
Expand Down Expand Up @@ -505,12 +492,10 @@ def process_access_token(options: dict, model: Model) -> str | None:
"""
Process the access token since it can be provided through options and flags

Args:
options (dict): A dictionary containing the processed options.
model (Model): Model to be downloaded.
:param options: A dictionary containing the processed options.
:param model: Model to be downloaded.

Returns:
str: The value of the access token (if provided).
:return: The value of the access token (if provided).
"""

# If conflicting access tokens are provided, raise an error
Expand All @@ -537,11 +522,11 @@ def process_access_token(options: dict, model: Model) -> str | None:
def get_options_for_json(options_dict: dict) -> dict:
"""
Prepares a dictionary containing options for conversion to JSON.
Args:
options_dict (dict): A dictionary containing options as key-value pairs.
Returns:
dict: A new dictionary with the same keys but with values prepared for
JSON serialization (strings with quotes for string values).

:param options_dict: A dictionary containing options as key-value pairs.

:return: A new dictionary with the same keys but with values prepared for
JSON serialization (strings with quotes for string values).
"""

# Create a shallow copy of the input dictionary
Expand All @@ -559,11 +544,9 @@ def map_args_to_model(args) -> Model:
"""
Maps command-line arguments to a Model object.

Args:
args (argparse.Namespace): Parsed command-line arguments.
:param args: Parsed command-line arguments.

Returns:
Model: Model object representing the configuration.
:return: Model object representing the configuration.
"""

# Mapping to tokenizer
Expand All @@ -578,8 +561,7 @@ def parse_arguments():
"""
Parse command-line arguments.

Returns:
argparse.Namespace: Parsed command-line arguments.
:return: Parsed command-line arguments as argparse.Namespace.
"""

parser = argparse.ArgumentParser(
Expand Down
39 changes: 32 additions & 7 deletions sdk/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@

class Model:
"""
Abstract base class for all models
Model Abstract base class for all models.

model_name (str): The name of the model.
model_path (str): The path of the model.
device (Union[str, Devices]): Which device the model must be on.
loaded (bool): Indicates if the model is loaded.
single_file (bool): Indicates if the model is a single file.
"""

model_name: str
model_path: str
device: Union[str, Devices]
Expand All @@ -17,13 +24,15 @@ class Model:
def __init__(self, model_name, model_path: str,
device: Union[str, Devices], single_file: bool = False):
"""
Initializes the model with the given name
__init__ Initializes the model with the given name.


:param model_name: The name of the model.
:param model_path: The path of the model.
:param device: Defines which device the model must be on.
:param single_file: Defines whether model is single file or not.

Args:
model_name (str): The name of the model
model_path (str): The path of the model
device (Union[str, Devices]): Which device the model must be on
:param single_file: Whether model is single file or not
:return: Returns an instance of model
"""
self.model_name = model_name
self.model_path = model_path
Expand All @@ -33,12 +42,28 @@ def __init__(self, model_name, model_path: str,

@abstractmethod
def load_model(self) -> bool:
"""
load_model Loads the model on the given device.

:returns: True if the model is successfully loaded.
OmarElChamaa marked this conversation as resolved.
Show resolved Hide resolved
"""
raise NotImplementedError

@abstractmethod
def unload_model(self) -> bool:
"""
unload_model unloads the model.

:returns: True if the model is successfully unloaded.
OmarElChamaa marked this conversation as resolved.
Show resolved Hide resolved
"""
raise NotImplementedError

@abstractmethod
def generate_prompt(self, prompt: str, **kwargs):
"""
generate_prompt Generates the prompt with the given option.

:param prompt: The prompt used to generate
:param kwargs: Additional parameters for generating the prompt.
"""
raise NotImplementedError
Loading
Loading