Skip to content

Commit

Permalink
Correction to match with the linter (#47)
Browse files Browse the repository at this point in the history
* Correction to match with the linter

* fix
  • Loading branch information
SkelNeXus authored Feb 20, 2024
1 parent 778bead commit ce23ed7
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 87 deletions.
15 changes: 10 additions & 5 deletions models/model_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,12 @@ def unload_model(self) -> bool:
self.loaded = False
return True

def generate_prompt(self, prompt: Optional[str], options: OptionsTextToImage):
def generate_prompt(self, prompt: Optional[str],
options: OptionsTextToImage):
"""
Generates the prompt with the given option
:param prompt: The optional prompt (if the prompt is empty, the options.prompt will be used)
:param prompt: The optional prompt
(if the prompt is empty, the options.prompt will be used)
:param options: The options of text to image model
:return: An object image resulting from the model
"""
Expand All @@ -89,7 +91,8 @@ def generate_prompt(self, prompt: Optional[str], options: OptionsTextToImage):
prompt_embeds=options.prompt_embeds,
negative_prompt_embeds=options.negative_prompt_embeds,
pooled_prompt_embeds=options.pooled_prompt_embeds,
negative_pooled_prompt_embeds=options.negative_pooled_prompt_embeds,
negative_pooled_prompt_embeds=(
options.negative_pooled_prompt_embeds),
ip_adapter_image=options.ip_adapter_image,
output_type=options.output_type,
return_dict=options.return_dict,
Expand All @@ -99,9 +102,11 @@ def generate_prompt(self, prompt: Optional[str], options: OptionsTextToImage):
crops_coords_top_left=options.crops_coords_top_left,
target_size=options.target_size,
negative_original_size=options.negative_original_size,
negative_crops_coords_top_left=options.negative_crops_coords_top_left,
negative_crops_coords_top_left=(
options.negative_crops_coords_top_left),
negative_target_size=options.negative_target_size,
clip_skip=options.clip_skip,
callback_on_step_end=options.callback_on_step_end,
callback_on_step_end_tensor_inputs=options.callback_on_step_end_tensor_inputs
callback_on_step_end_tensor_inputs=(
options.callback_on_step_end_tensor_inputs)
).images[0]
22 changes: 16 additions & 6 deletions models/models_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
class ModelsManagement:
"""
The ModelsManagement class controls all instantiated models.
It is with this class that you can deploy a model on a device and generate a prompt.
It is with this class that you can deploy a model on a device and
generate a prompt.
"""

def __init__(self):
Expand Down Expand Up @@ -39,15 +40,17 @@ def load_model(self, model_name: str) -> bool:
:return: True if the model is successfully loaded
"""
if self.loaded_model:
print("Unload the currently loaded model before loading a new one.")
print(
"Unload the currently loaded model before loading a new one.")
return False

if model_name not in self.loaded_models_cache:
print(f"Model '{model_name}' cannot be loaded: not found.")
return False

self.loaded_model = self.loaded_models_cache[model_name]
if not self.loaded_model.load_model(option=self.options_models[model_name]):
if not self.loaded_model.load_model(option=(
self.options_models[model_name])):
print("Something went wrong while unloading the model.")
self.loaded_model = None
return False
Expand Down Expand Up @@ -88,20 +91,27 @@ def set_model_options(self, model_name: str, options: Options):
def generate_prompt(self, prompt: Optional[str] = None):
"""
Generates the prompt for the loaded model with his stored options
:param prompt: The prompt to generate (if the prompt is empty, the options.prompt will be used)
:param prompt: The prompt to generate (if the prompt is empty, the
options.prompt will be used)
:return: The object of type link with the model category
"""
if not self.loaded_model:
print("No model loaded. Load a model before generating prompts.")
return

return self.loaded_model.generate_prompt(prompt, self.options_models[self.loaded_model.model_name])
return (
self.loaded_model.generate_prompt(prompt,
self.options_models[
self.loaded_model.model_name]
)
)

def print_models(self):
"""
Prints all models in the cache
"""
print("Models in cache:")
for model_name, model_instance in self.loaded_models_cache.items():
selected_indicator = "(selected)" if model_instance == self.loaded_model else ""
selected_indicator = (
"(selected)" if model_instance == self.loaded_model else "")
print(f"- {model_name} {selected_indicator}")
Loading

0 comments on commit ce23ed7

Please sign in to comment.