diff --git a/main.py b/main.py index a2e4712..8c490c6 100644 --- a/main.py +++ b/main.py @@ -1,17 +1,19 @@ import argparse -from sdk.demo import DemoTextConv, DemoTextToImg, DemoTextGen +from sdk.demo import DemoTextConv, DemoTextToImg, DemoTextGen, DemoTextToVideo if __name__ == '__main__': parser = argparse.ArgumentParser(description='Choose a Model type :' - ' TextConv, TextToImg ' + ' TextConv, TextToImg, ' + 'TextToVideo ' 'or TextGen') subparser = parser.add_subparsers(dest='option') conv = subparser.add_parser('TextConv') img = subparser.add_parser('TextToImg') gen = subparser.add_parser('TextGen') + video = subparser.add_parser('TextToVideo') args = parser.parse_args() match args.option: @@ -21,5 +23,7 @@ DemoTextToImg() case 'TextGen': DemoTextGen() + case 'TextToVideo': + DemoTextToVideo() case _: DemoTextToImg() diff --git a/sdk/demo/__init__.py b/sdk/demo/__init__.py index c5f7a15..74c2371 100644 --- a/sdk/demo/__init__.py +++ b/sdk/demo/__init__.py @@ -2,3 +2,4 @@ from .demo_text_conv import DemoTextConv from .demo_text_to_img import DemoTextToImg from .demo_text_generation import DemoTextGen +from .demo_text_to_video import DemoTextToVideo diff --git a/sdk/demo/demo_text_to_img.py b/sdk/demo/demo_text_to_img.py index bb253dd..c602aa6 100644 --- a/sdk/demo/demo_text_to_img.py +++ b/sdk/demo/demo_text_to_img.py @@ -1,5 +1,6 @@ import torch -from sdk.models import ModelTextToImage, ModelsManagement +from diffusers import DiffusionPipeline +from sdk.models import ModelDiffusers, ModelsManagement from sdk.options import Devices @@ -8,14 +9,21 @@ class DemoTextToImg: def __init__(self): model_stabilityai_name = "stabilityai/sdxl-turbo" model_stabilityai_path = "stabilityai/sdxl-turbo" + + model_options = { + 'torch_dtype': torch.float16, + 'use_safetensors': True, + 'add_watermarker': False, + 'variant': "fp16" + } + model_management = ModelsManagement() - model_stabilityai = ModelTextToImage(model_stabilityai_name, - model_stabilityai_path, - Devices.GPU, - torch_dtype=torch.float16, - use_safetensors=True, - add_watermarker=False, - variant="fp16") + model_stabilityai = ModelDiffusers( + model_name=model_stabilityai_name, + model_path=model_stabilityai_path, + model_class=DiffusionPipeline, + device=Devices.GPU, + **model_options) model_management.add_model(new_model=model_stabilityai) model_management.load_model(model_stabilityai_name) @@ -25,5 +33,5 @@ def __init__(self): "muted colors, detailed, 8k", image_width=512, image_height=512 - ) + ).images[0] image.show() diff --git a/sdk/demo/demo_text_to_video.py b/sdk/demo/demo_text_to_video.py new file mode 100644 index 0000000..13d92ba --- /dev/null +++ b/sdk/demo/demo_text_to_video.py @@ -0,0 +1,43 @@ +import torch +from diffusers import DiffusionPipeline +from diffusers.utils import export_to_video + +from sdk.models import ModelDiffusers, ModelsManagement +from sdk.options import Devices + + +class DemoTextToVideo: + + def __init__(self): + model_name = "damo-vilab/text-to-video-ms-1.7b" + model_path = "damo-vilab/text-to-video-ms-1.7b" + + model_options = { + 'torch_dtype': torch.float16, + 'use_safetensors': True, + 'add_watermarker': False, + 'variant': "fp16" + } + + model_management = ModelsManagement() + model = ModelDiffusers( + model_name=model_name, + model_path=model_path, + model_class=DiffusionPipeline, + device=Devices.GPU, + **model_options) + + model_management.add_model(new_model=model) + model_management.load_model(model_name) + + prompt = ("Astronaut in a jungle, cold color palette," + " muted colors, detailed, 8k") + + video_frames = model_management.generate_prompt( + prompt=prompt, + num_inference_steps=25 + ).frames + + video_path = export_to_video(video_frames) + + print(video_path) diff --git a/sdk/models/__init__.py b/sdk/models/__init__.py index 86152b0..a1b47cb 100644 --- a/sdk/models/__init__.py +++ b/sdk/models/__init__.py @@ -2,5 +2,5 @@ from sdk.models.model import Model from sdk.models.model_transformers import ModelTransformers from sdk.models.models_management import ModelsManagement -from sdk.models.model_text_to_image import ModelTextToImage +from sdk.models.model_diffusers import ModelDiffusers from sdk.models.model_text_conversation import ModelsTextConversation diff --git a/sdk/models/model_text_to_image.py b/sdk/models/model_diffusers.py similarity index 87% rename from sdk/models/model_text_to_image.py rename to sdk/models/model_diffusers.py index 91e8ed6..7a6ff8d 100644 --- a/sdk/models/model_text_to_image.py +++ b/sdk/models/model_diffusers.py @@ -1,19 +1,20 @@ import torch -from diffusers import DiffusionPipeline, StableDiffusionXLPipeline from sdk.options import Devices -from typing import Union +from typing import Union, Any from sdk.models import Model -class ModelTextToImage(Model): +class ModelDiffusers(Model): """ This class implements methods to generate images with a text prompt """ - pipeline: StableDiffusionXLPipeline + model_class: Any + pipeline: str loaded: bool device: Union[str, Devices] def __init__(self, model_name: str, model_path: str, + model_class: Any, device: Union[str, Devices], **kwargs): """ Initializes the ModelsTextToImage class @@ -21,6 +22,7 @@ def __init__(self, model_name: str, model_path: str, :param model_path: The path of the model """ super().__init__(model_name, model_path) + self.model_class = model_class self.device = device self.loaded = False self.create_pipeline(**kwargs) @@ -32,7 +34,7 @@ def create_pipeline(self, **kwargs): if self.loaded: return - self.pipeline = DiffusionPipeline.from_pretrained( + self.pipeline = self.model_class.from_pretrained( self.model_path, **kwargs ) @@ -59,7 +61,7 @@ def unload_model(self) -> bool: """ if not self.loaded: return False - self.pipeline.to(device=Devices.RESET) + self.pipeline.to(device=Devices.RESET.value) torch.cuda.empty_cache() torch.cuda.ipc_collect() self.loaded = False @@ -75,4 +77,4 @@ def generate_prompt(self, prompt: str, return self.pipeline( prompt=prompt, **kwargs - ).images[0] + )