diff --git a/apps/shark_studio/api/controlnet.py b/apps/shark_studio/api/controlnet.py index 620f3d9bb9..2c8a8b566b 100644 --- a/apps/shark_studio/api/controlnet.py +++ b/apps/shark_studio/api/controlnet.py @@ -40,20 +40,12 @@ def export_controlnet_model(model_keyword): control_adapter_map = { "sd15": { "canny": {"initializer": control_adapter.export_control_adapter_model}, - "openpose": { - "initializer": control_adapter.export_control_adapter_model - }, - "scribble": { - "initializer": control_adapter.export_control_adapter_model - }, - "zoedepth": { - "initializer": control_adapter.export_control_adapter_model - }, + "openpose": {"initializer": control_adapter.export_control_adapter_model}, + "scribble": {"initializer": control_adapter.export_control_adapter_model}, + "zoedepth": {"initializer": control_adapter.export_control_adapter_model}, }, "sdxl": { - "canny": { - "initializer": control_adapter.export_xl_control_adapter_model - }, + "canny": {"initializer": control_adapter.export_xl_control_adapter_model}, }, } preprocessor_model_map = { @@ -84,9 +76,7 @@ def run(self, inputs): def cnet_preview(model, input_image): curr_datetime = datetime.now().strftime("%Y-%m-%d.%H-%M-%S") - control_imgs_path = os.path.join( - get_generated_imgs_path(), "control_hints" - ) + control_imgs_path = os.path.join(get_generated_imgs_path(), "control_hints") if not os.path.exists(control_imgs_path): os.mkdir(control_imgs_path) img_dest = os.path.join(control_imgs_path, model + curr_datetime + ".png") diff --git a/apps/shark_studio/api/initializers.py b/apps/shark_studio/api/initializers.py index e4593570db..ef9816cfca 100644 --- a/apps/shark_studio/api/initializers.py +++ b/apps/shark_studio/api/initializers.py @@ -8,10 +8,10 @@ from apps.shark_studio.modules.timer import startup_timer from apps.shark_studio.web.utils.tmp_configs import ( - config_tmp, - clear_tmp_mlir, - clear_tmp_imgs, - ) + config_tmp, + clear_tmp_mlir, + clear_tmp_imgs, +) def imports(): @@ -21,12 +21,8 @@ def imports(): warnings.filterwarnings( action="ignore", category=DeprecationWarning, module="torch" ) - warnings.filterwarnings( - action="ignore", category=UserWarning, module="torchvision" - ) - warnings.filterwarnings( - action="ignore", category=UserWarning, module="torch" - ) + warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") + warnings.filterwarnings(action="ignore", category=UserWarning, module="torch") import gradio # noqa: F401 @@ -57,6 +53,7 @@ def initialize(): from apps.shark_studio.web.utils.file_utils import ( create_checkpoint_folders, ) + # Create custom models folders if they don't exist create_checkpoint_folders() diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 80ad1e8edf..c911ab74b5 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -77,9 +77,7 @@ def __init__( use_auth_token=hf_auth_token, ) elif not os.path.exists(self.tempfile_name): - self.torch_ir, self.tokenizer = llm_model_map[model_name][ - "initializer" - ]( + self.torch_ir, self.tokenizer = llm_model_map[model_name]["initializer"]( self.hf_model_name, hf_auth_token, compile_to="torch", @@ -142,9 +140,7 @@ def format_out(results): self.iree_module_dict["config"].device, input_tensor ) ] - token = self.iree_module_dict["vmfb"]["run_initialize"]( - *device_inputs - ) + token = self.iree_module_dict["vmfb"]["run_initialize"](*device_inputs) else: device_inputs = [ ireert.asdevicearray( @@ -152,9 +148,7 @@ def format_out(results): token, ) ] - token = self.iree_module_dict["vmfb"]["run_forward"]( - *device_inputs - ) + token = self.iree_module_dict["vmfb"]["run_forward"](*device_inputs) total_time = time.time() - st_time history.append(format_out(token)) diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index 43c6a1830c..d2dfa12cd9 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -11,10 +11,16 @@ from turbine_models.custom_models.sd_inference import clip, unet, vae from apps.shark_studio.api.controlnet import control_adapter_map from apps.shark_studio.web.utils.state import status_label -from apps.shark_studio.web.utils.file_utils import safe_name, get_resource_path, get_checkpoints_path +from apps.shark_studio.web.utils.file_utils import ( + safe_name, + get_resource_path, + get_checkpoints_path, +) from apps.shark_studio.modules.pipeline import SharkPipelineBase from apps.shark_studio.modules.schedulers import get_schedulers -from apps.shark_studio.modules.prompt_encoding import get_weighted_text_embeddings +from apps.shark_studio.modules.prompt_encoding import ( + get_weighted_text_embeddings, +) from apps.shark_studio.modules.img_processing import ( resize_stencil, save_output_img, @@ -42,25 +48,26 @@ }, "unet": { "initializer": unet.export_unet_model, - "ireec_flags": ["--iree-flow-collapse-reduction-dims", - "--iree-opt-const-expr-hoisting=False", - "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807", + "ireec_flags": [ + "--iree-flow-collapse-reduction-dims", + "--iree-opt-const-expr-hoisting=False", + "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807", ], "external_weight_file": None, }, "vae_decode": { "initializer": vae.export_vae_model, "external_weight_file": None, - "ireec_flags": ["--iree-flow-collapse-reduction-dims", - "--iree-opt-const-expr-hoisting=False", - "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807", + "ireec_flags": [ + "--iree-flow-collapse-reduction-dims", + "--iree-opt-const-expr-hoisting=False", + "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807", ], }, } class StableDiffusion(SharkPipelineBase): - # This class is responsible for executing image generation and creating # /managing a set of compiled modules to run Stable Diffusion. The init # aims to be as general as possible, and the class will infer and compile @@ -73,7 +80,6 @@ class StableDiffusion(SharkPipelineBase): # embeddings: a dict of embedding checkpoints or model IDs to use when # initializing the compiled modules. - def __init__( self, base_model_id, @@ -99,10 +105,12 @@ def __init__( "clip": {"hf_model_name": base_model_id}, "unet": { "hf_model_name": base_model_id, - "unet_model": unet.UnetModel(hf_model_name=base_model_id, hf_auth_token=None), + "unet_model": unet.UnetModel( + hf_model_name=base_model_id, hf_auth_token=None + ), "batch_size": batch_size, - #"is_controlled": is_controlled, - #"num_loras": num_loras, + # "is_controlled": is_controlled, + # "num_loras": num_loras, "height": height, "width": width, "precision": precision, @@ -110,7 +118,9 @@ def __init__( }, "vae_encode": { "hf_model_name": custom_vae if custom_vae else base_model_id, - "vae_model": vae.VaeModel(hf_model_name=base_model_id, hf_auth_token=None), + "vae_model": vae.VaeModel( + hf_model_name=base_model_id, hf_auth_token=None + ), "batch_size": batch_size, "height": height, "width": width, @@ -118,16 +128,16 @@ def __init__( }, "vae_decode": { "hf_model_name": custom_vae if custom_vae else base_model_id, - "vae_model": vae.VaeModel(hf_model_name=base_model_id, hf_auth_token=None), + "vae_model": vae.VaeModel( + hf_model_name=base_model_id, hf_auth_token=None + ), "batch_size": batch_size, "height": height, "width": width, "precision": precision, }, } - super().__init__( - sd_model_map, base_model_id, static_kwargs, device, import_ir - ) + super().__init__(sd_model_map, base_model_id, static_kwargs, device, import_ir) pipe_id_list = [ safe_name(base_model_id), str(batch_size), @@ -135,7 +145,7 @@ def __init__( precision, ] if num_loras > 0: - pipe_id_list.append(str(num_loras)+"lora") + pipe_id_list.append(str(num_loras) + "lora") if is_controlled: pipe_id_list.append("controlled") if custom_vae: @@ -145,7 +155,6 @@ def __init__( del static_kwargs gc.collect() - def prepare_pipe(self, scheduler, custom_weights, adapters, embeddings, is_img2img): print( f"\n[LOG] Preparing pipeline with scheduler {scheduler}" @@ -165,15 +174,16 @@ def prepare_pipe(self, scheduler, custom_weights, adapters, embeddings, is_img2i for i in self.model_map: self.model_map[i]["external_weights_file"] = None elif custom_weights: - print(f"\n[LOG][WARNING] Custom weights were not found at {custom_weights}. Did you mean to pass a base model ID?") + print( + f"\n[LOG][WARNING] Custom weights were not found at {custom_weights}. Did you mean to pass a base model ID?" + ) self.static_kwargs["pipe"] = { - # "external_weight_path": self.weights_path, -# "external_weights": "safetensors", + # "external_weight_path": self.weights_path, + # "external_weights": "safetensors", } self.get_compiled_map(pipe_id=self.pipe_id) print("\n[LOG] Pipeline successfully prepared for runtime.") return - def generate_images( self, @@ -191,24 +201,24 @@ def generate_images( control_mode, hints, ): - #TODO: Batched args + # TODO: Batched args self.ondemand = ondemand if self.is_img2img: image, _ = self.process_sd_init_image(image, resample_type) - else: + else: image = None print("\n[LOG] Generating images...") - batched_args=[ + batched_args = [ prompt, negative_prompt, - #steps, - #strength, - #guidance_scale, - #seed, - #resample_type, - #control_mode, - #hints, + # steps, + # strength, + # guidance_scale, + # seed, + # resample_type, + # control_mode, + # hints, ] for arg in batched_args: if not isinstance(arg, list): @@ -222,7 +232,7 @@ def generate_images( prompt, negative_prompt, ) - + uint32_info = np.iinfo(np.uint32) uint32_min, uint32_max = uint32_info.min, uint32_info.max if seed < uint32_min or seed >= uint32_max: @@ -242,7 +252,7 @@ def generate_images( text_embeddings=text_embeddings, guidance_scale=guidance_scale, total_timesteps=final_timesteps, - cpu_scheduling=True, # until we have schedulers through Turbine + cpu_scheduling=True, # until we have schedulers through Turbine ) # Img latents -> PIL images @@ -260,7 +270,6 @@ def generate_images( return all_imgs - def encode_prompts_weight( self, prompt, @@ -275,13 +284,10 @@ def encode_prompts_weight( ) clip_inf_start = time.time() - text_embeddings, uncond_embeddings = get_weighted_text_embeddings( pipe=self, prompt=prompt, - uncond_prompt=negative_prompt - if do_classifier_free_guidance - else None, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, ) if do_classifier_free_guidance: @@ -300,7 +306,6 @@ def encode_prompts_weight( return text_embeddings.numpy().astype(np.float16) - def prepare_latents( self, generator, @@ -318,7 +323,7 @@ def prepare_latents( generator=generator, dtype=self.dtype, ).to("cpu") - + self.scheduler.set_timesteps(num_inference_steps) if self.is_img2img: init_timestep = min( @@ -327,16 +332,13 @@ def prepare_latents( t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] latents = self.encode_image(image) - latents = self.scheduler.add_noise( - latents, noise, timesteps[0].repeat(1) - ) + latents = self.scheduler.add_noise(latents, noise, timesteps[0].repeat(1)) return latents, [timesteps] else: self.scheduler.is_scale_input_called = True latents = noise * self.scheduler.init_noise_sigma return latents, self.scheduler.timesteps - def encode_image(self, input_image): self.load_submodels(["vae_encode"]) vae_encode_start = time.time() @@ -348,7 +350,6 @@ def encode_image(self, input_image): return latents - def produce_img_latents( self, latents, @@ -370,11 +371,15 @@ def produce_img_latents( for i, t in tqdm(enumerate(total_timesteps)): step_start_time = time.time() timestep = torch.tensor([t]).to(self.dtype).detach().numpy() - latent_model_input = self.scheduler.scale_model_input(latents, t).to(self.dtype) + latent_model_input = self.scheduler.scale_model_input(latents, t).to( + self.dtype + ) if mask is not None and masked_image_latents is not None: latent_model_input = torch.cat( [ - torch.from_numpy(np.asarray(latent_model_input)).to(torch.float16), + torch.from_numpy(np.asarray(latent_model_input)).to( + torch.float16 + ), mask, masked_image_latents, ], @@ -398,9 +403,7 @@ def produce_img_latents( if cpu_scheduling: noise_pred = torch.from_numpy(noise_pred.to_host()) - latents = self.scheduler.step( - noise_pred, t, latents - ).prev_sample + latents = self.scheduler.step(noise_pred, t, latents).prev_sample else: latents = self.run("scheduler_step", (noise_pred, t, latents)) @@ -411,7 +414,7 @@ def produce_img_latents( # ) step_time_sum += step_time - #if self.status == SD_STATE_CANCEL: + # if self.status == SD_STATE_CANCEL: # break if self.ondemand: @@ -426,7 +429,6 @@ def produce_img_latents( all_latents = torch.cat(latent_history, dim=0) return all_latents - def decode_latents(self, latents, use_base_vae, cpu_scheduling): if use_base_vae: latents = 1 / 0.18215 * latents @@ -435,11 +437,11 @@ def decode_latents(self, latents, use_base_vae, cpu_scheduling): if cpu_scheduling: latents_numpy = latents.detach().numpy() - #profile_device = start_profiling(file_path="vae.rdc") + # profile_device = start_profiling(file_path="vae.rdc") vae_start = time.time() images = self.run("vae_decode", latents_numpy).to_host() vae_inf_time = (time.time() - vae_start) * 1000 - #end_profiling(profile_device) + # end_profiling(profile_device) print(f"\n[LOG] VAE Inference time (ms): {vae_inf_time:.3f}") if use_base_vae: @@ -451,7 +453,6 @@ def decode_latents(self, latents, use_base_vae, cpu_scheduling): pil_images = [Image.fromarray(image) for image in images.numpy()] return pil_images - def process_sd_init_image(self, sd_init_image, resample_type): if isinstance(sd_init_image, list): images = [] @@ -463,7 +464,9 @@ def process_sd_init_image(self, sd_init_image, resample_type): if isinstance(sd_init_image, str): if os.path.isfile(sd_init_image): sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB") - image, is_img2img = self.process_sd_init_image(sd_init_image, resample_type) + image, is_img2img = self.process_sd_init_image( + sd_init_image, resample_type + ) else: image = None is_img2img = False @@ -536,7 +539,6 @@ def shark_sd_fn( sd_kwargs = locals() is_img2img = True if sd_init_image[0] is not None else False - print("\n[LOG] Performing Stable Diffusion Pipeline setup...") from apps.shark_studio.modules.shared_cmd_opts import cmd_opts @@ -553,20 +555,20 @@ def shark_sd_fn( for i, model in enumerate(controlnets["model"]): if "xl" not in base_model_id.lower(): adapters[f"control_adapter_{model}"] = { - "hf_id": control_adapter_map[ - "runwayml/stable-diffusion-v1-5" - ][model], + "hf_id": control_adapter_map["runwayml/stable-diffusion-v1-5"][ + model + ], "strength": controlnets["strength"][i], } else: adapters[f"control_adapter_{model}"] = { - "hf_id": control_adapter_map[ - "stabilityai/stable-diffusion-xl-1.0" - ][model], + "hf_id": control_adapter_map["stabilityai/stable-diffusion-xl-1.0"][ + model + ], "strength": controlnets["strength"][i], } if model is not None: - is_controlled=True + is_controlled = True control_mode = controlnets["control_mode"] for i in controlnets["hint"]: hints.append[i] @@ -659,13 +661,13 @@ def view_json_file(file_path): return content - if __name__ == "__main__": from apps.shark_studio.modules.shared_cmd_opts import cmd_opts - import apps.shark_studio.web.utils.globals as global_obj + import apps.shark_studio.web.utils.globals as global_obj + global_obj._init() sd_json = view_json_file(get_resource_path("../configs/default_sd_config.json")) sd_kwargs = json.loads(sd_json) for i in shark_sd_fn_dict_input(sd_kwargs): - print(i) \ No newline at end of file + print(i) diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index 430f074894..e9268aa83b 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -46,9 +46,7 @@ def get_devices_by_name(driver_name): if len(device_list_dict) == 1: device_list.append(f"{device_name} => {driver_name}") else: - device_list.append( - f"{device_name} => {driver_name}://{i}" - ) + device_list.append(f"{device_name} => {driver_name}://{i}") return device_list set_iree_runtime_flags() @@ -259,9 +257,7 @@ def get_devices_by_name(driver_name): if len(device_list_dict) == 1: device_list.append(f"{device_name} => {driver_name}") else: - device_list.append( - f"{device_name} => {driver_name}://{i}" - ) + device_list.append(f"{device_name} => {driver_name}://{i}") return device_list set_iree_runtime_flags() @@ -316,9 +312,7 @@ def parse_seed_input(seed_input: str | list | int): if isinstance(seed_input, int): return [seed_input] - if isinstance(seed_input, list) and all( - type(seed) is int for seed in seed_input - ): + if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input): return seed_input raise TypeError( diff --git a/apps/shark_studio/modules/ckpt_processing.py b/apps/shark_studio/modules/ckpt_processing.py index 25edd3109c..08681f6c56 100644 --- a/apps/shark_studio/modules/ckpt_processing.py +++ b/apps/shark_studio/modules/ckpt_processing.py @@ -42,9 +42,7 @@ def preprocessCKPT(custom_weights, is_inpaint=False): # TODO: Add an option `--ema` (`--no-ema`) for users to specify if # they want to go for EMA weight extraction or not. extract_ema = False - print( - "Loading diffusers' pipeline from original stable diffusion checkpoint" - ) + print("Loading diffusers' pipeline from original stable diffusion checkpoint") num_in_channels = 9 if is_inpaint else 4 pipe = download_from_original_stable_diffusion_ckpt( checkpoint_path_or_dict=custom_weights, @@ -69,9 +67,7 @@ def convert_original_vae(vae_checkpoint): original_config = OmegaConf.load(original_config_file) vae_config = create_vae_diffusers_config(original_config, image_size=512) - converted_vae_checkpoint = convert_ldm_vae_checkpoint( - vae_state_dict, vae_config - ) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_state_dict, vae_config) return converted_vae_checkpoint @@ -89,9 +85,7 @@ def process_custom_pipe_weights(custom_weights): assert custom_weights.lower().endswith( (".ckpt", ".safetensors") ), "checkpoint files supported can be any of [.ckpt, .safetensors] type" - custom_weights_tgt = get_path_to_diffusers_checkpoint( - custom_weights - ) + custom_weights_tgt = get_path_to_diffusers_checkpoint(custom_weights) custom_weights_params = custom_weights return custom_weights_params, custom_weights_tgt @@ -104,15 +98,11 @@ def get_civitai_checkpoint(url: str): base_filename = re.findall( '"([^"]*)"', response.headers["Content-Disposition"] )[0] - destination_path = ( - Path.cwd() / (cmd_opts.ckpt_dir or "models") / base_filename - ) + destination_path = Path.cwd() / (cmd_opts.ckpt_dir or "models") / base_filename # we don't have this model downloaded yet if not destination_path.is_file(): - print( - f"downloading civitai model from {url} to {destination_path}" - ) + print(f"downloading civitai model from {url} to {destination_path}") size = int(response.headers["content-length"], 0) progress_bar = tqdm(total=size, unit="iB", unit_scale=True) diff --git a/apps/shark_studio/modules/embeddings.py b/apps/shark_studio/modules/embeddings.py index b35839c8e5..87924c819e 100644 --- a/apps/shark_studio/modules/embeddings.py +++ b/apps/shark_studio/modules/embeddings.py @@ -76,22 +76,14 @@ def processLoRA(model, use_lora, splitting_prefix, lora_strength=0.75): scale = lora_weight.alpha * lora_strength if len(weight.size()) == 2: if len(lora_weight.up.shape) == 4: - weight_up = ( - lora_weight.up.squeeze(3).squeeze(2).to(torch.float32) - ) - weight_down = ( - lora_weight.down.squeeze(3).squeeze(2).to(torch.float32) - ) - change = ( - torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) - ) + weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32) + weight_down = lora_weight.down.squeeze(3).squeeze(2).to(torch.float32) + change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) else: change = torch.mm(lora_weight.up, lora_weight.down) elif lora_weight.down.size()[2:4] == (1, 1): weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32) - weight_down = ( - lora_weight.down.squeeze(3).squeeze(2).to(torch.float32) - ) + weight_down = lora_weight.down.squeeze(3).squeeze(2).to(torch.float32) change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) else: change = torch.nn.functional.conv2d( @@ -166,9 +158,7 @@ def get_lora_metadata(lora_filename): # get a figure for the total number of images processed for this dataset # either then number actually listed or in its dataset_dir entry or # the highest frequency's number if that doesn't exist - img_count = dataset_dirs.get(dir, {}).get( - "img_count", frequencies[0][1] - ) + img_count = dataset_dirs.get(dir, {}).get("img_count", frequencies[0][1]) # add the dataset frequencies to the overall frequencies replacing the # frequency counts on the tags with a percentage/ratio diff --git a/apps/shark_studio/modules/pipeline.py b/apps/shark_studio/modules/pipeline.py index 6c78515cca..1b435db40d 100644 --- a/apps/shark_studio/modules/pipeline.py +++ b/apps/shark_studio/modules/pipeline.py @@ -43,7 +43,6 @@ def __init__( self.iree_module_dict = {} self.tempfiles = {} - def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None: # First checks whether we have .vmfbs precompiled, then populates the map # with the precompiled executables and fetches executables for the rest of the map. @@ -52,13 +51,15 @@ def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None: # and your model map is populated with any IR - unique model IDs and their static params, # call this method to get the artifacts associated with your map. self.pipe_id = self.safe_name(pipe_id) - self.pipe_vmfb_path = Path(os.path.join(get_checkpoints_path(".."), self.pipe_id)) + self.pipe_vmfb_path = Path( + os.path.join(get_checkpoints_path(".."), self.pipe_id) + ) self.pipe_vmfb_path.mkdir(parents=True, exist_ok=True) if submodel == "None": print("\n[LOG] Gathering any pre-compiled artifacts....") for key in self.model_map: self.get_compiled_map(pipe_id, submodel=key) - else: + else: self.get_precompiled(pipe_id, submodel) ireec_flags = [] if submodel in self.iree_module_dict: @@ -68,18 +69,22 @@ def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None: elif "vmfb_path" in self.model_map[submodel]: return elif submodel not in self.tempfiles: - print(f"\n[LOG] Tempfile for {submodel} not found. Fetching torch IR...") + print( + f"\n[LOG] Tempfile for {submodel} not found. Fetching torch IR..." + ) if submodel in self.static_kwargs: init_kwargs = self.static_kwargs[submodel] for key in self.static_kwargs["pipe"]: if key not in init_kwargs: init_kwargs[key] = self.static_kwargs["pipe"][key] - self.import_torch_ir( - submodel, init_kwargs - ) + self.import_torch_ir(submodel, init_kwargs) self.get_compiled_map(pipe_id, submodel) - else: - ireec_flags = self.model_map[submodel]["ireec_flags"] if "ireec_flags" in self.model_map[submodel] else [] + else: + ireec_flags = ( + self.model_map[submodel]["ireec_flags"] + if "ireec_flags" in self.model_map[submodel] + else [] + ) if "external_weights_file" in self.model_map[submodel]: weights_path = self.model_map[submodel]["external_weights_file"] @@ -92,11 +97,10 @@ def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None: mmap=True, external_weight_file=weights_path, extra_args=ireec_flags, - write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb") + write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb"), ) return - def get_precompiled(self, pipe_id, submodel="None"): if submodel == "None": for model in self.model_map: @@ -112,7 +116,6 @@ def get_precompiled(self, pipe_id, submodel="None"): self.model_map[submodel]["vmfb_path"] = os.path.join(vmfbs_path, file) return - def import_torch_ir(self, submodel, kwargs): torch_ir = self.model_map[submodel]["initializer"]( **self.safe_dict(kwargs), compile_to="torch" @@ -120,17 +123,16 @@ def import_torch_ir(self, submodel, kwargs): if submodel == "clip": # clip.export_clip_model returns (torch_ir, tokenizer) torch_ir = torch_ir[0] - self.tempfiles[submodel] = get_resource_path(os.path.join( - "..", "shark_tmp", f"{submodel}.torch.tempfile" - )) - + self.tempfiles[submodel] = get_resource_path( + os.path.join("..", "shark_tmp", f"{submodel}.torch.tempfile") + ) + with open(self.tempfiles[submodel], "w+") as f: f.write(torch_ir) del torch_ir gc.collect() return - def load_submodels(self, submodels: list): for submodel in submodels: if submodel in self.iree_module_dict: @@ -149,13 +151,14 @@ def load_submodels(self, submodels: list): self.device, device_idx=0, rt_flags=[], - external_weight_file=self.model_map[submodel]['external_weight_file'], + external_weight_file=self.model_map[submodel][ + "external_weight_file" + ], ) else: self.get_compiled_map(self.pipe_id, submodel) return - def unload_submodels(self, submodels: list): for submodel in submodels: if submodel in self.iree_module_dict: @@ -163,18 +166,20 @@ def unload_submodels(self, submodels: list): gc.collect() return - def run(self, submodel, inputs): if not isinstance(inputs, list): inputs = [inputs] - inp = [ireert.asdevicearray(self.iree_module_dict[submodel]["config"].device, input) for input in inputs] - return self.iree_module_dict[submodel]['vmfb']['main'](*inp) - + inp = [ + ireert.asdevicearray( + self.iree_module_dict[submodel]["config"].device, input + ) + for input in inputs + ] + return self.iree_module_dict[submodel]["vmfb"]["main"](*inp) def safe_name(self, name): return name.replace("/", "_").replace("-", "_").replace("\\", "_") - def safe_dict(self, kwargs: dict): flat_args = {} for i in kwargs: @@ -183,4 +188,4 @@ def safe_dict(self, kwargs: dict): else: flat_args[i] = kwargs[i] - return flat_args + return flat_args diff --git a/apps/shark_studio/modules/prompt_encoding.py b/apps/shark_studio/modules/prompt_encoding.py index b2a5e8a27e..3dc61aba08 100644 --- a/apps/shark_studio/modules/prompt_encoding.py +++ b/apps/shark_studio/modules/prompt_encoding.py @@ -1,4 +1,3 @@ - from typing import List, Optional, Union from iree import runtime as ireert import re @@ -112,9 +111,7 @@ def multiply_range(start_position, multiplier): return res -def get_prompts_with_weights( - pipe, prompt: List[str], max_length: int -): +def get_prompts_with_weights(pipe, prompt: List[str], max_length: int): r""" Tokenize a list of prompts and return its tokens with weights of each token. No padding, starting or ending token is included. @@ -164,18 +161,12 @@ def pad_tokens_and_weights( """ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) weights_length = ( - max_length - if no_boseos_middle - else max_embeddings_multiples * chunk_length + max_length if no_boseos_middle else max_embeddings_multiples * chunk_length ) for i in range(len(tokens)): - tokens[i] = ( - [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) - ) + tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) if no_boseos_middle: - weights[i] = ( - [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) - ) + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) else: w = [] if len(weights[i]) == 0: @@ -195,6 +186,7 @@ def pad_tokens_and_weights( return tokens, weights + def get_unweighted_text_embeddings( pipe, text_input, @@ -242,7 +234,6 @@ def get_unweighted_text_embeddings( return text_embeddings - # This function deals with NoneType values occuring in tokens after padding # It switches out None with 49407 as truncating None values causes matrix dimension errors, def filter_nonetype_tokens(tokens: List[List]): @@ -290,9 +281,7 @@ def get_weighted_text_embeddings( # round up the longest length of tokens to a multiple of (model_max_length - 2) max_length = max([len(token) for token in prompt_tokens]) if uncond_prompt is not None: - max_length = max( - max_length, max([len(token) for token in uncond_tokens]) - ) + max_length = max(max_length, max([len(token) for token in uncond_tokens])) max_embeddings_multiples = min( max_embeddings_multiples, (max_length - 1) // (pipe.model_max_length - 2) + 1, @@ -334,9 +323,7 @@ def get_weighted_text_embeddings( uncond_tokens = filter_nonetype_tokens(uncond_tokens) # uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) - uncond_tokens = torch.tensor( - uncond_tokens, dtype=torch.long, device="cpu" - ) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device="cpu") # get the embeddings text_embeddings = get_unweighted_text_embeddings( @@ -346,9 +333,7 @@ def get_weighted_text_embeddings( no_boseos_middle=no_boseos_middle, ) # prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) - prompt_weights = torch.tensor( - prompt_weights, dtype=torch.float, device="cpu" - ) + prompt_weights = torch.tensor(prompt_weights, dtype=torch.float, device="cpu") if uncond_prompt is not None: uncond_embeddings = get_unweighted_text_embeddings( pipe, @@ -357,27 +342,19 @@ def get_weighted_text_embeddings( no_boseos_middle=no_boseos_middle, ) # uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) - uncond_weights = torch.tensor( - uncond_weights, dtype=torch.float, device="cpu" - ) + uncond_weights = torch.tensor(uncond_weights, dtype=torch.float, device="cpu") # assign weights to the prompts and normalize in the sense of mean # TODO: should we normalize by chunk or in a whole (current implementation)? if (not skip_parsing) and (not skip_weighting): previous_mean = ( - text_embeddings.float() - .mean(axis=[-2, -1]) - .to(text_embeddings.dtype) + text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) ) text_embeddings *= prompt_weights.unsqueeze(-1) current_mean = ( - text_embeddings.float() - .mean(axis=[-2, -1]) - .to(text_embeddings.dtype) - ) - text_embeddings *= ( - (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) ) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) if uncond_prompt is not None: previous_mean = ( uncond_embeddings.float() diff --git a/apps/shark_studio/modules/schedulers.py b/apps/shark_studio/modules/schedulers.py index 7a42338b1a..8c2413c638 100644 --- a/apps/shark_studio/modules/schedulers.py +++ b/apps/shark_studio/modules/schedulers.py @@ -17,7 +17,7 @@ def get_schedulers(model_id): - #TODO: switch over to turbine and run all on GPU + # TODO: switch over to turbine and run all on GPU print(f"\n[LOG] Initializing schedulers from model id: {model_id}") schedulers = dict() schedulers["PNDM"] = PNDMScheduler.from_pretrained( @@ -44,14 +44,10 @@ def get_schedulers(model_id): model_id, subfolder="scheduler", ) - schedulers[ - "DPMSolverMultistep" - ] = DPMSolverMultistepScheduler.from_pretrained( + schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained( model_id, subfolder="scheduler", algorithm_type="dpmsolver" ) - schedulers[ - "DPMSolverMultistep++" - ] = DPMSolverMultistepScheduler.from_pretrained( + schedulers["DPMSolverMultistep++"] = DPMSolverMultistepScheduler.from_pretrained( model_id, subfolder="scheduler", algorithm_type="dpmsolver++" ) schedulers[ @@ -83,9 +79,7 @@ def get_schedulers(model_id): model_id, subfolder="scheduler", ) - schedulers[ - "DPMSolverSinglestep" - ] = DPMSolverSinglestepScheduler.from_pretrained( + schedulers["DPMSolverSinglestep"] = DPMSolverSinglestepScheduler.from_pretrained( model_id, subfolder="scheduler", ) @@ -108,24 +102,16 @@ def export_scheduler_model(model): scheduler_model_map = { "EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"), - "EulerAncestralDiscrete": export_scheduler_model( - "EulerAncestralDiscreteScheduler" - ), + "EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"), "LCM": export_scheduler_model("LCMScheduler"), "LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"), "PNDM": export_scheduler_model("PNDMScheduler"), "DDPM": export_scheduler_model("DDPMScheduler"), "DDIM": export_scheduler_model("DDIMScheduler"), - "DPMSolverMultistep": export_scheduler_model( - "DPMSolverMultistepScheduler" - ), + "DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"), "KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"), "DEISMultistep": export_scheduler_model("DEISMultistepScheduler"), - "DPMSolverSinglestep": export_scheduler_model( - "DPMSolverSingleStepScheduler" - ), - "KDPM2AncestralDiscrete": export_scheduler_model( - "KDPM2AncestralDiscreteScheduler" - ), + "DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"), + "KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"), "HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"), } diff --git a/apps/shark_studio/modules/shared_cmd_opts.py b/apps/shark_studio/modules/shared_cmd_opts.py index 535e5d2c7f..dd871383a7 100644 --- a/apps/shark_studio/modules/shared_cmd_opts.py +++ b/apps/shark_studio/modules/shared_cmd_opts.py @@ -130,8 +130,7 @@ def is_valid_file(arg): "--strength", type=float, default=0.8, - help="The strength of change applied on the given input image for " - "img2img.", + help="The strength of change applied on the given input image for " "img2img.", ) p.add_argument( @@ -290,9 +289,7 @@ def is_valid_file(arg): # Model Config and Usage Params ############################################################################## -p.add_argument( - "--device", type=str, default="vulkan", help="Device to run the model." -) +p.add_argument("--device", type=str, default="vulkan", help="Device to run the model.") p.add_argument( "--precision", type=str, default="fp16", help="Precision to run the model." @@ -350,8 +347,7 @@ def is_valid_file(arg): "--batch_count", type=int, default=1, - help="Number of batches to be generated with random seeds in " - "single execution.", + help="Number of batches to be generated with random seeds in " "single execution.", ) p.add_argument( @@ -416,8 +412,7 @@ def is_valid_file(arg): "--use_lora", type=str, default="", - help="Use standalone LoRA weight using a HF ID or a checkpoint " - "file (~3 MB).", + help="Use standalone LoRA weight using a HF ID or a checkpoint " "file (~3 MB).", ) p.add_argument( @@ -493,8 +488,7 @@ def is_valid_file(arg): "--dump_isa", default=False, action="store_true", - help="When enabled call amdllpc to get ISA dumps. " - "Use with dispatch benchmarks.", + help="When enabled call amdllpc to get ISA dumps. " "Use with dispatch benchmarks.", ) p.add_argument( @@ -515,8 +509,7 @@ def is_valid_file(arg): "--enable_rgp", default=False, action=argparse.BooleanOptionalAction, - help="Flag for inserting debug frames between iterations " - "for use with rgp.", + help="Flag for inserting debug frames between iterations " "for use with rgp.", ) p.add_argument( @@ -602,8 +595,7 @@ def is_valid_file(arg): "--progress_bar", default=True, action=argparse.BooleanOptionalAction, - help="Flag for removing the progress bar animation during " - "image generation.", + help="Flag for removing the progress bar animation during " "image generation.", ) p.add_argument( diff --git a/apps/shark_studio/modules/timer.py b/apps/shark_studio/modules/timer.py index 8fd1e6a7df..d6918e9c8c 100644 --- a/apps/shark_studio/modules/timer.py +++ b/apps/shark_studio/modules/timer.py @@ -11,9 +11,7 @@ def __init__(self, timer, category): def __enter__(self): self.start = time.time() - self.timer.base_category = ( - self.original_base_category + self.category + "/" - ) + self.timer.base_category = self.original_base_category + self.category + "/" self.timer.subcategory_level += 1 if self.timer.print_log: @@ -82,10 +80,7 @@ def summary(self): res += " (" res += ", ".join( - [ - f"{category}: {time_taken:.1f}s" - for category, time_taken in additions - ] + [f"{category}: {time_taken:.1f}s" for category, time_taken in additions] ) res += ")" diff --git a/apps/shark_studio/web/api/compat.py b/apps/shark_studio/web/api/compat.py index 80399505c4..3f92c41d02 100644 --- a/apps/shark_studio/web/api/compat.py +++ b/apps/shark_studio/web/api/compat.py @@ -30,17 +30,13 @@ def decode_base64_to_image(encoding): status_code=500, detail="Request to local resource not allowed" ) - headers = ( - {"user-agent": opts.api_useragent} if opts.api_useragent else {} - ) + headers = {"user-agent": opts.api_useragent} if opts.api_useragent else {} response = requests.get(encoding, timeout=30, headers=headers) try: image = Image.open(BytesIO(response.content)) return image except Exception as e: - raise HTTPException( - status_code=500, detail="Invalid image url" - ) from e + raise HTTPException(status_code=500, detail="Invalid image url") from e if encoding.startswith("data:image/"): encoding = encoding.split(";")[1].split(",")[1] @@ -48,9 +44,7 @@ def decode_base64_to_image(encoding): image = Image.open(BytesIO(base64.b64decode(encoding))) return image except Exception as e: - raise HTTPException( - status_code=500, detail="Invalid encoded image" - ) from e + raise HTTPException(status_code=500, detail="Invalid encoded image") from e def encode_pil_to_base64(image): diff --git a/apps/shark_studio/web/index.py b/apps/shark_studio/web/index.py index 0d5de9a839..05a9bc363d 100644 --- a/apps/shark_studio/web/index.py +++ b/apps/shark_studio/web/index.py @@ -130,9 +130,7 @@ def webui(): def resource_path(relative_path): """Get absolute path to resource, works for dev and for PyInstaller""" - base_path = getattr( - sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) - ) + base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) return os.path.join(base_path, relative_path) dark_theme = resource_path("ui/css/sd_dark_theme.css") diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index 3c5497215a..917ac870bf 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -88,17 +88,11 @@ def llm_chat_api(InputData: dict): # print(f"prompt : {InputData['prompt']}") # print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now global vicuna_model - model_name = ( - InputData["model"] if "model" in InputData.keys() else "codegen" - ) + model_name = InputData["model"] if "model" in InputData.keys() else "codegen" model_path = llm_model_map[model_name] device = "cpu-task" precision = "fp16" - max_toks = ( - None - if "max_tokens" not in InputData.keys() - else InputData["max_tokens"] - ) + max_toks = None if "max_tokens" not in InputData.keys() else InputData["max_tokens"] if max_toks is None: max_toks = 128 if model_name == "codegen" else 512 @@ -135,9 +129,7 @@ def llm_chat_api(InputData: dict): # TODO: add role dict for different models if is_chat_completion_api: # TODO: add funtionality for multiple messages - prompt = create_prompt( - model_name, [(InputData["messages"][0]["content"], "")] - ) + prompt = create_prompt(model_name, [(InputData["messages"][0]["content"], "")]) else: prompt = InputData["prompt"] print("prompt = ", prompt) @@ -170,9 +162,7 @@ def llm_chat_api(InputData: dict): end_time = dt.now().strftime("%Y%m%d%H%M%S%f") return { "id": end_time, - "object": "chat.completion" - if is_chat_completion_api - else "text_completion", + "object": "chat.completion" if is_chat_completion_api else "text_completion", "created": int(end_time), "choices": choices, } @@ -248,9 +238,7 @@ def view_json_file(file_obj): with gr.Row(visible=False): with gr.Group(): - config_file = gr.File( - label="Upload sharding configuration", visible=False - ) + config_file = gr.File(label="Upload sharding configuration", visible=False) json_view_button = gr.Button("View as JSON", visible=False) json_view = gr.JSON(visible=False) json_view_button.click( diff --git a/apps/shark_studio/web/ui/common_events.py b/apps/shark_studio/web/ui/common_events.py index 9adf7dd61b..7dda8ba268 100644 --- a/apps/shark_studio/web/ui/common_events.py +++ b/apps/shark_studio/web/ui/common_events.py @@ -13,7 +13,9 @@ def lora_changed(lora_files): # tag frequency percentage, above which a tag is displayed TAG_DISPLAY_THRESHOLD = 0.65 # template for the html used to display a tag - TAG_HTML_TEMPLATE = '{tag}' + TAG_HTML_TEMPLATE = ( + '{tag}' + ) output = [] for lora_file in lora_files: if lora_file == "": diff --git a/apps/shark_studio/web/ui/outputgallery.py b/apps/shark_studio/web/ui/outputgallery.py index 77e60be819..a3de6f7b57 100644 --- a/apps/shark_studio/web/ui/outputgallery.py +++ b/apps/shark_studio/web/ui/outputgallery.py @@ -22,8 +22,7 @@ def outputgallery_filenames(subdir) -> list[str]: new_dir_path = os.path.join(output_dir, subdir) if os.path.exists(new_dir_path): filenames = [ - glob.glob(new_dir_path + "/" + ext) - for ext in ("*.png", "*.jpg", "*.jpeg") + glob.glob(new_dir_path + "/" + ext) for ext in ("*.png", "*.jpg", "*.jpeg") ] return sorted(sum(filenames, []), key=os.path.getmtime, reverse=True) @@ -52,11 +51,7 @@ def output_subdirs() -> list[str]: [path for path in relative_paths if path.isnumeric()], reverse=True ) result_paths = generated_paths + sorted( - [ - path - for path in relative_paths - if (not path.isnumeric()) and path != "." - ] + [path for path in relative_paths if (not path.isnumeric()) and path != "."] ) return result_paths @@ -184,9 +179,7 @@ def on_image_columns_change(columns): def on_select_subdir(subdir) -> list: # evt.value is the subdirectory name new_images = outputgallery_filenames(subdir) - new_label = ( - f"{len(new_images)} images in {os.path.join(output_dir, subdir)}" - ) + new_label = f"{len(new_images)} images in {os.path.join(output_dir, subdir)}" return [ new_images, gr.Gallery( @@ -223,8 +216,7 @@ def on_refresh(current_subdir: str) -> list: ) new_images = outputgallery_filenames(new_subdir) new_label = ( - f"{len(new_images)} images in " - f"{os.path.join(output_dir, new_subdir)}" + f"{len(new_images)} images in " f"{os.path.join(output_dir, new_subdir)}" ) return [ @@ -234,9 +226,7 @@ def on_refresh(current_subdir: str) -> list: ), refreshed_subdirs, new_images, - gr.Gallery( - value=new_images, label=new_label, visible=len(new_images) > 0 - ), + gr.Gallery(value=new_images, label=new_label, visible=len(new_images) > 0), gr.Image( label=new_label, visible=len(new_images) == 0, diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index 66ec452d0b..49eac6d821 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -50,6 +50,7 @@ "stabilityai/sdxl-turbo", ] + def view_json_file(file_path): content = "" with open(file_path, "r") as fopen: @@ -149,7 +150,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str): else: sd_json = new_sd_config for i in sd_json["sd_init_image"]: - if i is not None: + if i is not None: if os.path.isfile(i): sd_image = [Image.open(i, mode="r")] else: @@ -338,9 +339,9 @@ def import_original(original_img, width, height): choices=["None"] + get_checkpoints(base_model_id), ) # with gr.Column(scale=2): - sd_vae_info = ( - str(get_checkpoints_path("vae")) - ).replace("\\", "\n\\") + sd_vae_info = (str(get_checkpoints_path("vae"))).replace( + "\\", "\n\\" + ) sd_vae_info = f"VAE Path: {sd_vae_info}" custom_vae = gr.Dropdown( label=f"Custom VAE Models", @@ -396,12 +397,10 @@ def import_original(original_img, width, height): height=300, interactive=True, ) - with gr.Accordion( - label="Embeddings options", open=True, render=True - ): - sd_lora_info = ( - str(get_checkpoints_path("loras")) - ).replace("\\", "\n\\") + with gr.Accordion(label="Embeddings options", open=True, render=True): + sd_lora_info = (str(get_checkpoints_path("loras"))).replace( + "\\", "\n\\" + ) with gr.Row(): embeddings_config = gr.JSON(min_width=50, scale=1) lora_opt = gr.Dropdown( @@ -651,7 +650,7 @@ def import_original(original_img, width, height): with gr.Column(scale=3, min_width=600): with gr.Group(): sd_gallery = gr.Gallery( - label="Generated images", + label="Generated images", show_label=False, elem_id="gallery", columns=2, @@ -666,9 +665,7 @@ def import_original(original_img, width, height): elem_id="std_output", show_label=False, ) - sd_element.load( - logger.read_sd_logs, None, std_output, every=1 - ) + sd_element.load(logger.read_sd_logs, None, std_output, every=1) sd_status = gr.Textbox(visible=False) with gr.Row(): stable_diffusion = gr.Button("Generate Image(s)") @@ -696,9 +693,7 @@ def import_original(original_img, width, height): value="Clear Config", size="sm", components=sd_json ) with gr.Row(): - save_sd_config = gr.Button( - value="Save Config", size="sm" - ) + save_sd_config = gr.Button(value="Save Config", size="sm") sd_config_name = gr.Textbox( value="Config Name", info="Name of the file this config will be saved to.", @@ -796,9 +791,7 @@ def import_original(original_img, width, height): ) prompt_submit = prompt.submit(**status_kwargs).then(**pull_kwargs) - neg_prompt_submit = negative_prompt.submit(**status_kwargs).then( - **pull_kwargs - ) + neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(**pull_kwargs) generate_click = ( stable_diffusion.click(**status_kwargs) .then(**pull_kwargs) diff --git a/apps/shark_studio/web/ui/utils.py b/apps/shark_studio/web/ui/utils.py index ba62e5adc0..34a94fa014 100644 --- a/apps/shark_studio/web/ui/utils.py +++ b/apps/shark_studio/web/ui/utils.py @@ -6,9 +6,7 @@ def resource_path(relative_path): """Get absolute path to resource, works for dev and for PyInstaller""" - base_path = getattr( - sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) - ) + base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) return os.path.join(base_path, relative_path) diff --git a/apps/shark_studio/web/utils/file_utils.py b/apps/shark_studio/web/utils/file_utils.py index e7b8fd72c4..cae925f5e2 100644 --- a/apps/shark_studio/web/utils/file_utils.py +++ b/apps/shark_studio/web/utils/file_utils.py @@ -23,9 +23,7 @@ def get_path_stem(path): def get_resource_path(relative_path): """Get absolute path to resource, works for dev and for PyInstaller""" - base_path = getattr( - sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) - ) + base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) result = Path(os.path.join(base_path, relative_path)).resolve(strict=False) return result diff --git a/apps/shark_studio/web/utils/metadata/csv_metadata.py b/apps/shark_studio/web/utils/metadata/csv_metadata.py index d617e802bf..d515234083 100644 --- a/apps/shark_studio/web/utils/metadata/csv_metadata.py +++ b/apps/shark_studio/web/utils/metadata/csv_metadata.py @@ -29,9 +29,7 @@ def parse_csv(image_filename: str): has_header = csv.Sniffer().has_header(csv_file.read(2048)) csv_file.seek(0) - reader = ( - csv.DictReader(csv_file) if has_header else csv.reader(csv_file) - ) + reader = csv.DictReader(csv_file) if has_header else csv.reader(csv_file) matches = [ # we rely on humanize and humanizable to work out the parsing of the individual .csv rows diff --git a/apps/shark_studio/web/utils/metadata/format.py b/apps/shark_studio/web/utils/metadata/format.py index f097dab54f..308d9f8e8b 100644 --- a/apps/shark_studio/web/utils/metadata/format.py +++ b/apps/shark_studio/web/utils/metadata/format.py @@ -92,15 +92,11 @@ def compact(metadata: dict) -> dict: result["Hires resize"] = f"{hires_y}x{hires_x}" # remove VAE if it exists and is empty - if (result.keys() & {"VAE"}) and ( - not result["VAE"] or result["VAE"] == "None" - ): + if (result.keys() & {"VAE"}) and (not result["VAE"] or result["VAE"] == "None"): result.pop("VAE") # remove LoRA if it exists and is empty - if (result.keys() & {"LoRA"}) and ( - not result["LoRA"] or result["LoRA"] == "None" - ): + if (result.keys() & {"LoRA"}) and (not result["LoRA"] or result["LoRA"] == "None"): result.pop("LoRA") return result diff --git a/apps/shark_studio/web/utils/metadata/png_metadata.py b/apps/shark_studio/web/utils/metadata/png_metadata.py index d9091afdf4..72f663f246 100644 --- a/apps/shark_studio/web/utils/metadata/png_metadata.py +++ b/apps/shark_studio/web/utils/metadata/png_metadata.py @@ -66,9 +66,7 @@ def parse_generation_parameters(x: str): return res -def try_find_model_base_from_png_metadata( - file: str, folder: str = "models" -) -> str: +def try_find_model_base_from_png_metadata(file: str, folder: str = "models") -> str: custom = "" # Remove extension from file info @@ -101,16 +99,13 @@ def find_model_from_png_metadata( # No matching model was found if not png_custom and not png_hf_id: print( - "Import PNG info: Unable to find a matching model for %s" - % model_file + "Import PNG info: Unable to find a matching model for %s" % model_file ) return png_custom, png_hf_id -def find_vae_from_png_metadata( - key: str, metadata: dict[str, str | int] -) -> str: +def find_vae_from_png_metadata(key: str, metadata: dict[str, str | int]) -> str: vae_custom = "" if key in metadata: diff --git a/apps/shark_studio/web/utils/state.py b/apps/shark_studio/web/utils/state.py index 626d4ce53f..350012c381 100644 --- a/apps/shark_studio/web/utils/state.py +++ b/apps/shark_studio/web/utils/state.py @@ -18,8 +18,7 @@ def get_generation_text_info(seeds, device): text_output = f"prompt={cfg_dump['prompts']}" text_output += f"\nnegative prompt={cfg_dump['negative_prompts']}" text_output += ( - f"\nmodel_id={cfg_dump['hf_model_id']}, " - f"ckpt_loc={cfg_dump['ckpt_loc']}" + f"\nmodel_id={cfg_dump['hf_model_id']}, " f"ckpt_loc={cfg_dump['ckpt_loc']}" ) text_output += f"\nscheduler={cfg_dump['scheduler']}, " f"device={device}" text_output += ( diff --git a/apps/shark_studio/web/utils/tmp_configs.py b/apps/shark_studio/web/utils/tmp_configs.py index 3e6ba46bfe..4415276ea3 100644 --- a/apps/shark_studio/web/utils/tmp_configs.py +++ b/apps/shark_studio/web/utils/tmp_configs.py @@ -7,9 +7,7 @@ def clear_tmp_mlir(): cleanup_start = time() - print( - "Clearing .mlir temporary files from a prior run. This may take some time..." - ) + print("Clearing .mlir temporary files from a prior run. This may take some time...") mlir_files = [ filename for filename in os.listdir(shark_tmp) @@ -18,9 +16,7 @@ def clear_tmp_mlir(): ] for filename in mlir_files: os.remove(shark_tmp + filename) - print( - f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds." - ) + print(f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds.") def clear_tmp_imgs():