Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Studio 2.0) add Stable Diffusion features #2037

Merged
merged 28 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
dbacc36
(WIP): Studio2 app infra and SD API
monorimet Dec 8, 2023
cdf2eb5
Studio2/SD: Use more correct LoRA alpha calculation (#2034)
one-lithe-rune Dec 12, 2023
7a0017d
Studio2: Remove duplications from api/utils.py (#2035)
one-lithe-rune Dec 12, 2023
a43c559
Add test for SD
monorimet Jan 17, 2024
019ba70
Small cleanup
monorimet Feb 2, 2024
01575a8
Shark2/SD/UI: Respect ckpt_dir, share and server_port args (#2070)
one-lithe-rune Feb 7, 2024
25312cd
Add StreamingLLM support to studio2 chat (#2060)
monorimet Jan 19, 2024
230638a
HF-Reference LLM mode + Update test result to match latest Turbine. (…
raikonenfnu Feb 1, 2024
be4c49a
Add rest API endpoint from LanguageModel API
monorimet Feb 3, 2024
1541b21
Add StreamingLLM support to studio2 chat (#2060)
monorimet Jan 19, 2024
5f675e1
Formatting and init files.
monorimet Feb 5, 2024
a198934
Remove unused import.
monorimet Feb 5, 2024
39ebc45
Small fixes
monorimet Feb 12, 2024
75f4ed9
Studio2/SD/UI: Improve various parts of the UI for Stable Diffusion (…
one-lithe-rune Feb 18, 2024
6dc39e6
Studio2/SD: Fix sd pipeline up to "Windows not supported" (#2082)
one-lithe-rune Feb 18, 2024
c507f7d
Studio2/SD/UI: Further sd ui pipeline fixes (#2091)
one-lithe-rune Feb 19, 2024
92c11be
Merge branch 'main' into sd-studio2
monorimet Feb 19, 2024
60c013e
Tweak compile-time flags for SD submodels.
monorimet Feb 20, 2024
f7d1af4
Small fixes to sd, pin mpmath
monorimet Mar 1, 2024
ca69fd5
Add pyinstaller spec and imports script.
monorimet Mar 1, 2024
44ef35f
Fix the .exe (#2101)
gpetters-amd Mar 22, 2024
81fba10
Fix _IREE_TARGET_MAP (#2103) (#2108)
gpetters-amd Mar 25, 2024
1827bc3
Merge branch 'main' into sd-studio2
gpetters94 Mar 27, 2024
0ade2ec
Cleanup sd model map.
monorimet Mar 28, 2024
f0ebfb0
Update dependencies.
monorimet Mar 28, 2024
2996df7
Studio2/SD/UI: Update gradio to 4.19.2 (sd-studio2) (#2097)
one-lithe-rune Mar 28, 2024
73765cd
Merge branch 'main' into sd-studio2
monorimet Mar 28, 2024
9f59a16
fix formatting and disable explicit vulkan env settings.
monorimet Mar 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions apps/shark_studio/api/controlnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# from turbine_models.custom_models.controlnet import control_adapter, preprocessors
import os
import PIL
import numpy as np
from apps.shark_studio.web.utils.file_utils import (
get_generated_imgs_path,
)
from datetime import datetime
from PIL import Image
from gradio.components.image_editor import (
EditorValue,
)


class control_adapter:
def __init__(
self,
model: str,
):
self.model = None

def export_control_adapter_model(model_keyword):
return None

def export_xl_control_adapter_model(model_keyword):
return None


class preprocessors:
def __init__(
self,
model: str,
):
self.model = None

def export_controlnet_model(model_keyword):
return None


control_adapter_map = {
"sd15": {
"canny": {"initializer": control_adapter.export_control_adapter_model},
"openpose": {"initializer": control_adapter.export_control_adapter_model},
"scribble": {"initializer": control_adapter.export_control_adapter_model},
"zoedepth": {"initializer": control_adapter.export_control_adapter_model},
},
"sdxl": {
"canny": {"initializer": control_adapter.export_xl_control_adapter_model},
},
}
preprocessor_model_map = {
"canny": {"initializer": preprocessors.export_controlnet_model},
"openpose": {"initializer": preprocessors.export_controlnet_model},
"scribble": {"initializer": preprocessors.export_controlnet_model},
"zoedepth": {"initializer": preprocessors.export_controlnet_model},
}


class PreprocessorModel:
def __init__(
self,
hf_model_id,
device="cpu",
):
self.model = hf_model_id
self.device = device

def compile(self):
print("compile not implemented for preprocessor.")
return

def run(self, inputs):
print("run not implemented for preprocessor.")
return inputs


def cnet_preview(model, input_image):
curr_datetime = datetime.now().strftime("%Y-%m-%d.%H-%M-%S")
control_imgs_path = os.path.join(get_generated_imgs_path(), "control_hints")
if not os.path.exists(control_imgs_path):
os.mkdir(control_imgs_path)
img_dest = os.path.join(control_imgs_path, model + curr_datetime + ".png")
match model:
case "canny":
canny = PreprocessorModel("canny")
result = canny(
np.array(input_image),
100,
200,
)
Image.fromarray(result).save(fp=img_dest)
return result, img_dest
case "openpose":
openpose = PreprocessorModel("openpose")
result = openpose(np.array(input_image))
Image.fromarray(result[0]).save(fp=img_dest)
return result, img_dest
case "zoedepth":
zoedepth = PreprocessorModel("ZoeDepth")
result = zoedepth(np.array(input_image))
Image.fromarray(result).save(fp=img_dest)
return result, img_dest
case "scribble":
input_image.save(fp=img_dest)
return input_image, img_dest
case _:
return None, None
125 changes: 125 additions & 0 deletions apps/shark_studio/api/initializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import importlib
import os
import signal
import sys
import warnings
import json
from threading import Thread

from apps.shark_studio.modules.timer import startup_timer

from apps.shark_studio.web.utils.tmp_configs import (
config_tmp,
clear_tmp_mlir,
clear_tmp_imgs,
shark_tmp,
)


def imports():
import torch # noqa: F401

startup_timer.record("import torch")
warnings.filterwarnings(
action="ignore", category=DeprecationWarning, module="torch"
)
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torch")

import gradio # noqa: F401

startup_timer.record("import gradio")

import apps.shark_studio.web.utils.globals as global_obj

global_obj._init()
startup_timer.record("initialize globals")

from apps.shark_studio.modules import (
img_processing,
) # noqa: F401

startup_timer.record("other imports")


def initialize():
configure_sigint_handler()
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.

config_tmp()
# clear_tmp_mlir()
clear_tmp_imgs()

from apps.shark_studio.web.utils.file_utils import (
create_checkpoint_folders,
)

# Create custom models folders if they don't exist
create_checkpoint_folders()

import gradio as gr

# initialize_rest(reload_script_modules=False)


def initialize_rest(*, reload_script_modules=False):
"""
Called both from initialize() and when reloading the webui.
"""
# Keep this for adding reload options to the webUI.


def dumpstacks():
import threading
import traceback

id2name = {th.ident: th.name for th in threading.enumerate()}
code = []
for threadId, stack in sys._current_frames().items():
code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
for filename, lineno, name, line in traceback.extract_stack(stack):
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
if line:
code.append(" " + line.strip())
with open(os.path.join(shark_tmp, "stack_dump.log"), "w") as f:
f.write("\n".join(code))


def setup_middleware(app):
from starlette.middleware.gzip import GZipMiddleware

app.middleware_stack = (
None # reset current middleware to allow modifying user provided list
)
app.add_middleware(GZipMiddleware, minimum_size=1000)
configure_cors_middleware(app)
app.build_middleware_stack() # rebuild middleware stack on-the-fly


def configure_cors_middleware(app):
from starlette.middleware.cors import CORSMiddleware
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts

cors_options = {
"allow_methods": ["*"],
"allow_headers": ["*"],
"allow_credentials": True,
}
if cmd_opts.api_accept_origin:
cors_options["allow_origins"] = cmd_opts.api_accept_origin.split(",")

app.add_middleware(CORSMiddleware, **cors_options)


def configure_sigint_handler():
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
print(f"Interrupted with signal {sig} in {frame}")

dumpstacks()

os._exit(0)

signal.signal(signal.SIGINT, sigint_handler)
Loading
Loading