diff --git a/apps/shark_studio/api/controlnet.py b/apps/shark_studio/api/controlnet.py
index 620f3d9bb9..2c8a8b566b 100644
--- a/apps/shark_studio/api/controlnet.py
+++ b/apps/shark_studio/api/controlnet.py
@@ -40,20 +40,12 @@ def export_controlnet_model(model_keyword):
control_adapter_map = {
"sd15": {
"canny": {"initializer": control_adapter.export_control_adapter_model},
- "openpose": {
- "initializer": control_adapter.export_control_adapter_model
- },
- "scribble": {
- "initializer": control_adapter.export_control_adapter_model
- },
- "zoedepth": {
- "initializer": control_adapter.export_control_adapter_model
- },
+ "openpose": {"initializer": control_adapter.export_control_adapter_model},
+ "scribble": {"initializer": control_adapter.export_control_adapter_model},
+ "zoedepth": {"initializer": control_adapter.export_control_adapter_model},
},
"sdxl": {
- "canny": {
- "initializer": control_adapter.export_xl_control_adapter_model
- },
+ "canny": {"initializer": control_adapter.export_xl_control_adapter_model},
},
}
preprocessor_model_map = {
@@ -84,9 +76,7 @@ def run(self, inputs):
def cnet_preview(model, input_image):
curr_datetime = datetime.now().strftime("%Y-%m-%d.%H-%M-%S")
- control_imgs_path = os.path.join(
- get_generated_imgs_path(), "control_hints"
- )
+ control_imgs_path = os.path.join(get_generated_imgs_path(), "control_hints")
if not os.path.exists(control_imgs_path):
os.mkdir(control_imgs_path)
img_dest = os.path.join(control_imgs_path, model + curr_datetime + ".png")
diff --git a/apps/shark_studio/api/initializers.py b/apps/shark_studio/api/initializers.py
index e4593570db..ef9816cfca 100644
--- a/apps/shark_studio/api/initializers.py
+++ b/apps/shark_studio/api/initializers.py
@@ -8,10 +8,10 @@
from apps.shark_studio.modules.timer import startup_timer
from apps.shark_studio.web.utils.tmp_configs import (
- config_tmp,
- clear_tmp_mlir,
- clear_tmp_imgs,
- )
+ config_tmp,
+ clear_tmp_mlir,
+ clear_tmp_imgs,
+)
def imports():
@@ -21,12 +21,8 @@ def imports():
warnings.filterwarnings(
action="ignore", category=DeprecationWarning, module="torch"
)
- warnings.filterwarnings(
- action="ignore", category=UserWarning, module="torchvision"
- )
- warnings.filterwarnings(
- action="ignore", category=UserWarning, module="torch"
- )
+ warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
+ warnings.filterwarnings(action="ignore", category=UserWarning, module="torch")
import gradio # noqa: F401
@@ -57,6 +53,7 @@ def initialize():
from apps.shark_studio.web.utils.file_utils import (
create_checkpoint_folders,
)
+
# Create custom models folders if they don't exist
create_checkpoint_folders()
diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py
index 80ad1e8edf..c911ab74b5 100644
--- a/apps/shark_studio/api/llm.py
+++ b/apps/shark_studio/api/llm.py
@@ -77,9 +77,7 @@ def __init__(
use_auth_token=hf_auth_token,
)
elif not os.path.exists(self.tempfile_name):
- self.torch_ir, self.tokenizer = llm_model_map[model_name][
- "initializer"
- ](
+ self.torch_ir, self.tokenizer = llm_model_map[model_name]["initializer"](
self.hf_model_name,
hf_auth_token,
compile_to="torch",
@@ -142,9 +140,7 @@ def format_out(results):
self.iree_module_dict["config"].device, input_tensor
)
]
- token = self.iree_module_dict["vmfb"]["run_initialize"](
- *device_inputs
- )
+ token = self.iree_module_dict["vmfb"]["run_initialize"](*device_inputs)
else:
device_inputs = [
ireert.asdevicearray(
@@ -152,9 +148,7 @@ def format_out(results):
token,
)
]
- token = self.iree_module_dict["vmfb"]["run_forward"](
- *device_inputs
- )
+ token = self.iree_module_dict["vmfb"]["run_forward"](*device_inputs)
total_time = time.time() - st_time
history.append(format_out(token))
diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py
index 43c6a1830c..d2dfa12cd9 100644
--- a/apps/shark_studio/api/sd.py
+++ b/apps/shark_studio/api/sd.py
@@ -11,10 +11,16 @@
from turbine_models.custom_models.sd_inference import clip, unet, vae
from apps.shark_studio.api.controlnet import control_adapter_map
from apps.shark_studio.web.utils.state import status_label
-from apps.shark_studio.web.utils.file_utils import safe_name, get_resource_path, get_checkpoints_path
+from apps.shark_studio.web.utils.file_utils import (
+ safe_name,
+ get_resource_path,
+ get_checkpoints_path,
+)
from apps.shark_studio.modules.pipeline import SharkPipelineBase
from apps.shark_studio.modules.schedulers import get_schedulers
-from apps.shark_studio.modules.prompt_encoding import get_weighted_text_embeddings
+from apps.shark_studio.modules.prompt_encoding import (
+ get_weighted_text_embeddings,
+)
from apps.shark_studio.modules.img_processing import (
resize_stencil,
save_output_img,
@@ -42,25 +48,26 @@
},
"unet": {
"initializer": unet.export_unet_model,
- "ireec_flags": ["--iree-flow-collapse-reduction-dims",
- "--iree-opt-const-expr-hoisting=False",
- "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
+ "ireec_flags": [
+ "--iree-flow-collapse-reduction-dims",
+ "--iree-opt-const-expr-hoisting=False",
+ "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
],
"external_weight_file": None,
},
"vae_decode": {
"initializer": vae.export_vae_model,
"external_weight_file": None,
- "ireec_flags": ["--iree-flow-collapse-reduction-dims",
- "--iree-opt-const-expr-hoisting=False",
- "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
+ "ireec_flags": [
+ "--iree-flow-collapse-reduction-dims",
+ "--iree-opt-const-expr-hoisting=False",
+ "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
],
},
}
class StableDiffusion(SharkPipelineBase):
-
# This class is responsible for executing image generation and creating
# /managing a set of compiled modules to run Stable Diffusion. The init
# aims to be as general as possible, and the class will infer and compile
@@ -73,7 +80,6 @@ class StableDiffusion(SharkPipelineBase):
# embeddings: a dict of embedding checkpoints or model IDs to use when
# initializing the compiled modules.
-
def __init__(
self,
base_model_id,
@@ -99,10 +105,12 @@ def __init__(
"clip": {"hf_model_name": base_model_id},
"unet": {
"hf_model_name": base_model_id,
- "unet_model": unet.UnetModel(hf_model_name=base_model_id, hf_auth_token=None),
+ "unet_model": unet.UnetModel(
+ hf_model_name=base_model_id, hf_auth_token=None
+ ),
"batch_size": batch_size,
- #"is_controlled": is_controlled,
- #"num_loras": num_loras,
+ # "is_controlled": is_controlled,
+ # "num_loras": num_loras,
"height": height,
"width": width,
"precision": precision,
@@ -110,7 +118,9 @@ def __init__(
},
"vae_encode": {
"hf_model_name": custom_vae if custom_vae else base_model_id,
- "vae_model": vae.VaeModel(hf_model_name=base_model_id, hf_auth_token=None),
+ "vae_model": vae.VaeModel(
+ hf_model_name=base_model_id, hf_auth_token=None
+ ),
"batch_size": batch_size,
"height": height,
"width": width,
@@ -118,16 +128,16 @@ def __init__(
},
"vae_decode": {
"hf_model_name": custom_vae if custom_vae else base_model_id,
- "vae_model": vae.VaeModel(hf_model_name=base_model_id, hf_auth_token=None),
+ "vae_model": vae.VaeModel(
+ hf_model_name=base_model_id, hf_auth_token=None
+ ),
"batch_size": batch_size,
"height": height,
"width": width,
"precision": precision,
},
}
- super().__init__(
- sd_model_map, base_model_id, static_kwargs, device, import_ir
- )
+ super().__init__(sd_model_map, base_model_id, static_kwargs, device, import_ir)
pipe_id_list = [
safe_name(base_model_id),
str(batch_size),
@@ -135,7 +145,7 @@ def __init__(
precision,
]
if num_loras > 0:
- pipe_id_list.append(str(num_loras)+"lora")
+ pipe_id_list.append(str(num_loras) + "lora")
if is_controlled:
pipe_id_list.append("controlled")
if custom_vae:
@@ -145,7 +155,6 @@ def __init__(
del static_kwargs
gc.collect()
-
def prepare_pipe(self, scheduler, custom_weights, adapters, embeddings, is_img2img):
print(
f"\n[LOG] Preparing pipeline with scheduler {scheduler}"
@@ -165,15 +174,16 @@ def prepare_pipe(self, scheduler, custom_weights, adapters, embeddings, is_img2i
for i in self.model_map:
self.model_map[i]["external_weights_file"] = None
elif custom_weights:
- print(f"\n[LOG][WARNING] Custom weights were not found at {custom_weights}. Did you mean to pass a base model ID?")
+ print(
+ f"\n[LOG][WARNING] Custom weights were not found at {custom_weights}. Did you mean to pass a base model ID?"
+ )
self.static_kwargs["pipe"] = {
- # "external_weight_path": self.weights_path,
-# "external_weights": "safetensors",
+ # "external_weight_path": self.weights_path,
+ # "external_weights": "safetensors",
}
self.get_compiled_map(pipe_id=self.pipe_id)
print("\n[LOG] Pipeline successfully prepared for runtime.")
return
-
def generate_images(
self,
@@ -191,24 +201,24 @@ def generate_images(
control_mode,
hints,
):
- #TODO: Batched args
+ # TODO: Batched args
self.ondemand = ondemand
if self.is_img2img:
image, _ = self.process_sd_init_image(image, resample_type)
- else:
+ else:
image = None
print("\n[LOG] Generating images...")
- batched_args=[
+ batched_args = [
prompt,
negative_prompt,
- #steps,
- #strength,
- #guidance_scale,
- #seed,
- #resample_type,
- #control_mode,
- #hints,
+ # steps,
+ # strength,
+ # guidance_scale,
+ # seed,
+ # resample_type,
+ # control_mode,
+ # hints,
]
for arg in batched_args:
if not isinstance(arg, list):
@@ -222,7 +232,7 @@ def generate_images(
prompt,
negative_prompt,
)
-
+
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
@@ -242,7 +252,7 @@ def generate_images(
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=final_timesteps,
- cpu_scheduling=True, # until we have schedulers through Turbine
+ cpu_scheduling=True, # until we have schedulers through Turbine
)
# Img latents -> PIL images
@@ -260,7 +270,6 @@ def generate_images(
return all_imgs
-
def encode_prompts_weight(
self,
prompt,
@@ -275,13 +284,10 @@ def encode_prompts_weight(
)
clip_inf_start = time.time()
-
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
- uncond_prompt=negative_prompt
- if do_classifier_free_guidance
- else None,
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
)
if do_classifier_free_guidance:
@@ -300,7 +306,6 @@ def encode_prompts_weight(
return text_embeddings.numpy().astype(np.float16)
-
def prepare_latents(
self,
generator,
@@ -318,7 +323,7 @@ def prepare_latents(
generator=generator,
dtype=self.dtype,
).to("cpu")
-
+
self.scheduler.set_timesteps(num_inference_steps)
if self.is_img2img:
init_timestep = min(
@@ -327,16 +332,13 @@ def prepare_latents(
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start:]
latents = self.encode_image(image)
- latents = self.scheduler.add_noise(
- latents, noise, timesteps[0].repeat(1)
- )
+ latents = self.scheduler.add_noise(latents, noise, timesteps[0].repeat(1))
return latents, [timesteps]
else:
self.scheduler.is_scale_input_called = True
latents = noise * self.scheduler.init_noise_sigma
return latents, self.scheduler.timesteps
-
def encode_image(self, input_image):
self.load_submodels(["vae_encode"])
vae_encode_start = time.time()
@@ -348,7 +350,6 @@ def encode_image(self, input_image):
return latents
-
def produce_img_latents(
self,
latents,
@@ -370,11 +371,15 @@ def produce_img_latents(
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(self.dtype).detach().numpy()
- latent_model_input = self.scheduler.scale_model_input(latents, t).to(self.dtype)
+ latent_model_input = self.scheduler.scale_model_input(latents, t).to(
+ self.dtype
+ )
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
- torch.from_numpy(np.asarray(latent_model_input)).to(torch.float16),
+ torch.from_numpy(np.asarray(latent_model_input)).to(
+ torch.float16
+ ),
mask,
masked_image_latents,
],
@@ -398,9 +403,7 @@ def produce_img_latents(
if cpu_scheduling:
noise_pred = torch.from_numpy(noise_pred.to_host())
- latents = self.scheduler.step(
- noise_pred, t, latents
- ).prev_sample
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
else:
latents = self.run("scheduler_step", (noise_pred, t, latents))
@@ -411,7 +414,7 @@ def produce_img_latents(
# )
step_time_sum += step_time
- #if self.status == SD_STATE_CANCEL:
+ # if self.status == SD_STATE_CANCEL:
# break
if self.ondemand:
@@ -426,7 +429,6 @@ def produce_img_latents(
all_latents = torch.cat(latent_history, dim=0)
return all_latents
-
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
if use_base_vae:
latents = 1 / 0.18215 * latents
@@ -435,11 +437,11 @@ def decode_latents(self, latents, use_base_vae, cpu_scheduling):
if cpu_scheduling:
latents_numpy = latents.detach().numpy()
- #profile_device = start_profiling(file_path="vae.rdc")
+ # profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
images = self.run("vae_decode", latents_numpy).to_host()
vae_inf_time = (time.time() - vae_start) * 1000
- #end_profiling(profile_device)
+ # end_profiling(profile_device)
print(f"\n[LOG] VAE Inference time (ms): {vae_inf_time:.3f}")
if use_base_vae:
@@ -451,7 +453,6 @@ def decode_latents(self, latents, use_base_vae, cpu_scheduling):
pil_images = [Image.fromarray(image) for image in images.numpy()]
return pil_images
-
def process_sd_init_image(self, sd_init_image, resample_type):
if isinstance(sd_init_image, list):
images = []
@@ -463,7 +464,9 @@ def process_sd_init_image(self, sd_init_image, resample_type):
if isinstance(sd_init_image, str):
if os.path.isfile(sd_init_image):
sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB")
- image, is_img2img = self.process_sd_init_image(sd_init_image, resample_type)
+ image, is_img2img = self.process_sd_init_image(
+ sd_init_image, resample_type
+ )
else:
image = None
is_img2img = False
@@ -536,7 +539,6 @@ def shark_sd_fn(
sd_kwargs = locals()
is_img2img = True if sd_init_image[0] is not None else False
-
print("\n[LOG] Performing Stable Diffusion Pipeline setup...")
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
@@ -553,20 +555,20 @@ def shark_sd_fn(
for i, model in enumerate(controlnets["model"]):
if "xl" not in base_model_id.lower():
adapters[f"control_adapter_{model}"] = {
- "hf_id": control_adapter_map[
- "runwayml/stable-diffusion-v1-5"
- ][model],
+ "hf_id": control_adapter_map["runwayml/stable-diffusion-v1-5"][
+ model
+ ],
"strength": controlnets["strength"][i],
}
else:
adapters[f"control_adapter_{model}"] = {
- "hf_id": control_adapter_map[
- "stabilityai/stable-diffusion-xl-1.0"
- ][model],
+ "hf_id": control_adapter_map["stabilityai/stable-diffusion-xl-1.0"][
+ model
+ ],
"strength": controlnets["strength"][i],
}
if model is not None:
- is_controlled=True
+ is_controlled = True
control_mode = controlnets["control_mode"]
for i in controlnets["hint"]:
hints.append[i]
@@ -659,13 +661,13 @@ def view_json_file(file_path):
return content
-
if __name__ == "__main__":
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
- import apps.shark_studio.web.utils.globals as global_obj
+ import apps.shark_studio.web.utils.globals as global_obj
+
global_obj._init()
sd_json = view_json_file(get_resource_path("../configs/default_sd_config.json"))
sd_kwargs = json.loads(sd_json)
for i in shark_sd_fn_dict_input(sd_kwargs):
- print(i)
\ No newline at end of file
+ print(i)
diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py
index 430f074894..e9268aa83b 100644
--- a/apps/shark_studio/api/utils.py
+++ b/apps/shark_studio/api/utils.py
@@ -46,9 +46,7 @@ def get_devices_by_name(driver_name):
if len(device_list_dict) == 1:
device_list.append(f"{device_name} => {driver_name}")
else:
- device_list.append(
- f"{device_name} => {driver_name}://{i}"
- )
+ device_list.append(f"{device_name} => {driver_name}://{i}")
return device_list
set_iree_runtime_flags()
@@ -259,9 +257,7 @@ def get_devices_by_name(driver_name):
if len(device_list_dict) == 1:
device_list.append(f"{device_name} => {driver_name}")
else:
- device_list.append(
- f"{device_name} => {driver_name}://{i}"
- )
+ device_list.append(f"{device_name} => {driver_name}://{i}")
return device_list
set_iree_runtime_flags()
@@ -316,9 +312,7 @@ def parse_seed_input(seed_input: str | list | int):
if isinstance(seed_input, int):
return [seed_input]
- if isinstance(seed_input, list) and all(
- type(seed) is int for seed in seed_input
- ):
+ if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input):
return seed_input
raise TypeError(
diff --git a/apps/shark_studio/modules/ckpt_processing.py b/apps/shark_studio/modules/ckpt_processing.py
index 25edd3109c..08681f6c56 100644
--- a/apps/shark_studio/modules/ckpt_processing.py
+++ b/apps/shark_studio/modules/ckpt_processing.py
@@ -42,9 +42,7 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if
# they want to go for EMA weight extraction or not.
extract_ema = False
- print(
- "Loading diffusers' pipeline from original stable diffusion checkpoint"
- )
+ print("Loading diffusers' pipeline from original stable diffusion checkpoint")
num_in_channels = 9 if is_inpaint else 4
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path_or_dict=custom_weights,
@@ -69,9 +67,7 @@ def convert_original_vae(vae_checkpoint):
original_config = OmegaConf.load(original_config_file)
vae_config = create_vae_diffusers_config(original_config, image_size=512)
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(
- vae_state_dict, vae_config
- )
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_state_dict, vae_config)
return converted_vae_checkpoint
@@ -89,9 +85,7 @@ def process_custom_pipe_weights(custom_weights):
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
- custom_weights_tgt = get_path_to_diffusers_checkpoint(
- custom_weights
- )
+ custom_weights_tgt = get_path_to_diffusers_checkpoint(custom_weights)
custom_weights_params = custom_weights
return custom_weights_params, custom_weights_tgt
@@ -104,15 +98,11 @@ def get_civitai_checkpoint(url: str):
base_filename = re.findall(
'"([^"]*)"', response.headers["Content-Disposition"]
)[0]
- destination_path = (
- Path.cwd() / (cmd_opts.ckpt_dir or "models") / base_filename
- )
+ destination_path = Path.cwd() / (cmd_opts.ckpt_dir or "models") / base_filename
# we don't have this model downloaded yet
if not destination_path.is_file():
- print(
- f"downloading civitai model from {url} to {destination_path}"
- )
+ print(f"downloading civitai model from {url} to {destination_path}")
size = int(response.headers["content-length"], 0)
progress_bar = tqdm(total=size, unit="iB", unit_scale=True)
diff --git a/apps/shark_studio/modules/embeddings.py b/apps/shark_studio/modules/embeddings.py
index b35839c8e5..87924c819e 100644
--- a/apps/shark_studio/modules/embeddings.py
+++ b/apps/shark_studio/modules/embeddings.py
@@ -76,22 +76,14 @@ def processLoRA(model, use_lora, splitting_prefix, lora_strength=0.75):
scale = lora_weight.alpha * lora_strength
if len(weight.size()) == 2:
if len(lora_weight.up.shape) == 4:
- weight_up = (
- lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
- )
- weight_down = (
- lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
- )
- change = (
- torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
- )
+ weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
+ weight_down = lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
+ change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
change = torch.mm(lora_weight.up, lora_weight.down)
elif lora_weight.down.size()[2:4] == (1, 1):
weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
- weight_down = (
- lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
- )
+ weight_down = lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
change = torch.nn.functional.conv2d(
@@ -166,9 +158,7 @@ def get_lora_metadata(lora_filename):
# get a figure for the total number of images processed for this dataset
# either then number actually listed or in its dataset_dir entry or
# the highest frequency's number if that doesn't exist
- img_count = dataset_dirs.get(dir, {}).get(
- "img_count", frequencies[0][1]
- )
+ img_count = dataset_dirs.get(dir, {}).get("img_count", frequencies[0][1])
# add the dataset frequencies to the overall frequencies replacing the
# frequency counts on the tags with a percentage/ratio
diff --git a/apps/shark_studio/modules/pipeline.py b/apps/shark_studio/modules/pipeline.py
index 6c78515cca..1b435db40d 100644
--- a/apps/shark_studio/modules/pipeline.py
+++ b/apps/shark_studio/modules/pipeline.py
@@ -43,7 +43,6 @@ def __init__(
self.iree_module_dict = {}
self.tempfiles = {}
-
def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
# First checks whether we have .vmfbs precompiled, then populates the map
# with the precompiled executables and fetches executables for the rest of the map.
@@ -52,13 +51,15 @@ def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
# and your model map is populated with any IR - unique model IDs and their static params,
# call this method to get the artifacts associated with your map.
self.pipe_id = self.safe_name(pipe_id)
- self.pipe_vmfb_path = Path(os.path.join(get_checkpoints_path(".."), self.pipe_id))
+ self.pipe_vmfb_path = Path(
+ os.path.join(get_checkpoints_path(".."), self.pipe_id)
+ )
self.pipe_vmfb_path.mkdir(parents=True, exist_ok=True)
if submodel == "None":
print("\n[LOG] Gathering any pre-compiled artifacts....")
for key in self.model_map:
self.get_compiled_map(pipe_id, submodel=key)
- else:
+ else:
self.get_precompiled(pipe_id, submodel)
ireec_flags = []
if submodel in self.iree_module_dict:
@@ -68,18 +69,22 @@ def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
elif "vmfb_path" in self.model_map[submodel]:
return
elif submodel not in self.tempfiles:
- print(f"\n[LOG] Tempfile for {submodel} not found. Fetching torch IR...")
+ print(
+ f"\n[LOG] Tempfile for {submodel} not found. Fetching torch IR..."
+ )
if submodel in self.static_kwargs:
init_kwargs = self.static_kwargs[submodel]
for key in self.static_kwargs["pipe"]:
if key not in init_kwargs:
init_kwargs[key] = self.static_kwargs["pipe"][key]
- self.import_torch_ir(
- submodel, init_kwargs
- )
+ self.import_torch_ir(submodel, init_kwargs)
self.get_compiled_map(pipe_id, submodel)
- else:
- ireec_flags = self.model_map[submodel]["ireec_flags"] if "ireec_flags" in self.model_map[submodel] else []
+ else:
+ ireec_flags = (
+ self.model_map[submodel]["ireec_flags"]
+ if "ireec_flags" in self.model_map[submodel]
+ else []
+ )
if "external_weights_file" in self.model_map[submodel]:
weights_path = self.model_map[submodel]["external_weights_file"]
@@ -92,11 +97,10 @@ def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
mmap=True,
external_weight_file=weights_path,
extra_args=ireec_flags,
- write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb")
+ write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb"),
)
return
-
def get_precompiled(self, pipe_id, submodel="None"):
if submodel == "None":
for model in self.model_map:
@@ -112,7 +116,6 @@ def get_precompiled(self, pipe_id, submodel="None"):
self.model_map[submodel]["vmfb_path"] = os.path.join(vmfbs_path, file)
return
-
def import_torch_ir(self, submodel, kwargs):
torch_ir = self.model_map[submodel]["initializer"](
**self.safe_dict(kwargs), compile_to="torch"
@@ -120,17 +123,16 @@ def import_torch_ir(self, submodel, kwargs):
if submodel == "clip":
# clip.export_clip_model returns (torch_ir, tokenizer)
torch_ir = torch_ir[0]
- self.tempfiles[submodel] = get_resource_path(os.path.join(
- "..", "shark_tmp", f"{submodel}.torch.tempfile"
- ))
-
+ self.tempfiles[submodel] = get_resource_path(
+ os.path.join("..", "shark_tmp", f"{submodel}.torch.tempfile")
+ )
+
with open(self.tempfiles[submodel], "w+") as f:
f.write(torch_ir)
del torch_ir
gc.collect()
return
-
def load_submodels(self, submodels: list):
for submodel in submodels:
if submodel in self.iree_module_dict:
@@ -149,13 +151,14 @@ def load_submodels(self, submodels: list):
self.device,
device_idx=0,
rt_flags=[],
- external_weight_file=self.model_map[submodel]['external_weight_file'],
+ external_weight_file=self.model_map[submodel][
+ "external_weight_file"
+ ],
)
else:
self.get_compiled_map(self.pipe_id, submodel)
return
-
def unload_submodels(self, submodels: list):
for submodel in submodels:
if submodel in self.iree_module_dict:
@@ -163,18 +166,20 @@ def unload_submodels(self, submodels: list):
gc.collect()
return
-
def run(self, submodel, inputs):
if not isinstance(inputs, list):
inputs = [inputs]
- inp = [ireert.asdevicearray(self.iree_module_dict[submodel]["config"].device, input) for input in inputs]
- return self.iree_module_dict[submodel]['vmfb']['main'](*inp)
-
+ inp = [
+ ireert.asdevicearray(
+ self.iree_module_dict[submodel]["config"].device, input
+ )
+ for input in inputs
+ ]
+ return self.iree_module_dict[submodel]["vmfb"]["main"](*inp)
def safe_name(self, name):
return name.replace("/", "_").replace("-", "_").replace("\\", "_")
-
def safe_dict(self, kwargs: dict):
flat_args = {}
for i in kwargs:
@@ -183,4 +188,4 @@ def safe_dict(self, kwargs: dict):
else:
flat_args[i] = kwargs[i]
- return flat_args
+ return flat_args
diff --git a/apps/shark_studio/modules/prompt_encoding.py b/apps/shark_studio/modules/prompt_encoding.py
index b2a5e8a27e..3dc61aba08 100644
--- a/apps/shark_studio/modules/prompt_encoding.py
+++ b/apps/shark_studio/modules/prompt_encoding.py
@@ -1,4 +1,3 @@
-
from typing import List, Optional, Union
from iree import runtime as ireert
import re
@@ -112,9 +111,7 @@ def multiply_range(start_position, multiplier):
return res
-def get_prompts_with_weights(
- pipe, prompt: List[str], max_length: int
-):
+def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
@@ -164,18 +161,12 @@ def pad_tokens_and_weights(
"""
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = (
- max_length
- if no_boseos_middle
- else max_embeddings_multiples * chunk_length
+ max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
)
for i in range(len(tokens)):
- tokens[i] = (
- [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
- )
+ tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
if no_boseos_middle:
- weights[i] = (
- [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
- )
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
else:
w = []
if len(weights[i]) == 0:
@@ -195,6 +186,7 @@ def pad_tokens_and_weights(
return tokens, weights
+
def get_unweighted_text_embeddings(
pipe,
text_input,
@@ -242,7 +234,6 @@ def get_unweighted_text_embeddings(
return text_embeddings
-
# This function deals with NoneType values occuring in tokens after padding
# It switches out None with 49407 as truncating None values causes matrix dimension errors,
def filter_nonetype_tokens(tokens: List[List]):
@@ -290,9 +281,7 @@ def get_weighted_text_embeddings(
# round up the longest length of tokens to a multiple of (model_max_length - 2)
max_length = max([len(token) for token in prompt_tokens])
if uncond_prompt is not None:
- max_length = max(
- max_length, max([len(token) for token in uncond_tokens])
- )
+ max_length = max(max_length, max([len(token) for token in uncond_tokens]))
max_embeddings_multiples = min(
max_embeddings_multiples,
(max_length - 1) // (pipe.model_max_length - 2) + 1,
@@ -334,9 +323,7 @@ def get_weighted_text_embeddings(
uncond_tokens = filter_nonetype_tokens(uncond_tokens)
# uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
- uncond_tokens = torch.tensor(
- uncond_tokens, dtype=torch.long, device="cpu"
- )
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device="cpu")
# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
@@ -346,9 +333,7 @@ def get_weighted_text_embeddings(
no_boseos_middle=no_boseos_middle,
)
# prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
- prompt_weights = torch.tensor(
- prompt_weights, dtype=torch.float, device="cpu"
- )
+ prompt_weights = torch.tensor(prompt_weights, dtype=torch.float, device="cpu")
if uncond_prompt is not None:
uncond_embeddings = get_unweighted_text_embeddings(
pipe,
@@ -357,27 +342,19 @@ def get_weighted_text_embeddings(
no_boseos_middle=no_boseos_middle,
)
# uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
- uncond_weights = torch.tensor(
- uncond_weights, dtype=torch.float, device="cpu"
- )
+ uncond_weights = torch.tensor(uncond_weights, dtype=torch.float, device="cpu")
# assign weights to the prompts and normalize in the sense of mean
# TODO: should we normalize by chunk or in a whole (current implementation)?
if (not skip_parsing) and (not skip_weighting):
previous_mean = (
- text_embeddings.float()
- .mean(axis=[-2, -1])
- .to(text_embeddings.dtype)
+ text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = (
- text_embeddings.float()
- .mean(axis=[-2, -1])
- .to(text_embeddings.dtype)
- )
- text_embeddings *= (
- (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
+ text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
)
+ text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if uncond_prompt is not None:
previous_mean = (
uncond_embeddings.float()
diff --git a/apps/shark_studio/modules/schedulers.py b/apps/shark_studio/modules/schedulers.py
index 7a42338b1a..8c2413c638 100644
--- a/apps/shark_studio/modules/schedulers.py
+++ b/apps/shark_studio/modules/schedulers.py
@@ -17,7 +17,7 @@
def get_schedulers(model_id):
- #TODO: switch over to turbine and run all on GPU
+ # TODO: switch over to turbine and run all on GPU
print(f"\n[LOG] Initializing schedulers from model id: {model_id}")
schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
@@ -44,14 +44,10 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
- schedulers[
- "DPMSolverMultistep"
- ] = DPMSolverMultistepScheduler.from_pretrained(
+ schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", algorithm_type="dpmsolver"
)
- schedulers[
- "DPMSolverMultistep++"
- ] = DPMSolverMultistepScheduler.from_pretrained(
+ schedulers["DPMSolverMultistep++"] = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
)
schedulers[
@@ -83,9 +79,7 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
- schedulers[
- "DPMSolverSinglestep"
- ] = DPMSolverSinglestepScheduler.from_pretrained(
+ schedulers["DPMSolverSinglestep"] = DPMSolverSinglestepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
@@ -108,24 +102,16 @@ def export_scheduler_model(model):
scheduler_model_map = {
"EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"),
- "EulerAncestralDiscrete": export_scheduler_model(
- "EulerAncestralDiscreteScheduler"
- ),
+ "EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"),
"LCM": export_scheduler_model("LCMScheduler"),
"LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"),
"PNDM": export_scheduler_model("PNDMScheduler"),
"DDPM": export_scheduler_model("DDPMScheduler"),
"DDIM": export_scheduler_model("DDIMScheduler"),
- "DPMSolverMultistep": export_scheduler_model(
- "DPMSolverMultistepScheduler"
- ),
+ "DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"),
"KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"),
"DEISMultistep": export_scheduler_model("DEISMultistepScheduler"),
- "DPMSolverSinglestep": export_scheduler_model(
- "DPMSolverSingleStepScheduler"
- ),
- "KDPM2AncestralDiscrete": export_scheduler_model(
- "KDPM2AncestralDiscreteScheduler"
- ),
+ "DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"),
+ "KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"),
"HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"),
}
diff --git a/apps/shark_studio/modules/shared_cmd_opts.py b/apps/shark_studio/modules/shared_cmd_opts.py
index 535e5d2c7f..dd871383a7 100644
--- a/apps/shark_studio/modules/shared_cmd_opts.py
+++ b/apps/shark_studio/modules/shared_cmd_opts.py
@@ -130,8 +130,7 @@ def is_valid_file(arg):
"--strength",
type=float,
default=0.8,
- help="The strength of change applied on the given input image for "
- "img2img.",
+ help="The strength of change applied on the given input image for " "img2img.",
)
p.add_argument(
@@ -290,9 +289,7 @@ def is_valid_file(arg):
# Model Config and Usage Params
##############################################################################
-p.add_argument(
- "--device", type=str, default="vulkan", help="Device to run the model."
-)
+p.add_argument("--device", type=str, default="vulkan", help="Device to run the model.")
p.add_argument(
"--precision", type=str, default="fp16", help="Precision to run the model."
@@ -350,8 +347,7 @@ def is_valid_file(arg):
"--batch_count",
type=int,
default=1,
- help="Number of batches to be generated with random seeds in "
- "single execution.",
+ help="Number of batches to be generated with random seeds in " "single execution.",
)
p.add_argument(
@@ -416,8 +412,7 @@ def is_valid_file(arg):
"--use_lora",
type=str,
default="",
- help="Use standalone LoRA weight using a HF ID or a checkpoint "
- "file (~3 MB).",
+ help="Use standalone LoRA weight using a HF ID or a checkpoint " "file (~3 MB).",
)
p.add_argument(
@@ -493,8 +488,7 @@ def is_valid_file(arg):
"--dump_isa",
default=False,
action="store_true",
- help="When enabled call amdllpc to get ISA dumps. "
- "Use with dispatch benchmarks.",
+ help="When enabled call amdllpc to get ISA dumps. " "Use with dispatch benchmarks.",
)
p.add_argument(
@@ -515,8 +509,7 @@ def is_valid_file(arg):
"--enable_rgp",
default=False,
action=argparse.BooleanOptionalAction,
- help="Flag for inserting debug frames between iterations "
- "for use with rgp.",
+ help="Flag for inserting debug frames between iterations " "for use with rgp.",
)
p.add_argument(
@@ -602,8 +595,7 @@ def is_valid_file(arg):
"--progress_bar",
default=True,
action=argparse.BooleanOptionalAction,
- help="Flag for removing the progress bar animation during "
- "image generation.",
+ help="Flag for removing the progress bar animation during " "image generation.",
)
p.add_argument(
diff --git a/apps/shark_studio/modules/timer.py b/apps/shark_studio/modules/timer.py
index 8fd1e6a7df..d6918e9c8c 100644
--- a/apps/shark_studio/modules/timer.py
+++ b/apps/shark_studio/modules/timer.py
@@ -11,9 +11,7 @@ def __init__(self, timer, category):
def __enter__(self):
self.start = time.time()
- self.timer.base_category = (
- self.original_base_category + self.category + "/"
- )
+ self.timer.base_category = self.original_base_category + self.category + "/"
self.timer.subcategory_level += 1
if self.timer.print_log:
@@ -82,10 +80,7 @@ def summary(self):
res += " ("
res += ", ".join(
- [
- f"{category}: {time_taken:.1f}s"
- for category, time_taken in additions
- ]
+ [f"{category}: {time_taken:.1f}s" for category, time_taken in additions]
)
res += ")"
diff --git a/apps/shark_studio/web/api/compat.py b/apps/shark_studio/web/api/compat.py
index 80399505c4..3f92c41d02 100644
--- a/apps/shark_studio/web/api/compat.py
+++ b/apps/shark_studio/web/api/compat.py
@@ -30,17 +30,13 @@ def decode_base64_to_image(encoding):
status_code=500, detail="Request to local resource not allowed"
)
- headers = (
- {"user-agent": opts.api_useragent} if opts.api_useragent else {}
- )
+ headers = {"user-agent": opts.api_useragent} if opts.api_useragent else {}
response = requests.get(encoding, timeout=30, headers=headers)
try:
image = Image.open(BytesIO(response.content))
return image
except Exception as e:
- raise HTTPException(
- status_code=500, detail="Invalid image url"
- ) from e
+ raise HTTPException(status_code=500, detail="Invalid image url") from e
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
@@ -48,9 +44,7 @@ def decode_base64_to_image(encoding):
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception as e:
- raise HTTPException(
- status_code=500, detail="Invalid encoded image"
- ) from e
+ raise HTTPException(status_code=500, detail="Invalid encoded image") from e
def encode_pil_to_base64(image):
diff --git a/apps/shark_studio/web/index.py b/apps/shark_studio/web/index.py
index 0d5de9a839..05a9bc363d 100644
--- a/apps/shark_studio/web/index.py
+++ b/apps/shark_studio/web/index.py
@@ -130,9 +130,7 @@ def webui():
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
- base_path = getattr(
- sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
- )
+ base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py
index 3c5497215a..917ac870bf 100644
--- a/apps/shark_studio/web/ui/chat.py
+++ b/apps/shark_studio/web/ui/chat.py
@@ -88,17 +88,11 @@ def llm_chat_api(InputData: dict):
# print(f"prompt : {InputData['prompt']}")
# print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now
global vicuna_model
- model_name = (
- InputData["model"] if "model" in InputData.keys() else "codegen"
- )
+ model_name = InputData["model"] if "model" in InputData.keys() else "codegen"
model_path = llm_model_map[model_name]
device = "cpu-task"
precision = "fp16"
- max_toks = (
- None
- if "max_tokens" not in InputData.keys()
- else InputData["max_tokens"]
- )
+ max_toks = None if "max_tokens" not in InputData.keys() else InputData["max_tokens"]
if max_toks is None:
max_toks = 128 if model_name == "codegen" else 512
@@ -135,9 +129,7 @@ def llm_chat_api(InputData: dict):
# TODO: add role dict for different models
if is_chat_completion_api:
# TODO: add funtionality for multiple messages
- prompt = create_prompt(
- model_name, [(InputData["messages"][0]["content"], "")]
- )
+ prompt = create_prompt(model_name, [(InputData["messages"][0]["content"], "")])
else:
prompt = InputData["prompt"]
print("prompt = ", prompt)
@@ -170,9 +162,7 @@ def llm_chat_api(InputData: dict):
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
return {
"id": end_time,
- "object": "chat.completion"
- if is_chat_completion_api
- else "text_completion",
+ "object": "chat.completion" if is_chat_completion_api else "text_completion",
"created": int(end_time),
"choices": choices,
}
@@ -248,9 +238,7 @@ def view_json_file(file_obj):
with gr.Row(visible=False):
with gr.Group():
- config_file = gr.File(
- label="Upload sharding configuration", visible=False
- )
+ config_file = gr.File(label="Upload sharding configuration", visible=False)
json_view_button = gr.Button("View as JSON", visible=False)
json_view = gr.JSON(visible=False)
json_view_button.click(
diff --git a/apps/shark_studio/web/ui/common_events.py b/apps/shark_studio/web/ui/common_events.py
index 9adf7dd61b..7dda8ba268 100644
--- a/apps/shark_studio/web/ui/common_events.py
+++ b/apps/shark_studio/web/ui/common_events.py
@@ -13,7 +13,9 @@ def lora_changed(lora_files):
# tag frequency percentage, above which a tag is displayed
TAG_DISPLAY_THRESHOLD = 0.65
# template for the html used to display a tag
- TAG_HTML_TEMPLATE = '{tag}'
+ TAG_HTML_TEMPLATE = (
+ '{tag}'
+ )
output = []
for lora_file in lora_files:
if lora_file == "":
diff --git a/apps/shark_studio/web/ui/outputgallery.py b/apps/shark_studio/web/ui/outputgallery.py
index 77e60be819..a3de6f7b57 100644
--- a/apps/shark_studio/web/ui/outputgallery.py
+++ b/apps/shark_studio/web/ui/outputgallery.py
@@ -22,8 +22,7 @@ def outputgallery_filenames(subdir) -> list[str]:
new_dir_path = os.path.join(output_dir, subdir)
if os.path.exists(new_dir_path):
filenames = [
- glob.glob(new_dir_path + "/" + ext)
- for ext in ("*.png", "*.jpg", "*.jpeg")
+ glob.glob(new_dir_path + "/" + ext) for ext in ("*.png", "*.jpg", "*.jpeg")
]
return sorted(sum(filenames, []), key=os.path.getmtime, reverse=True)
@@ -52,11 +51,7 @@ def output_subdirs() -> list[str]:
[path for path in relative_paths if path.isnumeric()], reverse=True
)
result_paths = generated_paths + sorted(
- [
- path
- for path in relative_paths
- if (not path.isnumeric()) and path != "."
- ]
+ [path for path in relative_paths if (not path.isnumeric()) and path != "."]
)
return result_paths
@@ -184,9 +179,7 @@ def on_image_columns_change(columns):
def on_select_subdir(subdir) -> list:
# evt.value is the subdirectory name
new_images = outputgallery_filenames(subdir)
- new_label = (
- f"{len(new_images)} images in {os.path.join(output_dir, subdir)}"
- )
+ new_label = f"{len(new_images)} images in {os.path.join(output_dir, subdir)}"
return [
new_images,
gr.Gallery(
@@ -223,8 +216,7 @@ def on_refresh(current_subdir: str) -> list:
)
new_images = outputgallery_filenames(new_subdir)
new_label = (
- f"{len(new_images)} images in "
- f"{os.path.join(output_dir, new_subdir)}"
+ f"{len(new_images)} images in " f"{os.path.join(output_dir, new_subdir)}"
)
return [
@@ -234,9 +226,7 @@ def on_refresh(current_subdir: str) -> list:
),
refreshed_subdirs,
new_images,
- gr.Gallery(
- value=new_images, label=new_label, visible=len(new_images) > 0
- ),
+ gr.Gallery(value=new_images, label=new_label, visible=len(new_images) > 0),
gr.Image(
label=new_label,
visible=len(new_images) == 0,
diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py
index 66ec452d0b..49eac6d821 100644
--- a/apps/shark_studio/web/ui/sd.py
+++ b/apps/shark_studio/web/ui/sd.py
@@ -50,6 +50,7 @@
"stabilityai/sdxl-turbo",
]
+
def view_json_file(file_path):
content = ""
with open(file_path, "r") as fopen:
@@ -149,7 +150,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
else:
sd_json = new_sd_config
for i in sd_json["sd_init_image"]:
- if i is not None:
+ if i is not None:
if os.path.isfile(i):
sd_image = [Image.open(i, mode="r")]
else:
@@ -338,9 +339,9 @@ def import_original(original_img, width, height):
choices=["None"] + get_checkpoints(base_model_id),
) #
with gr.Column(scale=2):
- sd_vae_info = (
- str(get_checkpoints_path("vae"))
- ).replace("\\", "\n\\")
+ sd_vae_info = (str(get_checkpoints_path("vae"))).replace(
+ "\\", "\n\\"
+ )
sd_vae_info = f"VAE Path: {sd_vae_info}"
custom_vae = gr.Dropdown(
label=f"Custom VAE Models",
@@ -396,12 +397,10 @@ def import_original(original_img, width, height):
height=300,
interactive=True,
)
- with gr.Accordion(
- label="Embeddings options", open=True, render=True
- ):
- sd_lora_info = (
- str(get_checkpoints_path("loras"))
- ).replace("\\", "\n\\")
+ with gr.Accordion(label="Embeddings options", open=True, render=True):
+ sd_lora_info = (str(get_checkpoints_path("loras"))).replace(
+ "\\", "\n\\"
+ )
with gr.Row():
embeddings_config = gr.JSON(min_width=50, scale=1)
lora_opt = gr.Dropdown(
@@ -651,7 +650,7 @@ def import_original(original_img, width, height):
with gr.Column(scale=3, min_width=600):
with gr.Group():
sd_gallery = gr.Gallery(
- label="Generated images",
+ label="Generated images",
show_label=False,
elem_id="gallery",
columns=2,
@@ -666,9 +665,7 @@ def import_original(original_img, width, height):
elem_id="std_output",
show_label=False,
)
- sd_element.load(
- logger.read_sd_logs, None, std_output, every=1
- )
+ sd_element.load(logger.read_sd_logs, None, std_output, every=1)
sd_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
@@ -696,9 +693,7 @@ def import_original(original_img, width, height):
value="Clear Config", size="sm", components=sd_json
)
with gr.Row():
- save_sd_config = gr.Button(
- value="Save Config", size="sm"
- )
+ save_sd_config = gr.Button(value="Save Config", size="sm")
sd_config_name = gr.Textbox(
value="Config Name",
info="Name of the file this config will be saved to.",
@@ -796,9 +791,7 @@ def import_original(original_img, width, height):
)
prompt_submit = prompt.submit(**status_kwargs).then(**pull_kwargs)
- neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
- **pull_kwargs
- )
+ neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(**pull_kwargs)
generate_click = (
stable_diffusion.click(**status_kwargs)
.then(**pull_kwargs)
diff --git a/apps/shark_studio/web/ui/utils.py b/apps/shark_studio/web/ui/utils.py
index ba62e5adc0..34a94fa014 100644
--- a/apps/shark_studio/web/ui/utils.py
+++ b/apps/shark_studio/web/ui/utils.py
@@ -6,9 +6,7 @@
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
- base_path = getattr(
- sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
- )
+ base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
return os.path.join(base_path, relative_path)
diff --git a/apps/shark_studio/web/utils/file_utils.py b/apps/shark_studio/web/utils/file_utils.py
index e7b8fd72c4..cae925f5e2 100644
--- a/apps/shark_studio/web/utils/file_utils.py
+++ b/apps/shark_studio/web/utils/file_utils.py
@@ -23,9 +23,7 @@ def get_path_stem(path):
def get_resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
- base_path = getattr(
- sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
- )
+ base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
result = Path(os.path.join(base_path, relative_path)).resolve(strict=False)
return result
diff --git a/apps/shark_studio/web/utils/metadata/csv_metadata.py b/apps/shark_studio/web/utils/metadata/csv_metadata.py
index d617e802bf..d515234083 100644
--- a/apps/shark_studio/web/utils/metadata/csv_metadata.py
+++ b/apps/shark_studio/web/utils/metadata/csv_metadata.py
@@ -29,9 +29,7 @@ def parse_csv(image_filename: str):
has_header = csv.Sniffer().has_header(csv_file.read(2048))
csv_file.seek(0)
- reader = (
- csv.DictReader(csv_file) if has_header else csv.reader(csv_file)
- )
+ reader = csv.DictReader(csv_file) if has_header else csv.reader(csv_file)
matches = [
# we rely on humanize and humanizable to work out the parsing of the individual .csv rows
diff --git a/apps/shark_studio/web/utils/metadata/format.py b/apps/shark_studio/web/utils/metadata/format.py
index f097dab54f..308d9f8e8b 100644
--- a/apps/shark_studio/web/utils/metadata/format.py
+++ b/apps/shark_studio/web/utils/metadata/format.py
@@ -92,15 +92,11 @@ def compact(metadata: dict) -> dict:
result["Hires resize"] = f"{hires_y}x{hires_x}"
# remove VAE if it exists and is empty
- if (result.keys() & {"VAE"}) and (
- not result["VAE"] or result["VAE"] == "None"
- ):
+ if (result.keys() & {"VAE"}) and (not result["VAE"] or result["VAE"] == "None"):
result.pop("VAE")
# remove LoRA if it exists and is empty
- if (result.keys() & {"LoRA"}) and (
- not result["LoRA"] or result["LoRA"] == "None"
- ):
+ if (result.keys() & {"LoRA"}) and (not result["LoRA"] or result["LoRA"] == "None"):
result.pop("LoRA")
return result
diff --git a/apps/shark_studio/web/utils/metadata/png_metadata.py b/apps/shark_studio/web/utils/metadata/png_metadata.py
index d9091afdf4..72f663f246 100644
--- a/apps/shark_studio/web/utils/metadata/png_metadata.py
+++ b/apps/shark_studio/web/utils/metadata/png_metadata.py
@@ -66,9 +66,7 @@ def parse_generation_parameters(x: str):
return res
-def try_find_model_base_from_png_metadata(
- file: str, folder: str = "models"
-) -> str:
+def try_find_model_base_from_png_metadata(file: str, folder: str = "models") -> str:
custom = ""
# Remove extension from file info
@@ -101,16 +99,13 @@ def find_model_from_png_metadata(
# No matching model was found
if not png_custom and not png_hf_id:
print(
- "Import PNG info: Unable to find a matching model for %s"
- % model_file
+ "Import PNG info: Unable to find a matching model for %s" % model_file
)
return png_custom, png_hf_id
-def find_vae_from_png_metadata(
- key: str, metadata: dict[str, str | int]
-) -> str:
+def find_vae_from_png_metadata(key: str, metadata: dict[str, str | int]) -> str:
vae_custom = ""
if key in metadata:
diff --git a/apps/shark_studio/web/utils/state.py b/apps/shark_studio/web/utils/state.py
index 626d4ce53f..350012c381 100644
--- a/apps/shark_studio/web/utils/state.py
+++ b/apps/shark_studio/web/utils/state.py
@@ -18,8 +18,7 @@ def get_generation_text_info(seeds, device):
text_output = f"prompt={cfg_dump['prompts']}"
text_output += f"\nnegative prompt={cfg_dump['negative_prompts']}"
text_output += (
- f"\nmodel_id={cfg_dump['hf_model_id']}, "
- f"ckpt_loc={cfg_dump['ckpt_loc']}"
+ f"\nmodel_id={cfg_dump['hf_model_id']}, " f"ckpt_loc={cfg_dump['ckpt_loc']}"
)
text_output += f"\nscheduler={cfg_dump['scheduler']}, " f"device={device}"
text_output += (
diff --git a/apps/shark_studio/web/utils/tmp_configs.py b/apps/shark_studio/web/utils/tmp_configs.py
index 3e6ba46bfe..4415276ea3 100644
--- a/apps/shark_studio/web/utils/tmp_configs.py
+++ b/apps/shark_studio/web/utils/tmp_configs.py
@@ -7,9 +7,7 @@
def clear_tmp_mlir():
cleanup_start = time()
- print(
- "Clearing .mlir temporary files from a prior run. This may take some time..."
- )
+ print("Clearing .mlir temporary files from a prior run. This may take some time...")
mlir_files = [
filename
for filename in os.listdir(shark_tmp)
@@ -18,9 +16,7 @@ def clear_tmp_mlir():
]
for filename in mlir_files:
os.remove(shark_tmp + filename)
- print(
- f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds."
- )
+ print(f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds.")
def clear_tmp_imgs():