Skip to content

Commit

Permalink
78 sdk correction on the text conversation model (#81)
Browse files Browse the repository at this point in the history
* Modification of the model text conversation, this code must be test

* Correction of the naming

* Creation of the model transformers class, modification of all model class and demo, need test

* Fix on modules and args

* fix issues on the text conv

* Suppression of options class and some corrections

* fix model.py

* fix main.py

* fix option.py

* fix demo_text_conv
  • Loading branch information
SkelNeXus authored Mar 20, 2024
1 parent 039d718 commit c1a9580
Show file tree
Hide file tree
Showing 16 changed files with 283 additions and 916 deletions.
12 changes: 6 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import argparse
from sdk.demo import DemoTextConv, DemoTxtToImg
from sdk.demo import DemoTextConv, DemoTextToImg


if __name__ == '__main__':

parser = argparse.ArgumentParser(description='Choose a Model type :'
' TextConv or TxtToImg')
' TextConv or TextToImg')
subparser = parser.add_subparsers(dest='option')

conv = subparser.add_parser('TextConv')
img = subparser.add_parser('TxtToImg')
img = subparser.add_parser('TextToImg')
args = parser.parse_args()

match args.option:
case 'TextConv':
DemoTextConv()
case ('TxtToImg'):
DemoTxtToImg()
case ('TextToImg'):
DemoTextToImg()
case _:
DemoTxtToImg()
DemoTextToImg()
2 changes: 1 addition & 1 deletion sdk/demo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# flake8: noqa F401
from .demo_text_conv import DemoTextConv
from .demo_txt_to_img import DemoTxtToImg
from .demo_text_to_img import DemoTextToImg
88 changes: 16 additions & 72 deletions sdk/demo/demo_text_conv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from sdk.models import ModelsManagement, ModelsTextConversation
from sdk.options import Devices, OptionsTextConversation
from sdk.options.options_tokenizer import OptionsTokenizer
from sdk.tokenizers.tokenizer import Tokenizer
from sdk.models import ModelsTextConversation
from sdk.options import Devices
from transformers import AutoTokenizer, AutoModelForCausalLM


class DemoTextConv:
Expand All @@ -13,75 +12,20 @@ def __init__(self):
"""
Initializes the DemoTextConv class with predefined options and models.
"""
# Define the model name and path
model_name = "Salesforce/codegen-350M-nl"
model_path = "Salesforce/codegen-350M-nl"
model_path = "microsoft/phi-2"
tokenizer_path = "microsoft/phi-2"

# Define options for text conversation
options = OptionsTextConversation(
prompt="Hello, what's 3 + 3 ?",
device=Devices.GPU,
model_name=model_name,
trust_remote_code=True,
max_length=100
)
model = ModelsTextConversation(model_name="model",
model_path=model_path,
tokenizer_path=tokenizer_path,
model_class=AutoModelForCausalLM,
tokenizer_class=AutoTokenizer,
device=Devices.GPU)

# Define tokenizer options
tokenizer_options = OptionsTokenizer(
device='cuda',
padding_side='left',
return_tensors="pt"
)
model.load_model()
model.create_new_conversation()

tokenizer = Tokenizer("Salesforce/codegen-350M-nl",
"Salesforce/codegen-350M-nl",
"Salesforce/codegen-350M-nl",
tokenizer_options)
result = model.generate_prompt(
"I'm looking for a movie - what's your favourite one?")

# Initialize the model management
model_management = ModelsManagement()

# Create and load the text conversation model
model = ModelsTextConversation(model_name, model_path, options)
model.tokenizer = tokenizer
model_management.add_model(new_model=model, model_options=options)
model_management.load_model(model_name)

# Print the response to the initial prompt
print(model_management.generate_prompt(options.prompt))

# Generate a response to a custom prompt
print(model_management.generate_prompt(
prompt="What did I say before ?"))

# Create a new conversation with a new prompt
options = OptionsTextConversation(
prompt="Hello, what's 6 + 6 ?",
device=options.device,
model_name=model_name,
batch_size=1,
chat_id_to_use=1,
minimum_tokens=50,
create_new_conv=True
)

# Switch to the new conversation
model_management.set_model_options(model_name=model_name,
options=options)
print(model_management.generate_prompt(options.prompt))

print(model_management.generate_prompt("Where is Japan"))

# Switch back to the initial conversation
options.chat_id_to_use = 0
# Create a new tokenizer and use it
options.tokenizer_id_to_use = 1
model_management.set_model_options(model_name=model_name,
options=options)
tokenizer_options = OptionsTokenizer(
device='cuda',
padding_side='right',
return_tensors='pt'
)
model.tokenizer_options = tokenizer_options
print(model_management.generate_prompt("Bye "))
print(result.messages[-1]["content"])
30 changes: 16 additions & 14 deletions sdk/demo/demo_txt_to_img.py → sdk/demo/demo_text_to_img.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
import torch
from sdk.models import ModelTextToImage, ModelsManagement
from sdk.options import Devices, OptionsTextToImage
from sdk.options import Devices


class DemoTxtToImg:
class DemoTextToImg:

def __init__(self):
options = OptionsTextToImage(
prompt="Astronaut in a jungle, cold color palette, "
"muted colors, detailed, 8k",
device=Devices.GPU,
image_width=512,
image_height=512,
)

model_stabilityai_name = "stabilityai/sdxl-turbo"
model_stabilityai_path = "stabilityai/sdxl-turbo"
model_management = ModelsManagement()
model_stabilityai = ModelTextToImage(model_stabilityai_name,
model_stabilityai_path)
model_stabilityai_path,
Devices.GPU,
torch_dtype=torch.float16,
use_safetensors=True,
add_watermarker=False,
variant="fp16")

model_management.add_model(new_model=model_stabilityai,
model_options=options)
model_management.add_model(new_model=model_stabilityai)
model_management.load_model(model_stabilityai_name)

image = model_management.generate_prompt()
image = model_management.generate_prompt(
prompt="Astronaut in a jungle, cold color palette, "
"muted colors, detailed, 8k",
image_width=512,
image_height=512
)
image.show()
1 change: 1 addition & 0 deletions sdk/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# flake8: noqa F401
from sdk.models.model import Model
from sdk.models.model_transformers import ModelTransformers
from sdk.models.models_management import ModelsManagement
from sdk.models.model_text_to_image import ModelTextToImage
from sdk.models.model_text_conversation import ModelsTextConversation
7 changes: 2 additions & 5 deletions sdk/models/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from abc import abstractmethod
from sdk.options import Options
from typing import Optional


class Model:
Expand All @@ -20,14 +18,13 @@ def __init__(self, model_name, model_path: str):
self.model_path = model_path

@abstractmethod
def load_model(self, option: Options) -> bool:
def load_model(self) -> bool:
raise NotImplementedError

@abstractmethod
def unload_model(self) -> bool:
raise NotImplementedError

@abstractmethod
def generate_prompt(self, prompt: Optional[str],
option: Options, **kwargs):
def generate_prompt(self, prompt: str, **kwargs):
raise NotImplementedError
Loading

0 comments on commit c1a9580

Please sign in to comment.