diff --git a/.gitignore b/.gitignore index 38b396529..eeaeb1ed4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,6 @@ __pycache__ lena.png lena_result.png lena_test.py -!taesdxl_decoder.pth /repositories /venv /tmp diff --git a/fooocus_version.py b/fooocus_version.py index d37de15b6..ffee39863 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.0.19' +version = '2.0.50' diff --git a/launch.py b/launch.py index a5d5d8f57..4ce6a163d 100644 --- a/launch.py +++ b/launch.py @@ -9,7 +9,7 @@ from modules.launch_util import is_installed, run, python, \ run_pip, repo_dir, git_clone, requirements_met, script_path, dir_repos from modules.model_loader import load_file_from_url -from modules.path import modelfile_path, lorafile_path, vae_approx_path, fooocus_expansion_path +from modules.path import modelfile_path, lorafile_path, vae_approx_path, fooocus_expansion_path, upscale_models_path REINSTALL_ALL = False @@ -67,8 +67,14 @@ def prepare_environment(): ] vae_approx_filenames = [ - ('taesdxl_decoder.pth', - 'https://huggingface.co/lllyasviel/misc/resolve/main/taesdxl_decoder.pth') + ('xlvaeapp.pth', + 'https://huggingface.co/lllyasviel/misc/resolve/main/xlvaeapp.pth') +] + + +upscaler_filenames = [ + ('fooocus_upscaler_s409985e5.bin', + 'https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_upscaler_s409985e5.bin') ] @@ -79,6 +85,8 @@ def download_models(): load_file_from_url(url=url, model_dir=lorafile_path, file_name=file_name) for file_name, url in vae_approx_filenames: load_file_from_url(url=url, model_dir=vae_approx_path, file_name=file_name) + for file_name, url in upscaler_filenames: + load_file_from_url(url=url, model_dir=upscale_models_path, file_name=file_name) load_file_from_url( url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_expansion.bin', diff --git a/models/vae_approx/taesdxl_decoder.pth b/models/vae_approx/taesdxl_decoder.pth deleted file mode 100644 index f2b34452a..000000000 Binary files a/models/vae_approx/taesdxl_decoder.pth and /dev/null differ diff --git a/modules/async_worker.py b/modules/async_worker.py index 43dfcc901..44b683780 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -1,7 +1,6 @@ import threading import torch - buffer = [] outputs = [] @@ -14,14 +13,18 @@ def worker(): import random import copy import modules.default_pipeline as pipeline + import modules.core as core + import modules.flags as flags import modules.path import modules.patch import modules.virtual_memory as virtual_memory + import comfy.model_management from modules.sdxl_styles import apply_style, aspect_ratios, fooocus_expansion from modules.private_logger import log from modules.expansion import safe_str - from modules.util import join_prompts, remove_empty_str + from modules.util import join_prompts, remove_empty_str, HWC3, resize_image + from modules.upscaler import perform_upscale try: async_gradio_app = shared.gradio_root @@ -37,16 +40,21 @@ def progressbar(number, text): outputs.append(['preview', (number, text, None)]) @torch.no_grad() + @torch.inference_mode() def handler(task): prompt, negative_prompt, style_selections, performance_selction, \ - aspect_ratios_selction, image_number, image_seed, sharpness, \ - base_model_name, refiner_model_name, \ - l1, w1, l2, w2, l3, w3, l4, w4, l5, w5 = task + aspect_ratios_selction, image_number, image_seed, sharpness, \ + base_model_name, refiner_model_name, \ + l1, w1, l2, w2, l3, w3, l4, w4, l5, w5, \ + input_image_checkbox, \ + uov_method, uov_input_image = task loras = [(l1, w1), (l2, w2), (l3, w3), (l4, w4), (l5, w5)] raw_style_selections = copy.deepcopy(style_selections) + uov_method = uov_method.lower() + if fooocus_expansion in style_selections: use_expansion = True style_selections.remove(fooocus_expansion) @@ -54,8 +62,80 @@ def handler(task): use_expansion = False use_style = len(style_selections) > 0 - modules.patch.sharpness = sharpness + initial_latent = None + denoising_strength = 1.0 + tiled = False + + if performance_selction == 'Speed': + steps = 30 + switch = 20 + else: + steps = 60 + switch = 40 + + pipeline.clear_all_caches() # save memory + + width, height = aspect_ratios[aspect_ratios_selction] + + if input_image_checkbox: + progressbar(0, 'Image processing ...') + if uov_method != flags.disabled and uov_input_image is not None: + uov_input_image = HWC3(uov_input_image) + H, W, C = uov_input_image.shape + if 'vary' in uov_method: + if H * W + 8 < width * height or float(abs(H * width - W * height)) > 1.5 * float(max(H, W, width, height)): + uov_input_image = resize_image(uov_input_image, width=width, height=height) + print(f'Aspect ratio corrected - users are uploading their own images.') + if 'subtle' in uov_method: + denoising_strength = 0.5 + if 'strong' in uov_method: + denoising_strength = 0.85 + initial_pixels = core.numpy_to_pytorch(uov_input_image) + progressbar(0, 'VAE encoding ...') + initial_latent = core.encode_vae(vae=pipeline.xl_base_patched.vae, pixels=initial_pixels) + B, C, H, W = initial_latent['samples'].shape + width = W * 8 + height = H * 8 + print(f'Final resolution is {str((height, width))}.') + elif 'upscale' in uov_method: + if '1.5x' in uov_method: + f = 1.5 + elif '2x' in uov_method: + f = 2.0 + else: + f = 1.0 + + width = int(W * f) + height = int(H * f) + image_is_super_large = width * height > 2800 * 2800 + progressbar(0, f'Upscaling image from {str((H, W))} to {str((height, width))}...') + + uov_input_image = core.numpy_to_pytorch(uov_input_image) + uov_input_image = perform_upscale(uov_input_image) + uov_input_image = core.pytorch_to_numpy(uov_input_image)[0] + uov_input_image = resize_image(uov_input_image, width=width, height=height) + print(f'Image upscaled.') + + if 'fast' in uov_method or image_is_super_large: + if 'fast' not in uov_method: + print('Image is too large. Directly returned the SR image. ' + 'Usually directly return SR image at 4K resolution ' + 'yields better results than SDXL diffusion.') + outputs.append(['results', [uov_input_image]]) + return + + tiled = True + denoising_strength = 1.0 - 0.618 + steps = int(steps * 0.618) + switch = int(steps * 0.67) + initial_pixels = core.numpy_to_pytorch(uov_input_image) + progressbar(0, 'VAE encoding ...') + initial_latent = core.encode_vae(vae=pipeline.xl_base_patched.vae, pixels=initial_pixels, tiled=True) + B, C, H, W = initial_latent['samples'].shape + width = W * 8 + height = H * 8 + print(f'Final resolution is {str((height, width))}.') progressbar(1, 'Initializing ...') @@ -152,16 +232,6 @@ def handler(task): virtual_memory.try_move_to_virtual_memory(pipeline.xl_refiner.clip.cond_stage_model) - if performance_selction == 'Speed': - steps = 30 - switch = 20 - else: - steps = 60 - switch = 40 - - pipeline.clear_all_caches() # save memory - width, height = aspect_ratios[aspect_ratios_selction] - results = [] all_steps = steps * image_number @@ -174,35 +244,43 @@ def callback(step, x0, x, total_steps, y): outputs.append(['preview', (13, 'Starting tasks ...', None)]) for current_task_id, task in enumerate(tasks): - imgs = pipeline.process_diffusion( - positive_cond=task['c'], - negative_cond=task['uc'], - steps=steps, - switch=switch, - width=width, - height=height, - image_seed=task['task_seed'], - callback=callback) - - for x in imgs: - d = [ - ('Prompt', raw_prompt), - ('Negative Prompt', raw_negative_prompt), - ('Fooocus V2 Expansion', task['expansion']), - ('Styles', str(raw_style_selections)), - ('Performance', performance_selction), - ('Resolution', str((width, height))), - ('Sharpness', sharpness), - ('Base Model', base_model_name), - ('Refiner Model', refiner_model_name), - ('Seed', task['task_seed']) - ] - for n, w in loras: - if n != 'None': - d.append((f'LoRA [{n}] weight', w)) - log(x, d, single_line_number=3) - - results += imgs + try: + imgs = pipeline.process_diffusion( + positive_cond=task['c'], + negative_cond=task['uc'], + steps=steps, + switch=switch, + width=width, + height=height, + image_seed=task['task_seed'], + callback=callback, + latent=initial_latent, + denoise=denoising_strength, + tiled=tiled + ) + + for x in imgs: + d = [ + ('Prompt', raw_prompt), + ('Negative Prompt', raw_negative_prompt), + ('Fooocus V2 Expansion', task['expansion']), + ('Styles', str(raw_style_selections)), + ('Performance', performance_selction), + ('Resolution', str((width, height))), + ('Sharpness', sharpness), + ('Base Model', base_model_name), + ('Refiner Model', refiner_model_name), + ('Seed', task['task_seed']) + ] + for n, w in loras: + if n != 'None': + d.append((f'LoRA [{n}] weight', w)) + log(x, d, single_line_number=3) + + results += imgs + except comfy.model_management.InterruptProcessingException as e: + print('User stopped') + break outputs.append(['results', results]) return diff --git a/modules/core.py b/modules/core.py index d43454aad..d36e906f7 100644 --- a/modules/core.py +++ b/modules/core.py @@ -8,7 +8,7 @@ import comfy.utils from comfy.sd import load_checkpoint_guess_config -from nodes import VAEDecode, EmptyLatentImage +from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models from comfy.model_base import SDXLRefiner from modules.samplers_advanced import KSampler, KSamplerWithRefiner @@ -18,6 +18,9 @@ patch_all() opEmptyLatentImage = EmptyLatentImage() opVAEDecode = VAEDecode() +opVAEEncode = VAEEncode() +opVAEDecodeTiled = VAEDecodeTiled() +opVAEEncodeTiled = VAEEncodeTiled() class StableDiffusionModel: @@ -45,12 +48,14 @@ def to_meta(self): @torch.no_grad() +@torch.inference_mode() def load_model(ckpt_filename): unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename) return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, model_filename=ckpt_filename) @torch.no_grad() +@torch.inference_mode() def load_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0): if strength_model == 0 and strength_clip == 0: return model @@ -61,40 +66,87 @@ def load_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0): @torch.no_grad() +@torch.inference_mode() def generate_empty_latent(width=1024, height=1024, batch_size=1): return opEmptyLatentImage.generate(width=width, height=height, batch_size=batch_size)[0] @torch.no_grad() -def decode_vae(vae, latent_image): - return opVAEDecode.decode(samples=latent_image, vae=vae)[0] +@torch.inference_mode() +def decode_vae(vae, latent_image, tiled=False): + return (opVAEDecodeTiled if tiled else opVAEDecode).decode(samples=latent_image, vae=vae)[0] -def get_previewer(device, latent_format): - from latent_preview import TAESD, TAESDPreviewerImpl - taesd_decoder_path = os.path.abspath(os.path.realpath(os.path.join("models", "vae_approx", - latent_format.taesd_decoder_name))) +@torch.no_grad() +@torch.inference_mode() +def encode_vae(vae, pixels, tiled=False): + return (opVAEEncodeTiled if tiled else opVAEEncode).encode(pixels=pixels, vae=vae)[0] + + +class VAEApprox(torch.nn.Module): + def __init__(self): + super(VAEApprox, self).__init__() + self.conv1 = torch.nn.Conv2d(4, 8, (7, 7)) + self.conv2 = torch.nn.Conv2d(8, 16, (5, 5)) + self.conv3 = torch.nn.Conv2d(16, 32, (3, 3)) + self.conv4 = torch.nn.Conv2d(32, 64, (3, 3)) + self.conv5 = torch.nn.Conv2d(64, 32, (3, 3)) + self.conv6 = torch.nn.Conv2d(32, 16, (3, 3)) + self.conv7 = torch.nn.Conv2d(16, 8, (3, 3)) + self.conv8 = torch.nn.Conv2d(8, 3, (3, 3)) + self.current_type = None + + def forward(self, x): + extra = 11 + x = torch.nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2)) + x = torch.nn.functional.pad(x, (extra, extra, extra, extra)) + for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8]: + x = layer(x) + x = torch.nn.functional.leaky_relu(x, 0.1) + return x + - if not os.path.exists(taesd_decoder_path): - print(f"Warning: TAESD previews enabled, but could not find {taesd_decoder_path}") - return None +VAE_approx_model = None - taesd = TAESD(None, taesd_decoder_path).to(device) +@torch.no_grad() +@torch.inference_mode() +def get_previewer(device, latent_format): + global VAE_approx_model + + if VAE_approx_model is None: + from modules.path import vae_approx_path + vae_approx_filename = os.path.join(vae_approx_path, 'xlvaeapp.pth') + sd = torch.load(vae_approx_filename, map_location='cpu') + VAE_approx_model = VAEApprox() + VAE_approx_model.load_state_dict(sd) + del sd + VAE_approx_model.eval() + + if comfy.model_management.should_use_fp16(): + VAE_approx_model.half() + VAE_approx_model.current_type = torch.float16 + else: + VAE_approx_model.float() + VAE_approx_model.current_type = torch.float32 + + VAE_approx_model.to(comfy.model_management.get_torch_device()) + + @torch.no_grad() + @torch.inference_mode() def preview_function(x0, step, total_steps): - global cv2_is_top with torch.no_grad(): - x_sample = taesd.decoder(torch.nn.functional.avg_pool2d(x0, kernel_size=(2, 2))).detach() * 255.0 - x_sample = einops.rearrange(x_sample, 'b c h w -> b h w c') + x_sample = x0.to(VAE_approx_model.current_type) + x_sample = VAE_approx_model(x_sample) * 127.5 + 127.5 + x_sample = einops.rearrange(x_sample, 'b c h w -> b h w c')[0] x_sample = x_sample.cpu().numpy().clip(0, 255).astype(np.uint8) - return x_sample[0] - - taesd.preview = preview_function + return x_sample - return taesd + return preview_function @torch.no_grad() +@torch.inference_mode() def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sampler_name='dpmpp_2m_sde_gpu', scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, callback_function=None): @@ -124,8 +176,8 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sa def callback(step, x0, x, total_steps): y = None - if previewer and step % 3 == 0: - y = previewer.preview(x0, step, total_steps) + if previewer is not None: + y = previewer(x0, step, total_steps) if callback_function is not None: callback_function(step, x0, x, total_steps, y) pbar.update_absolute(step + 1, total_steps, None) @@ -166,6 +218,7 @@ def callback(step, x0, x, total_steps): @torch.no_grad() +@torch.inference_mode() def ksampler_with_refiner(model, positive, negative, refiner, refiner_positive, refiner_negative, latent, seed=None, steps=30, refiner_switch_step=20, cfg=7.0, sampler_name='dpmpp_2m_sde_gpu', scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None, @@ -196,8 +249,8 @@ def ksampler_with_refiner(model, positive, negative, refiner, refiner_positive, def callback(step, x0, x, total_steps): y = None - if previewer and step % 3 == 0: - y = previewer.preview(x0, step, total_steps) + if previewer is not None: + y = previewer(x0, step, total_steps) if callback_function is not None: callback_function(step, x0, x, total_steps, y) pbar.update_absolute(step + 1, total_steps, None) @@ -243,5 +296,16 @@ def callback(step, x0, x, total_steps): @torch.no_grad() -def image_to_numpy(x): +@torch.inference_mode() +def pytorch_to_numpy(x): return [np.clip(255. * y.cpu().numpy(), 0, 255).astype(np.uint8) for y in x] + + +@torch.no_grad() +@torch.inference_mode() +def numpy_to_pytorch(x): + y = x.astype(np.float32) / 255.0 + y = y[None] + y = np.ascontiguousarray(y.copy()) + y = torch.from_numpy(y).float() + return y diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index a374c79ad..6ad7c3f11 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -3,7 +3,6 @@ import torch import modules.path import modules.virtual_memory as virtual_memory -import comfy.model_management as model_management from comfy.model_base import SDXL, SDXLRefiner from modules.patch import cfg_patched @@ -20,6 +19,8 @@ xl_base_patched_hash = '' +@torch.no_grad() +@torch.inference_mode() def refresh_base_model(name): global xl_base, xl_base_hash, xl_base_patched, xl_base_patched_hash @@ -51,6 +52,8 @@ def refresh_base_model(name): return +@torch.no_grad() +@torch.inference_mode() def refresh_refiner_model(name): global xl_refiner, xl_refiner_hash @@ -86,6 +89,8 @@ def refresh_refiner_model(name): return +@torch.no_grad() +@torch.inference_mode() def refresh_loras(loras): global xl_base, xl_base_patched, xl_base_patched_hash if xl_base_patched_hash == str(loras): @@ -106,6 +111,7 @@ def refresh_loras(loras): @torch.no_grad() +@torch.inference_mode() def clip_encode_single(clip, text, verbose=False): cached = clip.fcs_cond_cache.get(text, None) if cached is not None: @@ -121,6 +127,7 @@ def clip_encode_single(clip, text, verbose=False): @torch.no_grad() +@torch.inference_mode() def clip_encode(sd, texts, pool_top_k=1): if sd is None: return None @@ -145,6 +152,7 @@ def clip_encode(sd, texts, pool_top_k=1): @torch.no_grad() +@torch.inference_mode() def clear_sd_cond_cache(sd): if sd is None: return None @@ -155,11 +163,14 @@ def clear_sd_cond_cache(sd): @torch.no_grad() +@torch.inference_mode() def clear_all_caches(): clear_sd_cond_cache(xl_base_patched) clear_sd_cond_cache(xl_refiner) +@torch.no_grad() +@torch.inference_mode() def refresh_everything(refiner_model_name, base_model_name, loras): refresh_refiner_model(refiner_model_name) if xl_refiner is not None: @@ -184,6 +195,7 @@ def refresh_everything(refiner_model_name, base_model_name, loras): @torch.no_grad() +@torch.inference_mode() def patch_all_models(): assert xl_base is not None assert xl_base_patched is not None @@ -198,14 +210,18 @@ def patch_all_models(): @torch.no_grad() -def process_diffusion(positive_cond, negative_cond, steps, switch, width, height, image_seed, callback): +@torch.inference_mode() +def process_diffusion(positive_cond, negative_cond, steps, switch, width, height, image_seed, callback, latent=None, denoise=1.0, tiled=False): patch_all_models() if xl_refiner is not None: virtual_memory.try_move_to_virtual_memory(xl_refiner.unet.model) virtual_memory.load_from_virtual_memory(xl_base.unet.model) - empty_latent = core.generate_empty_latent(width=width, height=height, batch_size=1) + if latent is None: + empty_latent = core.generate_empty_latent(width=width, height=height, batch_size=1) + else: + empty_latent = latent if xl_refiner is not None: sampled_latent = core.ksampler_with_refiner( @@ -219,6 +235,7 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height latent=empty_latent, steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True, seed=image_seed, + denoise=denoise, callback_function=callback ) else: @@ -229,9 +246,10 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height latent=empty_latent, steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True, seed=image_seed, + denoise=denoise, callback_function=callback ) - decoded_latent = core.decode_vae(vae=xl_base_patched.vae, latent_image=sampled_latent) - images = core.image_to_numpy(decoded_latent) + decoded_latent = core.decode_vae(vae=xl_base_patched.vae, latent_image=sampled_latent, tiled=tiled) + images = core.pytorch_to_numpy(decoded_latent) return images diff --git a/modules/flags.py b/modules/flags.py new file mode 100644 index 000000000..3b8ca0255 --- /dev/null +++ b/modules/flags.py @@ -0,0 +1,10 @@ +disabled = 'Disabled' +subtle_variation = 'Vary (Subtle)' +strong_variation = 'Vary (Strong)' +upscale_15 = 'Upscale (1.5x)' +upscale_2 = 'Upscale (2x)' +upscale_fast = 'Upscale (Fast 2x)' + +uov_list = [ + disabled, subtle_variation, strong_variation, upscale_15, upscale_2, upscale_fast +] diff --git a/modules/html.py b/modules/html.py index 0794df21d..b6c8e3ea2 100644 --- a/modules/html.py +++ b/modules/html.py @@ -83,6 +83,14 @@ box-shadow: none !important; } +.advanced_check_row{ + width: 250px !important; +} + +.min_check{ + min-width: min(1px, 100%) !important; +} + ''' progress_html = '''
diff --git a/modules/patch.py b/modules/patch.py index f9539219b..66b02553c 100644 --- a/modules/patch.py +++ b/modules/patch.py @@ -70,6 +70,35 @@ def sdxl_encode_adm_patched(self, **kwargs): return torch.cat((clip_pooled.to(flat.device), flat), dim=1) +def sdxl_refiner_encode_adm_patched(self, **kwargs): + clip_pooled = kwargs["pooled_output"] + width = kwargs.get("width", 768) + height = kwargs.get("height", 768) + crop_w = kwargs.get("crop_w", 0) + crop_h = kwargs.get("crop_h", 0) + + if kwargs.get("prompt_type", "") == "negative": + aesthetic_score = kwargs.get("aesthetic_score", 2.5) + else: + aesthetic_score = kwargs.get("aesthetic_score", 7.0) + + if kwargs.get("prompt_type", "") == "negative": + width *= 0.8 + height *= 0.8 + elif kwargs.get("prompt_type", "") == "positive": + width *= 1.5 + height *= 1.5 + + out = [] + out.append(self.embedder(torch.Tensor([height]))) + out.append(self.embedder(torch.Tensor([width]))) + out.append(self.embedder(torch.Tensor([crop_h]))) + out.append(self.embedder(torch.Tensor([crop_w]))) + out.append(self.embedder(torch.Tensor([aesthetic_score]))) + flat = torch.flatten(torch.cat(out))[None,] + return torch.cat((clip_pooled.to(flat.device), flat), dim=1) + + def text_encoder_device_patched(): # Fooocus's style system uses text encoder much more times than comfy so this makes things much faster. return comfy.model_management.get_torch_device() @@ -83,3 +112,4 @@ def patch_all(): comfy.k_diffusion.external.DiscreteEpsDDPMDenoiser.forward = patched_discrete_eps_ddpm_denoiser_forward comfy.model_base.SDXL.encode_adm = sdxl_encode_adm_patched + # comfy.model_base.SDXLRefiner.encode_adm = sdxl_refiner_encode_adm_patched diff --git a/modules/path.py b/modules/path.py index 283078824..b2af0896c 100644 --- a/modules/path.py +++ b/modules/path.py @@ -3,6 +3,7 @@ modelfile_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/checkpoints/')) lorafile_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/loras/')) vae_approx_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/vae_approx/')) +upscale_models_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../models/upscale_models/')) temp_outputs_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../outputs/')) fooocus_expansion_path = os.path.abspath(os.path.join(os.path.dirname(__file__), diff --git a/modules/upscaler.py b/modules/upscaler.py new file mode 100644 index 000000000..0d12d86df --- /dev/null +++ b/modules/upscaler.py @@ -0,0 +1,25 @@ +import os +import torch + +from comfy_extras.chainner_models.architecture.RRDB import RRDBNet as ESRGAN +from comfy_extras.nodes_upscale_model import ImageUpscaleWithModel +from collections import OrderedDict +from modules.path import upscale_models_path + +model_filename = os.path.join(upscale_models_path, 'fooocus_upscaler_s409985e5.bin') +opImageUpscaleWithModel = ImageUpscaleWithModel() +model = None + + +def perform_upscale(img): + global model + if model is None: + sd = torch.load(model_filename) + sdo = OrderedDict() + for k, v in sd.items(): + sdo[k.replace('residual_block_', 'RDB')] = v + del sd + model = ESRGAN(sdo) + model.cpu() + model.eval() + return opImageUpscaleWithModel.upscale(model, img)[0] diff --git a/modules/util.py b/modules/util.py index a173cf21d..a98616203 100644 --- a/modules/util.py +++ b/modules/util.py @@ -1,7 +1,90 @@ +import numpy as np import datetime import random import os +from PIL import Image + + +LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) + + +def resize_image(im, width, height, resize_mode=1): + """ + Resizes an image with the specified resize_mode, width, and height. + + Args: + resize_mode: The mode to use when resizing the image. + 0: Resize the image to the specified width and height. + 1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess. + 2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image. + im: The image to resize. + width: The width to resize the image to. + height: The height to resize the image to. + """ + + im = Image.fromarray(im) + + def resize(im, w, h): + return im.resize((w, h), resample=LANCZOS) + + if resize_mode == 0: + res = resize(im, width, height) + + elif resize_mode == 1: + ratio = width / height + src_ratio = im.width / im.height + + src_w = width if ratio > src_ratio else im.width * height // im.height + src_h = height if ratio <= src_ratio else im.height * width // im.width + + resized = resize(im, src_w, src_h) + res = Image.new("RGB", (width, height)) + res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + + else: + ratio = width / height + src_ratio = im.width / im.height + + src_w = width if ratio < src_ratio else im.width * height // im.height + src_h = height if ratio >= src_ratio else im.height * width // im.width + + resized = resize(im, src_w, src_h) + res = Image.new("RGB", (width, height)) + res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + + if ratio < src_ratio: + fill_height = height // 2 - src_h // 2 + if fill_height > 0: + res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) + res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h)) + elif ratio > src_ratio: + fill_width = width // 2 - src_w // 2 + if fill_width > 0: + res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) + res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0)) + + return np.array(res) + + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + def remove_empty_str(items, default=None): items = [x for x in items if x != ""] diff --git a/shared.py b/shared.py index d61d61fc9..f940feee3 100644 --- a/shared.py +++ b/shared.py @@ -1,2 +1 @@ gradio_root = None - diff --git a/webui.py b/webui.py index a4eb0424a..8eec8181e 100644 --- a/webui.py +++ b/webui.py @@ -7,13 +7,14 @@ import fooocus_version import modules.html import modules.async_worker as worker +import modules.flags as flags +import comfy.model_management as model_management from modules.sdxl_styles import style_keys, aspect_ratios, fooocus_expansion, default_styles def generate_clicked(*args): - yield gr.update(interactive=False), \ - gr.update(visible=True, value=modules.html.make_progress_html(1, 'Initializing ...')), \ + yield gr.update(visible=True, value=modules.html.make_progress_html(1, 'Initializing ...')), \ gr.update(visible=True, value=None), \ gr.update(visible=False) @@ -26,13 +27,11 @@ def generate_clicked(*args): flag, product = worker.outputs.pop(0) if flag == 'preview': percentage, title, image = product - yield gr.update(interactive=False), \ - gr.update(visible=True, value=modules.html.make_progress_html(percentage, title)), \ + yield gr.update(visible=True, value=modules.html.make_progress_html(percentage, title)), \ gr.update(visible=True, value=image) if image is not None else gr.update(), \ gr.update(visible=False) if flag == 'results': - yield gr.update(interactive=True), \ - gr.update(visible=False), \ + yield gr.update(visible=False), \ gr.update(visible=False), \ gr.update(visible=True, value=product) finished = True @@ -50,9 +49,28 @@ def generate_clicked(*args): with gr.Column(scale=0.85): prompt = gr.Textbox(show_label=False, placeholder="Type prompt here.", container=False, autofocus=True, elem_classes='type_row', lines=1024) with gr.Column(scale=0.15, min_width=0): - run_button = gr.Button(label="Generate", value="Generate", elem_classes='type_row') - with gr.Row(): - advanced_checkbox = gr.Checkbox(label='Advanced', value=False, container=False) + run_button = gr.Button(label="Generate", value="Generate", elem_classes='type_row', visible=True) + stop_button = gr.Button(label="Stop", value="Stop", elem_classes='type_row', visible=False) + + def stop_clicked(): + model_management.interrupt_current_processing() + return gr.update(interactive=False) + + stop_button.click(stop_clicked, outputs=stop_button, queue=False) + with gr.Row(elem_classes='advanced_check_row'): + input_image_checkbox = gr.Checkbox(label='Input Image', value=False, container=False, elem_classes='min_check') + advanced_checkbox = gr.Checkbox(label='Advanced', value=False, container=False, elem_classes='min_check') + with gr.Row(visible=False) as image_input_panel: + with gr.Column(scale=0.5): + with gr.Accordion(label='Upscale or Variation', open=True): + uov_input_image = gr.Image(label='Drag above image to here', source='upload', type='numpy') + uov_method = gr.Radio(label='Method', choices=flags.uov_list, value=flags.disabled, show_label=False, container=False) + gr.HTML('\U0001F4D4 Document') + input_image_checkbox.change(lambda x: gr.update(visible=x), inputs=input_image_checkbox, outputs=image_input_panel, queue=False) + + # def get_select_index(g, evt: gr.SelectData): + # return g[evt.index]['name'] + # gallery.select(get_select_index, gallery, uov_input_image) with gr.Column(scale=0.5, visible=False) as right_col: with gr.Tab(label='Setting'): performance_selction = gr.Radio(label='Performance', choices=['Speed', 'Quality'], value='Speed') @@ -73,7 +91,7 @@ def refresh_seed(r, s): else: return s - seed_random.change(random_checked, inputs=[seed_random], outputs=[image_seed]) + seed_random.change(random_checked, inputs=[seed_random], outputs=[image_seed], queue=False) with gr.Tab(label='Style'): style_selections = gr.CheckboxGroup(show_label=False, container=False, @@ -105,16 +123,21 @@ def model_refresh_clicked(): results += [gr.update(choices=['None'] + modules.path.lora_filenames), gr.update()] return results - model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls) + model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls, queue=False) - advanced_checkbox.change(lambda x: gr.update(visible=x), advanced_checkbox, right_col) + advanced_checkbox.change(lambda x: gr.update(visible=x), advanced_checkbox, right_col, queue=False) ctrls = [ prompt, negative_prompt, style_selections, performance_selction, aspect_ratios_selction, image_number, image_seed, sharpness ] ctrls += [base_model, refiner_model] + lora_ctrls - run_button.click(fn=refresh_seed, inputs=[seed_random, image_seed], outputs=image_seed)\ - .then(fn=generate_clicked, inputs=ctrls, outputs=[run_button, progress_html, progress_window, gallery]) + ctrls += [input_image_checkbox] + ctrls += [uov_method, uov_input_image] + + run_button.click(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=False), []), outputs=[stop_button, run_button, gallery])\ + .then(fn=refresh_seed, inputs=[seed_random, image_seed], outputs=image_seed)\ + .then(fn=generate_clicked, inputs=ctrls, outputs=[progress_html, progress_window, gallery])\ + .then(lambda: (gr.update(visible=True), gr.update(visible=False)), outputs=[run_button, stop_button]) parser = argparse.ArgumentParser()