Skip to content

Commit

Permalink
Merge branch 'develop' of github.com:easy-model-fusion/sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
SkelNeXus committed Mar 22, 2024
2 parents 02b72e3 + 8c36040 commit 5fd59aa
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 19 deletions.
8 changes: 6 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import argparse
from sdk.demo import DemoTextConv, DemoTextToImg, DemoTextGen
from sdk.demo import DemoTextConv, DemoTextToImg, DemoTextGen, DemoTextToVideo


if __name__ == '__main__':

parser = argparse.ArgumentParser(description='Choose a Model type :'
' TextConv, TextToImg '
' TextConv, TextToImg, '
'TextToVideo '
'or TextGen')
subparser = parser.add_subparsers(dest='option')

conv = subparser.add_parser('TextConv')
img = subparser.add_parser('TextToImg')
gen = subparser.add_parser('TextGen')
video = subparser.add_parser('TextToVideo')
args = parser.parse_args()

match args.option:
Expand All @@ -21,5 +23,7 @@
DemoTextToImg()
case 'TextGen':
DemoTextGen()
case 'TextToVideo':
DemoTextToVideo()
case _:
DemoTextToImg()
1 change: 1 addition & 0 deletions sdk/demo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .demo_text_conv import DemoTextConv
from .demo_text_to_img import DemoTextToImg
from .demo_text_generation import DemoTextGen
from .demo_text_to_video import DemoTextToVideo
26 changes: 17 additions & 9 deletions sdk/demo/demo_text_to_img.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from sdk.models import ModelTextToImage, ModelsManagement
from diffusers import DiffusionPipeline
from sdk.models import ModelDiffusers, ModelsManagement
from sdk.options import Devices


Expand All @@ -8,14 +9,21 @@ class DemoTextToImg:
def __init__(self):
model_stabilityai_name = "stabilityai/sdxl-turbo"
model_stabilityai_path = "stabilityai/sdxl-turbo"

model_options = {
'torch_dtype': torch.float16,
'use_safetensors': True,
'add_watermarker': False,
'variant': "fp16"
}

model_management = ModelsManagement()
model_stabilityai = ModelTextToImage(model_stabilityai_name,
model_stabilityai_path,
Devices.GPU,
torch_dtype=torch.float16,
use_safetensors=True,
add_watermarker=False,
variant="fp16")
model_stabilityai = ModelDiffusers(
model_name=model_stabilityai_name,
model_path=model_stabilityai_path,
model_class=DiffusionPipeline,
device=Devices.GPU,
**model_options)

model_management.add_model(new_model=model_stabilityai)
model_management.load_model(model_stabilityai_name)
Expand All @@ -25,5 +33,5 @@ def __init__(self):
"muted colors, detailed, 8k",
image_width=512,
image_height=512
)
).images[0]
image.show()
43 changes: 43 additions & 0 deletions sdk/demo/demo_text_to_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
from diffusers import DiffusionPipeline
from diffusers.utils import export_to_video

from sdk.models import ModelDiffusers, ModelsManagement
from sdk.options import Devices


class DemoTextToVideo:

def __init__(self):
model_name = "damo-vilab/text-to-video-ms-1.7b"
model_path = "damo-vilab/text-to-video-ms-1.7b"

model_options = {
'torch_dtype': torch.float16,
'use_safetensors': True,
'add_watermarker': False,
'variant': "fp16"
}

model_management = ModelsManagement()
model = ModelDiffusers(
model_name=model_name,
model_path=model_path,
model_class=DiffusionPipeline,
device=Devices.GPU,
**model_options)

model_management.add_model(new_model=model)
model_management.load_model(model_name)

prompt = ("Astronaut in a jungle, cold color palette,"
" muted colors, detailed, 8k")

video_frames = model_management.generate_prompt(
prompt=prompt,
num_inference_steps=25
).frames

video_path = export_to_video(video_frames)

print(video_path)
2 changes: 1 addition & 1 deletion sdk/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from sdk.models.model import Model
from sdk.models.model_transformers import ModelTransformers
from sdk.models.models_management import ModelsManagement
from sdk.models.model_text_to_image import ModelTextToImage
from sdk.models.model_diffusers import ModelDiffusers
from sdk.models.model_text_conversation import ModelsTextConversation
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
import torch
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
from sdk.options import Devices
from typing import Union
from typing import Union, Any
from sdk.models import Model


class ModelTextToImage(Model):
class ModelDiffusers(Model):
"""
This class implements methods to generate images with a text prompt
"""
pipeline: StableDiffusionXLPipeline
model_class: Any
pipeline: str
loaded: bool
device: Union[str, Devices]

def __init__(self, model_name: str, model_path: str,
model_class: Any,
device: Union[str, Devices], **kwargs):
"""
Initializes the ModelsTextToImage class
:param model_name: The name of the model
:param model_path: The path of the model
"""
super().__init__(model_name, model_path)
self.model_class = model_class
self.device = device
self.loaded = False
self.create_pipeline(**kwargs)
Expand All @@ -32,7 +34,7 @@ def create_pipeline(self, **kwargs):
if self.loaded:
return

self.pipeline = DiffusionPipeline.from_pretrained(
self.pipeline = self.model_class.from_pretrained(
self.model_path,
**kwargs
)
Expand All @@ -59,7 +61,7 @@ def unload_model(self) -> bool:
"""
if not self.loaded:
return False
self.pipeline.to(device=Devices.RESET)
self.pipeline.to(device=Devices.RESET.value)
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
self.loaded = False
Expand All @@ -75,4 +77,4 @@ def generate_prompt(self, prompt: str,
return self.pipeline(
prompt=prompt,
**kwargs
).images[0]
)

0 comments on commit 5fd59aa

Please sign in to comment.