Skip to content

Commit

Permalink
#42 sdk create the generate model schematic to prepare the model gene…
Browse files Browse the repository at this point in the history
…rator (#49)

* Add path to model and kargs to generate method

* Creation of __init__.py in all directories

* Change all import to use the __init__.py mecanisme

* Fix bug import class in the same directory

* add generated model into the init
  • Loading branch information
SkelNeXus authored Feb 21, 2024
1 parent efdb67f commit 3303cd7
Show file tree
Hide file tree
Showing 16 changed files with 106 additions and 88 deletions.
1 change: 1 addition & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .generated_models import *
2 changes: 2 additions & 0 deletions demo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .demo_text_conv import DemoTextConv
from .demo_txt_to_img import DemoTxtToImg
81 changes: 40 additions & 41 deletions demo/demo_text_conv.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import torch
from models.models_management import ModelsManagement
from models.model_text_conversation import ModelsTextConversation
from options.options import Devices
from options.options_text_conversation import OptionsTextConversation
from models import ModelsManagement, ModelsTextConversation
from options import Devices, OptionsTextConversation


class DemoTextConv:

def __init__(self):
model_name = "facebook/blenderbot-400M-distill"
model_path = "facebook/blenderbot-400M-distill"

options = OptionsTextConversation(
prompt="Hello",
Expand All @@ -19,52 +18,52 @@ def __init__(self):
)

model_management = ModelsManagement()
model = ModelsTextConversation(model_name)
model = ModelsTextConversation(model_name, model_path)

model_management.add_model(new_model=model, model_options=options)
model_management.load_model(model_name)

model_management.generate_prompt(options.prompt)
self.demo(model_management.loaded_model)
if model_management.loaded_model:
self.demo(model_management.loaded_model)

def demo(self, model: ModelsTextConversation):
print("Model name : ", model.model_name)
print("User : Hello")

if model:
new_user_input_ids = model.tokenizer.encode(
"Hello ! "
+ model.tokenizer.eos_token,
return_tensors='pt')
# append the new user input tokens to the chat history
bot_input_ids = torch.cat(
[model.chat_history_ids,
new_user_input_ids],
dim=-1) if model.conversation_step > 0 else (
new_user_input_ids)
new_user_input_ids = model.tokenizer.encode(
"Hello ! "
+ model.tokenizer.eos_token,
return_tensors='pt')
# append the new user input tokens to the chat history
bot_input_ids = torch.cat(
[model.chat_history_ids,
new_user_input_ids],
dim=-1) if model.conversation_step > 0 else (
new_user_input_ids)

# generated a response while limiting the total
# chat history to 1000 tokens,
chat_history_ids = model.pipeline.generate(
bot_input_ids, max_length=1000,
pad_token_id=model.tokenizer.eos_token_id)
print("Chatbot: {}".format(
model.tokenizer.decode(
chat_history_ids[:, bot_input_ids.shape[-1]:][0],
skip_special_tokens=True)))
print("User : How are you ?")
new_user_input_ids = model.tokenizer.encode(
"How are you ? " + model.tokenizer.eos_token,
return_tensors='pt')
bot_input_ids = torch.cat(
[model.chat_history_ids, new_user_input_ids],
dim=-1) if model.conversation_step > 0 else new_user_input_ids
chat_history_ids = model.pipeline.generate(
bot_input_ids,
max_length=1000,
pad_token_id=model.tokenizer.eos_token_id)
# generated a response while limiting the total
# chat history to 1000 tokens,
chat_history_ids = model.pipeline.generate(
bot_input_ids, max_length=1000,
pad_token_id=model.tokenizer.eos_token_id)
print("Chatbot: {}".format(
model.tokenizer.decode(
chat_history_ids[:, bot_input_ids.shape[-1]:][0],
skip_special_tokens=True)))
print("User : How are you ?")
new_user_input_ids = model.tokenizer.encode(
"How are you ? " + model.tokenizer.eos_token,
return_tensors='pt')
bot_input_ids = torch.cat(
[model.chat_history_ids, new_user_input_ids],
dim=-1) if model.conversation_step > 0 else new_user_input_ids
chat_history_ids = model.pipeline.generate(
bot_input_ids,
max_length=1000,
pad_token_id=model.tokenizer.eos_token_id)

print("Chatbot: {}".format(
model.tokenizer.decode(
chat_history_ids[:, bot_input_ids.shape[-1]:][0],
skip_special_tokens=True)))
print("Chatbot: {}".format(
model.tokenizer.decode(
chat_history_ids[:, bot_input_ids.shape[-1]:][0],
skip_special_tokens=True)))
15 changes: 7 additions & 8 deletions demo/demo_txt_to_img.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
from models.model_text_to_image import ModelTextToImage
from models.models_management import ModelsManagement
from models import ModelTextToImage, ModelsManagement
from options import Devices, OptionsTextToImage

from options.options import Devices
from options.options_text_to_image import OptionsTextToImage


class DemoMainTxtToImg():
class DemoTxtToImg:

def __init__(self):
options = OptionsTextToImage(
prompt="Astronaut in a jungle, cold color palette, "
"muted colors, detailed, 8k",
device=Devices.GPU,
image_width=512,
image_height=512
image_height=512,
)

model_stabilityai_name = "stabilityai/sdxl-turbo"
model_stabilityai_path = "stabilityai/sdxl-turbo"
model_management = ModelsManagement()
model_stabilityai = ModelTextToImage(model_stabilityai_name)
model_stabilityai = ModelTextToImage(model_stabilityai_name,
model_stabilityai_path)

model_management.add_model(new_model=model_stabilityai,
model_options=options)
Expand Down
3 changes: 3 additions & 0 deletions generated_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
DO NOT EDIT, THIS FILE WILL BE GENERATED AUTOMATICALLY
"""
Empty file added generated_models/__init__.py
Empty file.
8 changes: 4 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
from demo.demo_text_conv import DemoTextConv
from demo.demo_txt_to_img import DemoMainTxtToImg
from demo import DemoTextConv, DemoTxtToImg


if __name__ == '__main__':

Expand All @@ -16,6 +16,6 @@
case 'TextConv':
DemoTextConv()
case ('TxtToImg'):
DemoMainTxtToImg()
DemoTxtToImg()
case _:
DemoMainTxtToImg()
DemoTxtToImg()
4 changes: 4 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .model_text_to_image import ModelTextToImage
from .models_management import ModelsManagement
from .model_text_conversation import ModelsTextConversation
from .model import Model
10 changes: 7 additions & 3 deletions models/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from options.options import Options
from options import Options
from typing import Optional


Expand All @@ -8,13 +8,16 @@ class Model:
Abstract base class for all models
"""
model_name: str
model_path: str

def __init__(self, model_name: str):
def __init__(self, model_name, model_path: str):
"""
Initializes the model with the given name
:param model_name: The name of the model
:param model_path: The path of the model
"""
self.model_name = model_name
self.model_path = model_path

@abstractmethod
def load_model(self, option: Options) -> bool:
Expand All @@ -25,5 +28,6 @@ def unload_model(self) -> bool:
raise NotImplementedError

@abstractmethod
def generate_prompt(self, prompt: Optional[str], option: Options):
def generate_prompt(self, prompt: Optional[str],
option: Options, **kwargs):
raise NotImplementedError
25 changes: 12 additions & 13 deletions models/model_text_conversation.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,26 @@
from typing import Optional

import torch
from transformers import (AutoModelForCausalLM, AutoTokenizer,
from typing import Optional
from transformers import (AutoModel, AutoTokenizer,
ConversationalPipeline, Conversation)
from models.model import Model
from options.options import Devices
from options.options_text_conversation import OptionsTextConversation
from .model import Model
from options import Devices, OptionsTextConversation


class ModelsTextConversation(Model):
pipeline: ConversationalPipeline
tokenizer: AutoTokenizer
model_name: str
loaded: bool
chat_bot: Conversation
conversation_step: int = 0
chat_history_token_ids = []

def __init__(self, model_name: str):
def __init__(self, model_name: str, model_path: str):
"""
Initializes the ModelsTextToImage class
:param model_name: The name of the model
:param model_path: The path of the model
"""
super().__init__(model_name)
super().__init__(model_name, model_path)
self.loaded = False
self.create_pipeline()

Expand All @@ -33,9 +31,9 @@ def create_pipeline(self):
if self.loaded:
return

self.pipeline = AutoModelForCausalLM.from_pretrained(self.model_name)
self.pipeline = AutoModel.from_pretrained(self.model_path)
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
self.model_path,
trust_remote_code=True,
padding_side='left')

Expand Down Expand Up @@ -69,6 +67,7 @@ def unload_model(self):

def generate_prompt(
self, prompt: Optional[str],
option: OptionsTextConversation):
option: OptionsTextConversation,
**kwargs):
prompt = prompt if prompt else option.prompt
return Conversation(prompt)
return Conversation(prompt, **kwargs)
20 changes: 11 additions & 9 deletions models/model_text_to_image.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
import torch
from models.model import Model
from options.options_text_to_image import OptionsTextToImage, Devices
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
from options import OptionsTextToImage, Devices
from typing import Optional
from .model import Model


class ModelTextToImage(Model):
"""
This class implements methods to generate images with a text prompt
"""
pipeline: StableDiffusionXLPipeline
model_name: str
loaded: bool

def __init__(self, model_name: str):
def __init__(self, model_name: str, model_path: str):
"""
Initializes the ModelsTextToImage class
:param model_name: The name of the model
:param model_path: The path of the model
"""
super().__init__(model_name)
super().__init__(model_name, model_path)
self.loaded = False
self.create_pipeline()

Expand All @@ -30,7 +30,7 @@ def create_pipeline(self):
return

self.pipeline = DiffusionPipeline.from_pretrained(
self.model_name,
self.model_path,
torch_dtype=torch.float16,
use_safetensors=True,
add_watermarker=False,
Expand Down Expand Up @@ -65,7 +65,8 @@ def unload_model(self) -> bool:
return True

def generate_prompt(self, prompt: Optional[str],
options: OptionsTextToImage):
options: OptionsTextToImage,
**kwargs):
"""
Generates the prompt with the given option
:param prompt: The optional prompt
Expand Down Expand Up @@ -108,5 +109,6 @@ def generate_prompt(self, prompt: Optional[str],
clip_skip=options.clip_skip,
callback_on_step_end=options.callback_on_step_end,
callback_on_step_end_tensor_inputs=(
options.callback_on_step_end_tensor_inputs)
options.callback_on_step_end_tensor_inputs),
**kwargs
).images[0]
16 changes: 9 additions & 7 deletions models/models_management.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional, Dict
from models.model import Model
from options.options import Options
from .model import Model
from options import Options


class ModelsManagement:
Expand Down Expand Up @@ -88,7 +88,7 @@ def set_model_options(self, model_name: str, options: Options):
"""
self.options_models[model_name] = options

def generate_prompt(self, prompt: Optional[str] = None):
def generate_prompt(self, prompt: Optional[str] = None, **kwargs):
"""
Generates the prompt for the loaded model with his stored options
:param prompt: The prompt to generate (if the prompt is empty, the
Expand All @@ -100,10 +100,12 @@ def generate_prompt(self, prompt: Optional[str] = None):
return

return (
self.loaded_model.generate_prompt(prompt,
self.options_models[
self.loaded_model.model_name]
)
self.loaded_model.generate_prompt(
prompt,
self.options_models[
self.loaded_model.model_name],
**kwargs
)
)

def print_models(self):
Expand Down
3 changes: 3 additions & 0 deletions options/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .options_text_to_image import OptionsTextToImage
from .options_text_conversation import OptionsTextConversation
from .options import Options, Devices
2 changes: 1 addition & 1 deletion options/options_text_conversation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from options.options import Options, Devices
import torch
from .options import Options, Devices
from typing import Optional, Union, Dict, Any
from transformers import (PreTrainedModel, TFPreTrainedModel,
PreTrainedTokenizer,
Expand Down
4 changes: 2 additions & 2 deletions options/options_text_to_image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional, Union, List, Dict, Tuple, Any, Callable
from diffusers.image_processor import PipelineImageInput
from options.options import Options, Devices
from .options import Options, Devices
import torch


Expand Down Expand Up @@ -43,7 +43,7 @@ class OptionsTextToImage(Options):

def __init__(
self,
device: Devices,
device: Union[str, Devices],
prompt: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
image_width: Optional[int] = None,
Expand Down
Empty file added tests/__init__.py
Empty file.

0 comments on commit 3303cd7

Please sign in to comment.