-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
78 sdk correction on the text conversation model (#81)
* Modification of the model text conversation, this code must be test * Correction of the naming * Creation of the model transformers class, modification of all model class and demo, need test * Fix on modules and args * fix issues on the text conv * Suppression of options class and some corrections * fix model.py * fix main.py * fix option.py * fix demo_text_conv
- Loading branch information
Showing
16 changed files
with
283 additions
and
916 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,21 @@ | ||
import argparse | ||
from sdk.demo import DemoTextConv, DemoTxtToImg | ||
from sdk.demo import DemoTextConv, DemoTextToImg | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
parser = argparse.ArgumentParser(description='Choose a Model type :' | ||
' TextConv or TxtToImg') | ||
' TextConv or TextToImg') | ||
subparser = parser.add_subparsers(dest='option') | ||
|
||
conv = subparser.add_parser('TextConv') | ||
img = subparser.add_parser('TxtToImg') | ||
img = subparser.add_parser('TextToImg') | ||
args = parser.parse_args() | ||
|
||
match args.option: | ||
case 'TextConv': | ||
DemoTextConv() | ||
case ('TxtToImg'): | ||
DemoTxtToImg() | ||
case ('TextToImg'): | ||
DemoTextToImg() | ||
case _: | ||
DemoTxtToImg() | ||
DemoTextToImg() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
# flake8: noqa F401 | ||
from .demo_text_conv import DemoTextConv | ||
from .demo_txt_to_img import DemoTxtToImg | ||
from .demo_text_to_img import DemoTextToImg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
30 changes: 16 additions & 14 deletions
30
sdk/demo/demo_txt_to_img.py → sdk/demo/demo_text_to_img.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,29 @@ | ||
import torch | ||
from sdk.models import ModelTextToImage, ModelsManagement | ||
from sdk.options import Devices, OptionsTextToImage | ||
from sdk.options import Devices | ||
|
||
|
||
class DemoTxtToImg: | ||
class DemoTextToImg: | ||
|
||
def __init__(self): | ||
options = OptionsTextToImage( | ||
prompt="Astronaut in a jungle, cold color palette, " | ||
"muted colors, detailed, 8k", | ||
device=Devices.GPU, | ||
image_width=512, | ||
image_height=512, | ||
) | ||
|
||
model_stabilityai_name = "stabilityai/sdxl-turbo" | ||
model_stabilityai_path = "stabilityai/sdxl-turbo" | ||
model_management = ModelsManagement() | ||
model_stabilityai = ModelTextToImage(model_stabilityai_name, | ||
model_stabilityai_path) | ||
model_stabilityai_path, | ||
Devices.GPU, | ||
torch_dtype=torch.float16, | ||
use_safetensors=True, | ||
add_watermarker=False, | ||
variant="fp16") | ||
|
||
model_management.add_model(new_model=model_stabilityai, | ||
model_options=options) | ||
model_management.add_model(new_model=model_stabilityai) | ||
model_management.load_model(model_stabilityai_name) | ||
|
||
image = model_management.generate_prompt() | ||
image = model_management.generate_prompt( | ||
prompt="Astronaut in a jungle, cold color palette, " | ||
"muted colors, detailed, 8k", | ||
image_width=512, | ||
image_height=512 | ||
) | ||
image.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# flake8: noqa F401 | ||
from sdk.models.model import Model | ||
from sdk.models.model_transformers import ModelTransformers | ||
from sdk.models.models_management import ModelsManagement | ||
from sdk.models.model_text_to_image import ModelTextToImage | ||
from sdk.models.model_text_conversation import ModelsTextConversation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.