-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #86 from easy-model-fusion/develop
- Loading branch information
Showing
19 changed files
with
332 additions
and
924 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,25 @@ | ||
import argparse | ||
from sdk.demo import DemoTextConv, DemoTxtToImg | ||
from sdk.demo import DemoTextConv, DemoTextToImg, DemoTextGen | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
parser = argparse.ArgumentParser(description='Choose a Model type :' | ||
' TextConv or TxtToImg') | ||
' TextConv, TextToImg ' | ||
'or TextGen') | ||
subparser = parser.add_subparsers(dest='option') | ||
|
||
conv = subparser.add_parser('TextConv') | ||
img = subparser.add_parser('TxtToImg') | ||
img = subparser.add_parser('TextToImg') | ||
gen = subparser.add_parser('TextGen') | ||
args = parser.parse_args() | ||
|
||
match args.option: | ||
case 'TextConv': | ||
DemoTextConv() | ||
case ('TxtToImg'): | ||
DemoTxtToImg() | ||
case 'TextToImg': | ||
DemoTextToImg() | ||
case 'TextGen': | ||
DemoTextGen() | ||
case _: | ||
DemoTxtToImg() | ||
DemoTextToImg() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
# flake8: noqa F401 | ||
from .demo_text_conv import DemoTextConv | ||
from .demo_txt_to_img import DemoTxtToImg | ||
from .demo_text_to_img import DemoTextToImg | ||
from .demo_text_generation import DemoTextGen |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from sdk.options import Devices | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
|
||
from sdk.models import ModelTransformers | ||
|
||
|
||
class DemoTextGen: | ||
""" | ||
This class demonstrates a text conversation using a chatbot model. | ||
""" | ||
|
||
def __init__(self): | ||
""" | ||
Initializes the DemoTextConv class with predefined options and models. | ||
""" | ||
model_path = "microsoft/phi-2" | ||
tokenizer_path = "microsoft/phi-2" | ||
|
||
model_transformers = ModelTransformers( | ||
model_name="model", | ||
model_path=model_path, | ||
tokenizer_path=tokenizer_path, | ||
task="text-generation", | ||
model_class=AutoModelForCausalLM, | ||
tokenizer_class=AutoTokenizer, | ||
device=Devices.GPU | ||
) | ||
|
||
model_transformers.load_model() | ||
|
||
result = model_transformers.generate_prompt( | ||
prompt="I'm looking for a movie - what's your favourite one?", | ||
max_length=300, | ||
truncation=True | ||
) | ||
|
||
print(result) |
30 changes: 16 additions & 14 deletions
30
sdk/demo/demo_txt_to_img.py → sdk/demo/demo_text_to_img.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,29 @@ | ||
import torch | ||
from sdk.models import ModelTextToImage, ModelsManagement | ||
from sdk.options import Devices, OptionsTextToImage | ||
from sdk.options import Devices | ||
|
||
|
||
class DemoTxtToImg: | ||
class DemoTextToImg: | ||
|
||
def __init__(self): | ||
options = OptionsTextToImage( | ||
prompt="Astronaut in a jungle, cold color palette, " | ||
"muted colors, detailed, 8k", | ||
device=Devices.GPU, | ||
image_width=512, | ||
image_height=512, | ||
) | ||
|
||
model_stabilityai_name = "stabilityai/sdxl-turbo" | ||
model_stabilityai_path = "stabilityai/sdxl-turbo" | ||
model_management = ModelsManagement() | ||
model_stabilityai = ModelTextToImage(model_stabilityai_name, | ||
model_stabilityai_path) | ||
model_stabilityai_path, | ||
Devices.GPU, | ||
torch_dtype=torch.float16, | ||
use_safetensors=True, | ||
add_watermarker=False, | ||
variant="fp16") | ||
|
||
model_management.add_model(new_model=model_stabilityai, | ||
model_options=options) | ||
model_management.add_model(new_model=model_stabilityai) | ||
model_management.load_model(model_stabilityai_name) | ||
|
||
image = model_management.generate_prompt() | ||
image = model_management.generate_prompt( | ||
prompt="Astronaut in a jungle, cold color palette, " | ||
"muted colors, detailed, 8k", | ||
image_width=512, | ||
image_height=512 | ||
) | ||
image.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# flake8: noqa F401 | ||
from sdk.models.model import Model | ||
from sdk.models.model_transformers import ModelTransformers | ||
from sdk.models.models_management import ModelsManagement | ||
from sdk.models.model_text_to_image import ModelTextToImage | ||
from sdk.models.model_text_conversation import ModelsTextConversation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.