Skip to content

Commit

Permalink
[SD][web] Add openjourney and dreamlike in SD web UI
Browse files Browse the repository at this point in the history
Signed-Off-by: Gaurav Shukla <[email protected]>
  • Loading branch information
Shukla-Gaurav committed Dec 25, 2022
1 parent d11cf42 commit 45af40f
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 11 deletions.
2 changes: 1 addition & 1 deletion shark/examples/shark_inference/stable_diffusion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def set_init_device_flags():

# use tuned models only in the case of stablediffusion/fp16 and rdna3 cards.
if (
args.variant == "openjourney"
args.variant in ["openjourney", "dreamlike"]
or args.precision != "fp16"
or "vulkan" not in args.device
or "rdna3" not in args.iree_vulkan_target_triple
Expand Down
4 changes: 3 additions & 1 deletion web/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,11 @@
label="Model Variant",
value="stablediffusion",
choices=[
"stablediffusion",
"anythingv3",
"analogdiffusion",
"stablediffusion",
"openjourney",
"dreamlike",
],
)
scheduler_key = gr.Dropdown(
Expand Down
31 changes: 26 additions & 5 deletions web/models/stable_diffusion/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

# clip has 2 variants of max length 77 or 64.
model_clip_max_length = 64 if args.max_length == 64 else 77
if args.variant in ["anythingv3", "analogdiffusion"]:
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
model_clip_max_length = 77
elif args.variant == "openjourney":
model_clip_max_length = 64
Expand Down Expand Up @@ -64,6 +64,7 @@
"anythingv3": "diffusers",
"analogdiffusion": "main",
"openjourney": "main",
"dreamlike": "main",
}


Expand All @@ -78,7 +79,12 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]):
model_config[args.version], subfolder="text_encoder"
)

elif args.variant in ["anythingv3", "analogdiffusion", "openjourney"]:
elif args.variant in [
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]:
text_encoder = CLIPTextModel.from_pretrained(
model_variant[args.variant],
subfolder="text_encoder",
Expand Down Expand Up @@ -133,7 +139,12 @@ def forward(self, input):
)
else:
inputs = model_input[args.version]["vae"]
elif args.variant in ["anythingv3", "analogdiffusion", "openjourney"]:
elif args.variant in [
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]:
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
Expand Down Expand Up @@ -184,7 +195,12 @@ def forward(self, input):
)
else:
inputs = model_input[args.version]["vae"]
elif args.variant in ["anythingv3", "analogdiffusion", "openjourney"]:
elif args.variant in [
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]:
if args.precision == "fp16":
vae = vae.half().cuda()
inputs = tuple(
Expand Down Expand Up @@ -242,7 +258,12 @@ def forward(self, latent, timestep, text_embedding, guidance_scale):
)
else:
inputs = model_input[args.version]["unet"]
elif args.variant in ["anythingv3", "analogdiffusion", "openjourney"]:
elif args.variant in [
"anythingv3",
"analogdiffusion",
"openjourney",
"dreamlike",
]:
if args.precision == "fp16":
unet = unet.half().cuda()
inputs = tuple(
Expand Down
12 changes: 10 additions & 2 deletions web/models/stable_diffusion/resources/model_db.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
"analogdiffusion/untuned":"gs://shark_tank/sd_analog_diffusion",
"analogdiffusion/tuned":"gs://shark_tank/sd_tuned",
"openjourney/untuned":"gs://shark_tank/sd_openjourney",
"openjourney/tuned":"gs://shark_tank/sd_tuned"
"openjourney/tuned":"gs://shark_tank/sd_tuned",
"dreamlike/untuned":"gs://shark_tank/sd_dreamlike_diffusion"
},
{
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
Expand Down Expand Up @@ -55,6 +56,13 @@
"openjourney/v2_1base/vae/fp16/length_77/untuned/base":"oj_vaebase_22dec_fp16",
"openjourney/v2_1base/vae/fp32/length_77/untuned":"oj_vae_22dec_fp32",
"openjourney/v2_1base/vae/fp32/length_77/untuned/base":"oj_vaebase_22dec_fp32",
"openjourney/v2_1base/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64"
"openjourney/v2_1base/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64",
"dreamlike/v2_1base/unet/fp16/length_77/untuned":"dl_unet_23dec_fp16_77",
"dreamlike/v2_1base/unet/fp32/length_77/untuned":"dl_unet_23dec_fp32_77",
"dreamlike/v2_1base/vae/fp16/length_77/untuned":"dl_vae_23dec_fp16",
"dreamlike/v2_1base/vae/fp16/length_77/untuned/base":"dl_vaebase_23dec_fp16",
"dreamlike/v2_1base/vae/fp32/length_77/untuned":"dl_vae_23dec_fp32",
"dreamlike/v2_1base/vae/fp32/length_77/untuned/base":"dl_vaebase_23dec_fp32",
"dreamlike/v2_1base/clip/fp32/length_77/untuned":"dl_clip_23dec_fp32_77"
}
]
4 changes: 2 additions & 2 deletions web/models/stable_diffusion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,14 @@ def set_init_device_flags():
args.device = "cpu"

# set max_length based on availability.
if args.variant in ["anythingv3", "analogdiffusion"]:
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
args.max_length = 77
elif args.variant == "openjourney":
args.max_length = 64

# use tuned models only in the case of stablediffusion/fp16 and rdna3 cards.
if (
args.variant == "openjourney"
args.variant in ["openjourney", "dreamlike"]
or args.precision != "fp16"
or "vulkan" not in args.device
or "rdna3" not in args.iree_vulkan_target_triple
Expand Down

0 comments on commit 45af40f

Please sign in to comment.