Skip to content

Commit

Permalink
fix formatting.
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Dec 18, 2023
1 parent 5813043 commit 1a9812a
Show file tree
Hide file tree
Showing 25 changed files with 200 additions and 343 deletions.
20 changes: 5 additions & 15 deletions apps/shark_studio/api/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,12 @@ def export_controlnet_model(model_keyword):
control_adapter_map = {
"sd15": {
"canny": {"initializer": control_adapter.export_control_adapter_model},
"openpose": {
"initializer": control_adapter.export_control_adapter_model
},
"scribble": {
"initializer": control_adapter.export_control_adapter_model
},
"zoedepth": {
"initializer": control_adapter.export_control_adapter_model
},
"openpose": {"initializer": control_adapter.export_control_adapter_model},
"scribble": {"initializer": control_adapter.export_control_adapter_model},
"zoedepth": {"initializer": control_adapter.export_control_adapter_model},
},
"sdxl": {
"canny": {
"initializer": control_adapter.export_xl_control_adapter_model
},
"canny": {"initializer": control_adapter.export_xl_control_adapter_model},
},
}
preprocessor_model_map = {
Expand Down Expand Up @@ -84,9 +76,7 @@ def run(self, inputs):

def cnet_preview(model, input_image):
curr_datetime = datetime.now().strftime("%Y-%m-%d.%H-%M-%S")
control_imgs_path = os.path.join(
get_generated_imgs_path(), "control_hints"
)
control_imgs_path = os.path.join(get_generated_imgs_path(), "control_hints")
if not os.path.exists(control_imgs_path):
os.mkdir(control_imgs_path)
img_dest = os.path.join(control_imgs_path, model + curr_datetime + ".png")
Expand Down
17 changes: 7 additions & 10 deletions apps/shark_studio/api/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

from apps.shark_studio.modules.timer import startup_timer
from apps.shark_studio.web.utils.tmp_configs import (
config_tmp,
clear_tmp_mlir,
clear_tmp_imgs,
)
config_tmp,
clear_tmp_mlir,
clear_tmp_imgs,
)


def imports():
Expand All @@ -21,12 +21,8 @@ def imports():
warnings.filterwarnings(
action="ignore", category=DeprecationWarning, module="torch"
)
warnings.filterwarnings(
action="ignore", category=UserWarning, module="torchvision"
)
warnings.filterwarnings(
action="ignore", category=UserWarning, module="torch"
)
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torch")

import gradio # noqa: F401

Expand Down Expand Up @@ -57,6 +53,7 @@ def initialize():
from apps.shark_studio.web.utils.file_utils import (
create_checkpoint_folders,
)

# Create custom models folders if they don't exist
create_checkpoint_folders()

Expand Down
12 changes: 3 additions & 9 deletions apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ def __init__(
use_auth_token=hf_auth_token,
)
elif not os.path.exists(self.tempfile_name):
self.torch_ir, self.tokenizer = llm_model_map[model_name][
"initializer"
](
self.torch_ir, self.tokenizer = llm_model_map[model_name]["initializer"](
self.hf_model_name,
hf_auth_token,
compile_to="torch",
Expand Down Expand Up @@ -142,19 +140,15 @@ def format_out(results):
self.iree_module_dict["config"].device, input_tensor
)
]
token = self.iree_module_dict["vmfb"]["run_initialize"](
*device_inputs
)
token = self.iree_module_dict["vmfb"]["run_initialize"](*device_inputs)
else:
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"].device,
token,
)
]
token = self.iree_module_dict["vmfb"]["run_forward"](
*device_inputs
)
token = self.iree_module_dict["vmfb"]["run_forward"](*device_inputs)

total_time = time.time() - st_time
history.append(format_out(token))
Expand Down
Loading

0 comments on commit 1a9812a

Please sign in to comment.