Skip to content

Commit

Permalink
Add further comments + fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
multimodalart committed Aug 3, 2024
1 parent 618e64a commit e43ae0f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
15 changes: 9 additions & 6 deletions demo_gr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class FluxGenerator:
def __init__(self, model_name: str, device: str, offload: bool):
self.device = torch.device(device)
self.offload = offload
self.model_name = model_name
self.is_schnell = model_name == "flux-schnell"
self.model, self.ae, self.t5, self.clip, self.nsfw_classifier = get_models(
model_name,
Expand All @@ -49,6 +50,7 @@ def generate_image(
image2image_strength=0.0,
add_sampling_metadata=True,
):
seed = int(seed)
if seed == -1:
seed = None

Expand Down Expand Up @@ -152,9 +154,9 @@ def generate_image(

img.save(filename, format="png", quality=100)

return img, opts.seed, filename, None
return img, str(opts.seed), filename, None
else:
return None, opts.seed, None, "Your generated image may contain NSFW content."
return None, str(opts.seed), None, "Your generated image may contain NSFW content."

def create_demo(model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False):
generator = FluxGenerator(model_name, device, offload)
Expand All @@ -166,15 +168,16 @@ def create_demo(model_name: str, device: str = "cuda" if torch.cuda.is_available
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="a photo of a forest with mist swirling around the tree trunks. The word \"FLUX\" is painted over it in big, red brush strokes with visible texture")
do_img2img = gr.Checkbox(label="Image to Image", value=False, interactive=not is_schnell)
init_image = gr.Image(label="Input Image", visible=False)
image2image_strength = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="Noising strength", visible=False)

with gr.Accordion("Advanced Options", open=False):
width = gr.Slider(128, 8192, 1360, step=16, label="Width")
height = gr.Slider(128, 8192, 768, step=16, label="Height")
num_steps = gr.Slider(1, 50, 4 if is_schnell else 50, step=1, label="Number of steps")
guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Guidance", interactive=not is_schnell)
seed = gr.Number(-1, label="Seed (-1 for random)", precision=0)
do_img2img = gr.Checkbox(label="Image to Image", value=False, interactive=not is_schnell)
init_image = gr.Image(label="Input Image", visible=False)
image2image_strength = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="Noising strength", visible=False)
seed = gr.Textbox(-1, label="Seed (-1 for random)")
add_sampling_metadata = gr.Checkbox(label="Add sampling parameters to metadata?", value=True)

generate_btn = gr.Button("Generate")
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ streamlit = [
"streamlit",
"streamlit-keyup",
]
gradio = [
"gradio",
]
all = [
"flux[streamlit]",
"flux[gradio]",
]

[project.scripts]
Expand Down

0 comments on commit e43ae0f

Please sign in to comment.