From e43ae0fda01d53ebca4f53cb549f778528d8467b Mon Sep 17 00:00:00 2001 From: multimodalart Date: Sat, 3 Aug 2024 11:47:50 +0300 Subject: [PATCH] Add further comments + fixes --- demo_gr.py | 15 +++++++++------ pyproject.toml | 4 ++++ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/demo_gr.py b/demo_gr.py index 28c55ce1..03634cef 100644 --- a/demo_gr.py +++ b/demo_gr.py @@ -28,6 +28,7 @@ class FluxGenerator: def __init__(self, model_name: str, device: str, offload: bool): self.device = torch.device(device) self.offload = offload + self.model_name = model_name self.is_schnell = model_name == "flux-schnell" self.model, self.ae, self.t5, self.clip, self.nsfw_classifier = get_models( model_name, @@ -49,6 +50,7 @@ def generate_image( image2image_strength=0.0, add_sampling_metadata=True, ): + seed = int(seed) if seed == -1: seed = None @@ -152,9 +154,9 @@ def generate_image( img.save(filename, format="png", quality=100) - return img, opts.seed, filename, None + return img, str(opts.seed), filename, None else: - return None, opts.seed, None, "Your generated image may contain NSFW content." + return None, str(opts.seed), None, "Your generated image may contain NSFW content." def create_demo(model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False): generator = FluxGenerator(model_name, device, offload) @@ -166,15 +168,16 @@ def create_demo(model_name: str, device: str = "cuda" if torch.cuda.is_available with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt", value="a photo of a forest with mist swirling around the tree trunks. The word \"FLUX\" is painted over it in big, red brush strokes with visible texture") + do_img2img = gr.Checkbox(label="Image to Image", value=False, interactive=not is_schnell) + init_image = gr.Image(label="Input Image", visible=False) + image2image_strength = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="Noising strength", visible=False) + with gr.Accordion("Advanced Options", open=False): width = gr.Slider(128, 8192, 1360, step=16, label="Width") height = gr.Slider(128, 8192, 768, step=16, label="Height") num_steps = gr.Slider(1, 50, 4 if is_schnell else 50, step=1, label="Number of steps") guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Guidance", interactive=not is_schnell) - seed = gr.Number(-1, label="Seed (-1 for random)", precision=0) - do_img2img = gr.Checkbox(label="Image to Image", value=False, interactive=not is_schnell) - init_image = gr.Image(label="Input Image", visible=False) - image2image_strength = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="Noising strength", visible=False) + seed = gr.Textbox(-1, label="Seed (-1 for random)") add_sampling_metadata = gr.Checkbox(label="Add sampling parameters to metadata?", value=True) generate_btn = gr.Button("Generate") diff --git a/pyproject.toml b/pyproject.toml index edaf4323..72f921b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,8 +28,12 @@ streamlit = [ "streamlit", "streamlit-keyup", ] +gradio = [ + "gradio", +] all = [ "flux[streamlit]", + "flux[gradio]", ] [project.scripts]