Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[modular] Stable Diffusion XL ControlNet Union #10509

Merged

Conversation

hlky
Copy link
Collaborator

@hlky hlky commented Jan 9, 2025

What does this PR do?

Adds Stable Diffusion XL ControlNet Union to Modular Diffusers

Modular

import torch
import numpy as np
from PIL import Image
from diffusers import ModularPipeline
from diffusers.pipelines.modular_pipeline import SequentialPipelineBlocks
from diffusers.pipelines.components_manager import ComponentsManager
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular import (
    StableDiffusionXLTextEncoderStep,
    StableDiffusionXLDecodeLatentsStep,
    StableDiffusionXLInputStep,
    StableDiffusionXLAutoSetTimestepsStep,
    StableDiffusionXLAutoPrepareLatentsStep,
    StableDiffusionXLAutoPrepareAdditionalConditioningStep,
    StableDiffusionXLControlNetUnionDenoiseStep,
    StableDiffusionXLAutoVaeEncoderStep,
)
from diffusers.utils import load_image
from controlnet_aux import LineartAnimeDetector
from diffusers import ControlNetUnionModel, AutoencoderKL


class StableDiffusionXLMainSteps(SequentialPipelineBlocks):
    block_classes = [
        StableDiffusionXLInputStep,
        StableDiffusionXLAutoSetTimestepsStep,
        StableDiffusionXLAutoPrepareLatentsStep,
        StableDiffusionXLAutoPrepareAdditionalConditioningStep,
        StableDiffusionXLControlNetUnionDenoiseStep,
    ]
    block_names = [
        "input",
        "set_timesteps",
        "prepare_latents",
        "prepare_add_cond",
        "denoise",
    ]


text_block = StableDiffusionXLTextEncoderStep()
sdxl_main_block = StableDiffusionXLMainSteps()
decoder_block = StableDiffusionXLDecodeLatentsStep()
encoder_block = StableDiffusionXLAutoVaeEncoderStep()

text_node = ModularPipeline.from_block(text_block)
sdxl_node = ModularPipeline.from_block(sdxl_main_block)
decoder_node = ModularPipeline.from_block(decoder_block)
encoder_node = ModularPipeline.from_block(encoder_block)

components = ComponentsManager()
components.add_from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
)
vae = AutoencoderKL.from_pretrained(
    "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)
controlnet = ControlNetUnionModel.from_pretrained(
    "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
)

text_node.update_states(
    **components.get(["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"])
)
decoder_node.update_states(vae=vae)
encoder_node.update_states(vae=vae)
sdxl_node.update_states(**components.get(["unet", "scheduler"]))
sdxl_node.update_states(controlnet=controlnet)

text_node = text_node.to("cuda")
sdxl_node = sdxl_node.to("cuda")
decoder_node = decoder_node.to("cuda")
encoder_node = encoder_node.to("cuda")

text_state = text_node(prompt="A cat")

generator = torch.Generator().manual_seed(0)

Text-to-image

image = load_image(
    "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
).resize((1024, 1024))

processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
controlnet_img = processor(image, output_type="pil")
latents = sdxl_node(
    **text_state.intermediates,
    control_image=[controlnet_img],
    control_mode=[3],
    generator=generator,
    num_inference_steps=20,
    guidance_scale=5.0,
    width=1024,
    height=1024,
    output="latents",
)
image = decoder_node(latents=latents, output="images")[0]
image.save("lineart.png")

lineart

Image-to-image

image = load_image(
    "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
)
height = image.height
width = image.width
ratio = np.sqrt(1024.0 * 1024.0 / (width * height))
# 3 * 3 upscale correspond to 16 * 3 multiply, 2 * 2 correspond to 16 * 2 multiply and so on.
scale_image_factor = 3
base_factor = 16
factor = scale_image_factor * base_factor
W, H = int(width * ratio) // factor * factor, int(height * ratio) // factor * factor
image = image.resize((W, H))
target_width = W // scale_image_factor
target_height = H // scale_image_factor
images = []
crops_coords_list = [
    (0, 0),
    (0, width // 2),
    (height // 2, 0),
    (width // 2, height // 2),
    0,
    0,
    0,
    0,
    0,
]
for i in range(scale_image_factor):
    for j in range(scale_image_factor):
        left = j * target_width
        top = i * target_height
        right = left + target_width
        bottom = top + target_height
        cropped_image = image.crop((left, top, right, bottom))
        cropped_image = cropped_image.resize((W, H))
        images.append(cropped_image)
result_latents = []
for sub_img, crops_coords in zip(images, crops_coords_list):
    new_width, new_height = W, H
    latents = sdxl_node(
        **text_state.intermediates,
        image=sub_img,
        control_image=[sub_img],
        control_mode=[6],
        width=new_width,
        height=new_height,
        num_inference_steps=30,
        crops_coords_top_left=(W, H),
        target_size=(W, H),
        original_size=(W * 2, H * 2),
        generator=generator,
        output="latents",
    )
    result_latents.append(latents)

result_images = []
for latents in result_latents:
    image = decoder_node(latents=latents, output="images")[0]
    result_images.append(image)

new_im = Image.new(
    "RGB", (new_width * scale_image_factor, new_height * scale_image_factor)
)
new_im.paste(result_images[0], (0, 0))
new_im.paste(result_images[1], (new_width, 0))
new_im.paste(result_images[2], (new_width * 2, 0))
new_im.paste(result_images[3], (0, new_height))
new_im.paste(result_images[4], (new_width, new_height))
new_im.paste(result_images[5], (new_width * 2, new_height))
new_im.paste(result_images[6], (0, new_height * 2))
new_im.paste(result_images[7], (new_width, new_height * 2))
new_im.paste(result_images[8], (new_width * 2, new_height * 2))
new_im.save("upscaled.png")

upscaled

Inpaint

image = load_image(
    "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint/overture-creations-5sI6fQgYIuo.png"
).resize((1024, 1024))
mask = load_image(
    "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
).resize((1024, 1024))

controlnet_img = image.copy()
controlnet_img_np = np.array(controlnet_img)
mask_np = np.array(mask)
controlnet_img_np[mask_np > 0] = 0
controlnet_img = Image.fromarray(controlnet_img_np)
latents = sdxl_node(
    **text_state.intermediates,
    image=image,
    mask_image=mask,
    control_image=[controlnet_img],
    control_mode=[7],
    generator=generator,
    num_inference_steps=20,
    guidance_scale=5.0,
    width=1024,
    height=1024,
    output="latents",
)
image = decoder_node(latents=latents, output="images")[0]
image.save("inpaint.png")

inpaint

Masking issue or something here with the output

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@yiyixuxu

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jan 9, 2025

wow awesome!

cc @asomoza here, we only have ip-adapter left now

@yiyixuxu yiyixuxu merged commit 7a34832 into huggingface:modular-diffusers Jan 9, 2025
1 check passed
@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jan 9, 2025

ohh inpaint still has an issue? we can look into that

@asomoza
Copy link
Member

asomoza commented Jan 10, 2025

thanks, this was the missing piece for what I'm doing now, I'll use it right away and test it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants