diff --git a/README.md b/README.md index 4b3e882..f63741a 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,8 @@ ```bash conda env create -f environment.yml conda activate freecontrol +pip install -U diffusers +pip install -U gradio ``` **Sample Semantic Bases** diff --git a/config/base.yaml b/config/base.yaml index f5ebc26..868cc12 100755 --- a/config/base.yaml +++ b/config/base.yaml @@ -5,70 +5,69 @@ sd_config: steps: 200 # Diffusion scale guidance_scale: 7.5 # Classifier-free guidance scale grad_guidance_scale: 1 # Gradient guidance scale, it will be multiplied with the weight in each type of guidance - sd_version: '2.1' # choice from ['2,1'|'2.0'|'1.4'] + sd_version: "1.5" # choice from ["1.5", "2.1_base"] dreambooth: null # Path to dreambooth. Set to null to disable dreambooth - safetensors: True # whether to use safetenosr. For most ckpt from civitai.com, they used safetensor format - same_latent: False # whether to use the same latent from inversion + safetensors: True # whether to use safetenosr. For most ckpt from civitai.com, they used safetensor format + same_latent: False # whether to use the same latent from inversion appearnace_same_latent: False pca_paths: [] seed: 922 # Seed for random number generator - generated_sample: False # Whether to use generated sample as the reference pose - prompt : "" # Prompt for generated sample + generated_sample: False # Whether to use generated sample as the reference pose + prompt: "" # Prompt for generated sample negative_prompt: "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality" # Prompt for negative sample - target_obj: '' # Target object for inversion + target_obj: "" # Target object for inversion obj_pairs: "" data: inversion: - target_folder: 'dataset/latent' + target_folder: "dataset/latent" num_inference_steps: 999 - method: 'DDIM' # choice from ['DDIM'|'NTI'|'NPT'] - fixed_size: [512,512] # Set to null to disable fixed size, otherwise set to the fixed size (h,w) of the target image + method: "DDIM" # choice from ['DDIM'|'NTI'|'NPT'] + fixed_size: [512, 512] # Set to null to disable fixed size, otherwise set to the fixed size (h,w) of the target image prompt: "A photo of a person, holding his hands, sketch" select_objects: "person" policy: "share" # choice from ['share'|'separate'] - sd_model: '' + sd_model: "" # Guidance config guidance: pca_guidance: - start_step: 0 # Start step for PCA self-attention injection - end_step: 80 # End step for PCA injection - weight: 800 # Weight for PCA self-attention injection + start_step: 0 # Start step for PCA self-attention injection + end_step: 80 # End step for PCA injection + weight: 800 # Weight for PCA self-attention injection select_feature: "key" -# weight1: 0 - structure_guidance: # Parameters for PCA injection - apply: True # Whether apply PCA injection - n_components: 64 # Number of leading components for PCA injection - normalized: True # Whether normalize the PCA score - mask_type: 'cross_attn' # Mask type for PCA injection, choice from ['cross_attn'|'tr'] - penalty_type: 'max' # Penalty type for PCA injection, choice from ['max'|'mean'] - mask_tr: 0.3 # Threshold for PCA score, only applied when normalized is true - penalty_factor: 10 # Penalty factor for masked region, only applied when normalized is true - warm_up: # Parameters for Guidance weight warm up - apply: True # Whether apply warm up - end_step: 10 # End step for warm up - adaptive: # Parameters for adaptive self-attention injection - apply: False # Whether apply adaptive self-attention injection - adaptive_p: 1 # power of the adaptive threshold - blocks: # Blocks for self-attention injection - [ 'up_blocks.1'] - appearance_guidance: # Parameters for texture regulation - apply: True # Whether apply texture regulation - reg_factor: 0.05 # Regularization factor + structure_guidance: # Parameters for PCA injection + apply: True # Whether apply PCA injection + n_components: 64 # Number of leading components for PCA injection + normalized: True # Whether normalize the PCA score + mask_type: "cross_attn" # Mask type for PCA injection, choice from ["cross_attn"|"tr"] + penalty_type: "max" # Penalty type for PCA injection, choice from ["max"|"mean"] + mask_tr: 0.3 # Threshold for PCA score, only applied when normalized is true + penalty_factor: 10 # Penalty factor for masked region, only applied when normalized is true + warm_up: # Parameters for Guidance weight warm up + apply: True # Whether apply warm up + end_step: 10 # End step for warm up + adaptive: # Parameters for adaptive self-attention injection + apply: False # Whether apply adaptive self-attention injection + adaptive_p: 1 # power of the adaptive threshold + blocks: # Blocks for self-attention injection + ["up_blocks.1"] + appearance_guidance: # Parameters for texture regulation + apply: True # Whether apply texture regulation + reg_factor: 0.05 # Regularization factor tr: 0.5 cross_attn_mask_tr: 0.3 - app_n_components: 2 # Number of leading components used to extract mean appearance feature + app_n_components: 2 # Number of leading components used to extract mean appearance feature cross_attn: - start_step: 0 # Start step for cross-attention guidance - end_step: 80 # End step for cross-attention guidance - weight: 0 # Weight for cross-attention guidance - obj_only: True # Whether apply object only cross-attention guidance - soft_guidance: # Parameters for soft guidance - apply: True # Whether apply soft guidance - kernel_size: 5 # Kernel size for Gaussian blur - sigma: 2 # Sigma for Gaussian blur + start_step: 0 # Start step for cross-attention guidance + end_step: 80 # End step for cross-attention guidance + weight: 0 # Weight for cross-attention guidance + obj_only: True # Whether apply object only cross-attention guidance + soft_guidance: # Parameters for soft guidance + apply: True # Whether apply soft guidance + kernel_size: 5 # Kernel size for Gaussian blur + sigma: 2 # Sigma for Gaussian blur blocks: - [ 'up_blocks.1'] + ["up_blocks.1"] diff --git a/config/sdxl_base.yaml b/config/sdxl_base.yaml new file mode 100755 index 0000000..9761cf0 --- /dev/null +++ b/config/sdxl_base.yaml @@ -0,0 +1,73 @@ +# SD config +sd_config: + H: 1024 # Generated image weight + W: 1024 # Generated image height + steps: 200 # Diffusion scale + guidance_scale: 7.5 # Classifier-free guidance scale + grad_guidance_scale: 1 # Gradient guidance scale, it will be multiplied with the weight in each type of guidance + sd_version: "XL-1.0" # choice from ['sdxl-1.0','2,1'|'2.0'|'1.4'] + dreambooth: null # Path to dreambooth. Set to null to disable dreambooth + safetensors: True # whether to use safetenosr. For most ckpt from civitai.com, they used safetensor format + same_latent: False # whether to use the same latent from inversion + appearnace_same_latent: False + pca_paths: + [] + seed: 922 # Seed for random number generator + generated_sample: False # Whether to use generated sample as the reference pose + prompt : "" # Prompt for generated sample + negative_prompt: "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality" # Prompt for negative sample + target_obj: "" # Target object for inversion + obj_pairs: "" + +data: + inversion: + target_folder: "dataset/latent" + num_inference_steps: 999 + method: "DDIM" # choice from ['DDIM'|'NTI'|'NPT'] + fixed_size: [1024,1024] # Set to null to disable fixed size, otherwise set to the fixed size (h,w) of the target image + prompt: "A photo of a person, holding his hands, sketch" + select_objects: "person" + policy: "share" # choice from ['share'|'separate'] + sd_model: "" + +# Guidance config +guidance: + pca_guidance: + start_step: 0 # Start step for PCA self-attention injection + end_step: 80 # End step for PCA injection + weight: 800 # Weight for PCA self-attention injection + select_feature: "key" + structure_guidance: # Parameters for PCA injection + apply: True # Whether apply PCA injection + n_components: 64 # Number of leading components for PCA injection + normalized: True # Whether normalize the PCA score + mask_type: "cross_attn" # Mask type for PCA injection, choice from ['cross_attn'|'tr'] + penalty_type: "max" # Penalty type for PCA injection, choice from ['max'|'mean'] + mask_tr: 0.3 # Threshold for PCA score, only applied when normalized is true + penalty_factor: 10 # Penalty factor for masked region, only applied when normalized is true + warm_up: # Parameters for Guidance weight warm up + apply: True # Whether apply warm up + end_step: 10 # End step for warm up + adaptive: # Parameters for adaptive self-attention injection + apply: False # Whether apply adaptive self-attention injection + adaptive_p: 1 # power of the adaptive threshold + blocks: # Blocks for self-attention injection + ["up_blocks.0"] + appearance_guidance: # Parameters for texture regulation + apply: True # Whether apply texture regulation + reg_factor: 0.05 # Regularization factor + tr: 0.5 + cross_attn_mask_tr: 0.3 + app_n_components: 2 # Number of leading components used to extract mean appearance feature + + cross_attn: + start_step: 0 # Start step for cross-attention guidance + end_step: 80 # End step for cross-attention guidance + weight: 0 # Weight for cross-attention guidance + obj_only: True # Whether apply object only cross-attention guidance + soft_guidance: # Parameters for soft guidance + apply: True # Whether apply soft guidance + kernel_size: 5 # Kernel size for Gaussian blur + sigma: 2 # Sigma for Gaussian blur + blocks: + ["up_blocks.0"] diff --git a/gradio_app.py b/gradio_app.py index 37e5a83..8cc5cb4 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -69,13 +69,20 @@ def freecontrol_generate(condition_image, prompt, scale, ddim_steps, sd_version, input_config = gradio_update_parameter # Load base config - base_config = yaml.load(open("config/base.yaml", "r"), Loader=yaml.FullLoader) + if 'XL' in sd_version: + config_path = 'config/sdxl_base.yaml' + else: + config_path = 'config/base.yaml' + base_config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader) # Update the Default config by gradio config config = merge_sweep_config(base_config=base_config, update=input_config) config = OmegaConf.create(config) # set the correct pipeline - pipeline_name = "SDPipeline" + if 'XL' in sd_version: + pipeline_name = "SDXLPipeline" + else: + pipeline_name = "SDPipeline" pipeline = make_pipeline(pipeline_name, model_path, @@ -277,7 +284,7 @@ def main(): run_button.click(fn=freecontrol_generate, inputs=ips, outputs=[result_gallery]) - block.launch(server_name='0.0.0.0', share=False, server_port=9989) + block.launch(server_name='0.0.0.0', share=True, server_port=9989) if __name__ == '__main__': diff --git a/libs/model/__init__.py b/libs/model/__init__.py index 38a0544..9103940 100755 --- a/libs/model/__init__.py +++ b/libs/model/__init__.py @@ -1,2 +1,3 @@ from .pipelines import make_pipeline -from . import sd_pipeline \ No newline at end of file +from . import sd_pipeline +from . import sdxl_pipeline \ No newline at end of file diff --git a/libs/model/module/xformer_attention.py b/libs/model/module/xformer_attention.py index 7848368..38d0158 100755 --- a/libs/model/module/xformer_attention.py +++ b/libs/model/module/xformer_attention.py @@ -54,15 +54,15 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, scale=scale) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, scale=scale) - value = attn.to_v(encoder_hidden_states, scale=scale) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) # Record the Q,K,V for PCA guidance self.key = key diff --git a/libs/model/sd_pipeline.py b/libs/model/sd_pipeline.py index 2728051..d464bd0 100755 --- a/libs/model/sd_pipeline.py +++ b/libs/model/sd_pipeline.py @@ -304,8 +304,8 @@ def __call__( self.compute_cross_attn_mask(cond_control_ids, cond_example_ids, cond_appearance_ids) if _in_step(self.guidance_config.pca_guidance, i): - # Compute the PCA structure and appearance guidance - # Set the select feature to key by default + # Compute the PCA structure and appearance guidance + # Set the select feature to key by default try: select_feature = self.guidance_config.pca_guidance.select_feature except: diff --git a/libs/model/sdxl_pipeline.py b/libs/model/sdxl_pipeline.py index 5bc43a6..8cb2b51 100644 --- a/libs/model/sdxl_pipeline.py +++ b/libs/model/sdxl_pipeline.py @@ -11,13 +11,11 @@ from diffusers.utils import BaseOutput from numpy import deprecate -from libs.dataset.data_utils import * -from libs.utils.utils import compute_token_merge_indices, extract_data -from .module import prep_conv_layer, prep_unet, get_hidden_state, get_selt_attn_feat_info -from .pipeline_utils import prepare_unet, _in_step, _classify_blocks +from libs.utils.utils import * +from .module import prep_unet_conv, prep_unet_attention, get_self_attn_feat +from .pipeline_utils import _in_step, _classify_blocks from .pipelines import * - # Take from huggingface/diffusers class StableDiffusionPipelineOutput(BaseOutput): """ @@ -90,14 +88,13 @@ def __call__( # Pose2Pose parameters config: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None, - data_samples=None, + inverted_data=None, ): assert config is not None, "config is required for Pose2Pose pipeline" self.input_config = config - self.unet = prep_unet(self.unet) - self.unet = prep_conv_layer(self.unet) - + self.unet = prep_unet_attention(self.unet) + self.unet = prep_unet_conv(self.unet) self.load_pca_info() self.running_device = 'cuda' self.ref_mask_record = None @@ -217,6 +214,8 @@ def __call__( ) # Copy the latent for the appearance sample + if same_latent: + keep_latents = latents latents = torch.cat([latents] * 2, dim=0) ''' @@ -236,8 +235,9 @@ def __call__( [uncond-example, uncond-control, uncond-appearance, cond-example, cond-control, cond-appearance] ''' - num_example_sample: int = len(data_samples['examplar']) - num_appearance_sample: int = len(data_samples['appearance_input']) if data_samples['appearance_input'] is not None else 0 + num_example_sample: int = len(inverted_data['condition_input']) + num_appearance_sample: int = len(inverted_data['appearance_input']) if 'appearance_input' in inverted_data.keys() and inverted_data[ + 'appearance_input'] is not None else 0 num_control_samples: int = batch_size * num_images_per_prompt if num_appearance_sample == 0: num_appearance_sample = num_control_samples @@ -257,25 +257,24 @@ def __call__( example_ids = uncond_example_ids + cond_example_ids keep_ids: List[int] = [ids for ids in np.arange(total_samples).tolist() if ids not in example_ids] - # print("Num example sample", num_example_sample) - # print("Num appearance sample", num_appearance_sample) - # print("Num Control sample",num_control_samples) - # print("Example ids",example_ids) - # print("Keep ids",keep_ids) - # print("Control ids",cond_control_ids) - # print("Appearance ids",cond_appearance_ids) - # print("Total samples",total_samples) - # print(latents.shape) - # exit() - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # Prepare guidance configs + self.guidance_config = config.guidance # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids = self._get_add_time_ids( - original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( @@ -283,6 +282,7 @@ def __call__( negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids @@ -294,9 +294,7 @@ def __call__( prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) - #print("Add time ids shape: ", add_time_ids.shape) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) - #print("Reapted add time ids shape: ", add_time_ids.shape) # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -312,45 +310,34 @@ def __call__( num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): score = None - - assert do_classifier_free_guidance, "Currently only support classifier free guidance" # Process the latent step_timestep: int = t.detach().cpu().item() - # print(data_samples['examplar'][0]['all_latents'].keys()) - assert step_timestep in data_samples['examplar'][0][ + assert step_timestep in inverted_data['condition_input'][0][ 'all_latents'].keys(), f"timestep {step_timestep} not in inverse samples keys" - data_samples_latent: torch.Tensor = data_samples['examplar'][0]['all_latents'][step_timestep] + data_samples_latent: torch.Tensor = inverted_data['condition_input'][0]['all_latents'][step_timestep] data_samples_latent = data_samples_latent.to(device=self.running_device, dtype=prompt_embeds.dtype) if i == 0 and same_latent: latents = data_samples_latent.repeat(2, 1, 1, 1) - record_latent = latents.chunk(2)[0] - free_latent = latents.chunk(2)[1] - free_latent.requires_grad_(False) - record_latent.requires_grad_(True) - latents = torch.cat([record_latent, free_latent], dim=0) - copy_record_latent = record_latent.detach() - copy_record_latent.requires_grad_(False) - #print(latents.shape,prompt_embeds.shape) - - - if config.data.inversion.method == 'DDIM': - - latent_list: List[torch.Tensor] = [copy_record_latent,free_latent, data_samples_latent, record_latent,free_latent] + if i == 0 and same_latent and config.sd_config.appearnace_same_latent: + latents = data_samples_latent.repeat(2, 1, 1, 1) + elif i == 0 and same_latent and not config.sd_config.appearnace_same_latent: + latents = torch.cat([data_samples_latent, keep_latents], dim=0) + print("Latents shape", latents.shape) + latent_list: List[torch.Tensor] = [latents, data_samples_latent, latents] else: - latent_list: List[torch.Tensor] = [data_samples_latent, latents, data_samples_latent, latents] + raise NotImplementedError("Currently only support DDIM method") # check if appearance_input is in inverted_data - if 'appearance_input' in data_samples.keys() and data_samples['appearance_input'] is not None: - appearance_input = data_samples['appearance_input'][0]['all_latents'][step_timestep].to(device=self.running_device, dtype=prompt_embeds.dtype) + if 'appearance_input' in inverted_data.keys() and inverted_data['appearance_input'] is not None: + appearance_input = inverted_data['appearance_input'][0]['all_latents'][step_timestep].to( + device=self.running_device, dtype=prompt_embeds.dtype) # replace the second batch of the last latent in the latent list latent_list[1] = appearance_input latent_list[-1] = appearance_input @@ -358,22 +345,17 @@ def __call__( latent_model_input: torch.Tensor = torch.cat(latent_list, dim=0).to('cuda') latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - - # process the prompt embedding if config.data.inversion.method == 'DDIM': - - - ref_prompt_embeds = data_samples['examplar'][0]['prompt_embeds'].to('cuda') - ref_added_text_embeds = data_samples['examplar'][0]['add_text_embeds'].to('cuda') - ref_add_time_ids = data_samples['examplar'][0]['add_time_id'].to('cuda') + ref_prompt_embeds = inverted_data['condition_input'][0]['prompt_embeds'].to('cuda') + ref_added_text_embeds = inverted_data['condition_input'][0]['add_text_embeds'].to('cuda') + ref_add_time_ids = inverted_data['condition_input'][0]['add_time_id'].to('cuda') step_prompt_embeds_list: List[torch.Tensor] = [prompt_embeds.chunk(2)[0]] * 2 + [ ref_prompt_embeds] + [prompt_embeds.chunk(2)[1]] * 2 step_add_text_embeds_list: List[torch.Tensor] = [add_text_embeds.chunk(2)[0]] * 2 + [ add_text_embeds.chunk(2)[1]] + [add_text_embeds.chunk(2)[1]] * 2 - step_add_time_ids_list : List[torch.Tensor] = [add_time_ids.chunk(2)[0]] * 2 + [ + step_add_time_ids_list: List[torch.Tensor] = [add_time_ids.chunk(2)[0]] * 2 + [ add_time_ids.chunk(2)[1]] + [add_time_ids.chunk(2)[1]] * 2 @@ -381,25 +363,19 @@ def __call__( raise NotImplementedError("Currently only support DDIM method") step_prompt_embeds = torch.cat(step_prompt_embeds_list, dim=0).to('cuda') - step_add_text_embeds = torch.cat(step_add_text_embeds_list, dim=0).to('cuda') + step_add_text_embeds = torch.cat(step_add_text_embeds_list, dim=0).to('cuda') step_add_time_ids = torch.cat(step_add_time_ids_list, dim=0).to('cuda') # predict the noise residual added_cond_kwargs = {"text_embeds": step_add_text_embeds, "time_ids": step_add_time_ids} - # print(add_text_embeds.shape,add_text_embeds.shape) - # print(latent_model_input.shape,step_prompt_embeds.shape) - # exit() - require_grad_flag = False # Check if the current step is in the guidance step if _in_step(self.guidance_config.pca_guidance, i) or _in_step(self.guidance_config.cross_attn, i): require_grad_flag = True - # Only require grad when need to compute the gradient for guidance if require_grad_flag: latent_model_input.requires_grad_(True) - latent_model_input[cond_appearance_ids].requires_grad_(True) # predict the noise residual noise_pred = self.unet( latent_model_input, @@ -425,21 +401,29 @@ def __call__( self.cross_seg = None if _in_step(self.guidance_config.cross_attn, i): # Compute the Cross-Attention loss - cross_attn_loss = self.compute_cross_attn_loss(cond_control_ids, cond_example_ids, - cond_appearance_ids, i) - #loss += cross_attn_loss + self.compute_cross_attn_mask(cond_control_ids, cond_example_ids, cond_appearance_ids) if _in_step(self.guidance_config.pca_guidance, i): - # Compute the PCA Semantic loss - pca_loss = self.compute_pca_loss(cond_control_ids, cond_example_ids, cond_appearance_ids, i) - # pca_loss = self.compute_conv_loss(cond_control_ids,cond_example_ids, cond_appearance_ids,i) - loss += pca_loss + # Compute the PCA structure and appearance guidance + # Set the select feature to key by default + try: + select_feature = self.guidance_config.pca_guidance.select_feature + except: + select_feature = "key" + + if select_feature == 'query' or select_feature == 'key' or select_feature == 'value': + pca_loss = self.compute_attn_pca_loss(cond_control_ids, cond_example_ids, cond_appearance_ids, + i) + loss += pca_loss + elif select_feature == 'conv': + pca_loss = self.compute_conv_pca_loss(cond_control_ids, cond_example_ids, cond_appearance_ids, + i) + loss += pca_loss temp_control_ids = None if isinstance(loss, torch.Tensor): - gradient = torch.autograd.grad(loss, record_latent, allow_unused=True)[0] - #gradient = torch.autograd.grad(loss, latent_model_input, allow_unused=True)[0] - #gradient = gradient[cond_control_ids] + gradient = torch.autograd.grad(loss, latent_model_input, allow_unused=True)[0] + gradient = gradient[cond_control_ids] assert gradient is not None, f"Step {i}: grad is None" score = gradient.detach() temp_control_ids: List[int] = np.arange(num_control_samples).tolist() @@ -476,7 +460,7 @@ def __call__( latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) with torch.no_grad(): image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - #image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + # image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) @@ -491,8 +475,7 @@ def __call__( image = self.image_processor.postprocess(image, output_type=output_type) # # Offload all models - # self.maybe_free_model_hooks() - + self.maybe_free_model_hooks() if not return_dict: return (image,) @@ -503,40 +486,6 @@ def load_pca_info(self): path = self.input_config.sd_config.pca_paths[0] self.loaded_pca_info = torch.load(path) - def tensor_to_image(self, tensor): - """将tensor的前两个维度转换为指定形状的图像""" - h, w = int((tensor.size(1)) ** 0.5), int((tensor.size(1)) ** 0.5) # 计算要reshape的目标尺寸 - return tensor.view(tensor.size(0), h, w, -1) - - def visualize_and_save(self, feat_tensor, mask, output_address, step_i): - # 确保输出目录存在 - if not os.path.exists(output_address): - os.makedirs(output_address) - - feat_images = self.tensor_to_image(feat_tensor) - mask_images = self.tensor_to_image(mask) - - bs = feat_tensor.size(0) # 获取批次大小 - - for i in range(feat_tensor.shape[-1]): - imgs = [feat_images[:, :, :, i][j] * 255 for j in range(bs)] - masks = [mask_images[:, :, :, i][j] for j in range(bs)] - apply_masks = [imgs[j] * masks[j] for j in range(bs)] - - grid = torch.stack(imgs + masks + apply_masks, dim=0) - - # 打印所有的形状信息 - for j in range(bs): - print(f"Image {j + 1} shape:", imgs[j].shape) - print(f"Mask {j + 1} shape:", masks[j].shape) - print(f"Applied Mask {j + 1} shape:", apply_masks[j].shape) - print("Grid shape:", grid.shape) - grid = grid.unsqueeze(1) - grid = F.interpolate(grid, size=(256, 256), mode='nearest') - step_i = "{:03d}".format(int(step_i)) - - vutils.save_image(grid, os.path.join(output_address, f'Step_{step_i}_component_{i}.jpg'), normalize=True) - def _compute_feat_loss(self, feat, pca_info, cond_control_ids, cond_example_ids, cond_appearance_ids, step, reg_included=False, reg_feature=None, ): feat_copy = feat if reg_feature is None else reg_feature @@ -544,33 +493,23 @@ def _compute_feat_loss(self, feat, pca_info, cond_control_ids, cond_example_ids, # Feat in the shape [bs,h*w,channels] feat_mean: torch.Tensor = pca_info['mean'].to(self.running_device) feat_basis: torch.Tensor = pca_info['basis'].to(self.running_device) - n_components: int = self.guidance_config.pca_guidance.pca.n_components + n_components: int = self.guidance_config.pca_guidance.structure_guidance.n_components # print(feat.shape) centered_feat = feat - feat_mean # Compute the projection feat_proj = torch.matmul(centered_feat, feat_basis.T) - feat_proj_unnorm = feat_proj.clone()[:, :, :n_components] - if self.guidance_config.pca_guidance.pca.normalized: + if self.guidance_config.pca_guidance.structure_guidance.normalized: # Normalize the projection by the max and min value feat_proj = feat_proj.permute(0, 2, 1) feat_proj_max = feat_proj.max(dim=-1, keepdim=True)[0].detach() feat_proj_min = feat_proj.min(dim=-1, keepdim=True)[0].detach() feat_proj = (feat_proj - feat_proj_min) / (feat_proj_max - feat_proj_min + 1e-7) feat_proj = feat_proj.permute(0, 2, 1) - - # Normalize the projection by the mean and standard deviation - # feat_proj = feat_proj.permute(0, 2, 1) - # feat_proj_mean = feat_proj.mean(dim=-1, keepdim=True) - # feat_proj_std = feat_proj.std(dim=-1, keepdim=True) - # feat_proj = (feat_proj - feat_proj_mean) / (feat_proj_std + 1e-7) - # - # feat_proj = feat_proj.permute(0, 2, 1) - feat_proj = feat_proj[:, :, :n_components] - if self.guidance_config.pca_guidance.pca.mask_tr > 0: + if self.guidance_config.pca_guidance.structure_guidance.mask_tr > 0: # Get the activation mask for each component # Check the policy for pca guidance if self.input_config.data.inversion.policy == 'share': @@ -578,14 +517,14 @@ def _compute_feat_loss(self, feat, pca_info, cond_control_ids, cond_example_ids, ref_feat = feat_proj[cond_example_ids].mean(dim=0, keepdim=True) num_control_samples: int = len(cond_control_ids) ref_feat = ref_feat.repeat(num_control_samples, 1, 1) + res = int(math.sqrt(feat_proj.shape[1])) # Select the mask for the control samples - if self.guidance_config.pca_guidance.pca.mask_type == 'tr': - ref_mask = ref_feat > self.guidance_config.pca_guidance.pca.mask_tr - elif self.guidance_config.pca_guidance.pca.mask_type == 'cross_attn': + if self.guidance_config.pca_guidance.structure_guidance.mask_type == 'tr': + ref_mask = ref_feat > self.guidance_config.pca_guidance.structure_guidance.mask_tr + elif self.guidance_config.pca_guidance.structure_guidance.mask_type == 'cross_attn': # Currently, only take the first object pair obj_pair = self.record_obj_pairs[0] example_token_ids = obj_pair['ref'] - example_sample_ids = self.new_id_record[0] example_sample_probs = self.cross_attn_probs['probs'][example_sample_ids] example_token_probs = example_sample_probs[:, example_token_ids].sum(dim=1) @@ -593,59 +532,39 @@ def _compute_feat_loss(self, feat, pca_info, cond_control_ids, cond_example_ids, example_token_probs = (example_token_probs - example_token_probs.min(dim=-1, keepdim=True)[0]) / ( example_token_probs.max(dim=-1, keepdim=True)[0] - example_token_probs.min(dim=-1, keepdim=True)[0] + 1e-7) + mask_res = int(math.sqrt(example_token_probs.shape[1])) + if res != mask_res: + example_token_probs = example_token_probs.unsqueeze(0).reshape(1, 1, mask_res, mask_res) + example_token_probs = F.interpolate(example_token_probs, size=(res, res), + mode='bicubic').squeeze(1).reshape(1, -1) - # print(example_token_probs.shape) - # print(example_token_probs.max()) - # print(example_token_probs) - # exit() - print(example_token_probs) - ref_mask = example_token_probs > self.guidance_config.pca_guidance.pca.mask_tr - print(ref_mask) + ref_mask = example_token_probs > self.guidance_config.pca_guidance.structure_guidance.mask_tr ref_mask = ref_mask.to(self.running_device).unsqueeze(-1).repeat(num_control_samples, 1, ref_feat.shape[-1]) - # print(ref_mask.shape,ref_feat.shape) - # # print(ref_mask) - # print(feat_proj.shape,feat_proj[cond_control_ids].shape) - # exit() # Compute the loss temp_loss: torch.Tensor = F.mse_loss(ref_feat[ref_mask], feat_proj[cond_control_ids][ref_mask]) - mse_loss_value = temp_loss.detach().cpu().item() - # Compute l2 penalty loss - penalty_factor: float = float(self.guidance_config.pca_guidance.pca.penalty_factor) + penalty_factor: float = float(self.guidance_config.pca_guidance.structure_guidance.penalty_factor) fliped_mask = ~ref_mask - if self.guidance_config.pca_guidance.pca.penalty_type == 'max': + if self.guidance_config.pca_guidance.structure_guidance.penalty_type == 'max': # Compute the max value in the fliped_mask score1 = (feat_proj[cond_example_ids] * fliped_mask).max(dim=1, keepdim=True)[0] score2 = F.relu((feat_proj[cond_control_ids] * fliped_mask) - score1) - penalty_loss = penalty_factor * F.mse_loss(score2, torch.zeros_like(score2)) - penalty_loss_value = penalty_loss.detach().cpu().item() + temp_loss += penalty_loss + elif self.guidance_config.pca_guidance.structure_guidance.penalty_type == 'hard': + # Compute the max value in the fliped_mask + score1 = (feat_proj[cond_example_ids] * fliped_mask).max(dim=1, keepdim=True)[0] + # assign hard value to the score1 + score1 = torch.ones_like(score1) * self.guidance_config.pca_guidance.structure_guidance.hard_value + score2 = F.relu((feat_proj[cond_control_ids] * fliped_mask) - score1) + penalty_loss = penalty_factor * F.mse_loss(score2, torch.zeros_like(score2)) temp_loss += penalty_loss else: raise NotImplementedError("Only max penalty type has been implemented") - - # score = torch.relu((feat_proj[cond_control_ids]*fliped_mask).mean(dim=1,keepdim=True)-feat_proj[cond_example_ids][fliped_mask].mean(dim=1,keepdim=True)) - - # penalty_loss = penalty_factor * F.mse_loss(score,torch.zeros_like(score)) - # score1 = (feat_proj[cond_control_ids]*fliped_mask).mean(dim=1,keepdim=True) - # score2 = (feat_proj[cond_example_ids]*fliped_mask).mean(dim=1,keepdim=True) - - # temp_loss += F.(ref_feat[fliped_mask],feat_proj[cond_control_ids][fliped_mask]) * penalty_factor loss.append(temp_loss) - num_masked_points = fliped_mask.sum(dim=1).detach().cpu() / fliped_mask.shape[1] - print("Num masked points", num_masked_points.tolist()) - print("MSE Loss", mse_loss_value, "Penalty", penalty_loss_value) - - # save the scores - # self.visualize_and_save(feat_proj[cond_example_ids+cond_control_ids+cond_appearance_ids],torch.cat([ref_mask,ref_mask,ref_mask],dim=0), - # 'experiments/vis/',self.current_step) - - - elif self.input_config.data.inversion.policy == 'separate': - raise NotImplementedError("Separate policy not implemented yet") else: raise NotImplementedError("Only \'share\' policy has been implemented") @@ -671,172 +590,50 @@ def _compute_feat_loss(self, feat, pca_info, cond_control_ids, cond_example_ids, loss.append(temp_loss) # Compute the texture regularization loss - reg_factor = float(self.guidance_config.pca_guidance.texture_regulation.reg_factor) - reg_method: int = int(self.guidance_config.pca_guidance.texture_regulation.reg_method) + reg_factor = float(self.guidance_config.pca_guidance.appearance_guidance.reg_factor) if reg_included and reg_factor > 0: - if reg_method == 1: - # Compute the segmentation mask - - obj_pair = self.record_obj_pairs[0] - - example_token_ids = obj_pair['ref'] - - example_sample_ids = self.new_id_record[0] - # print(example_sample_ids, self.cross_attn_probs['probs'].shape, obj_pair) - - example_sample_probs = self.cross_attn_probs['probs'][example_sample_ids] - example_token_probs = example_sample_probs[:, example_token_ids].sum(dim=1) - # use max, min value to normalize the probs - example_token_probs = (example_token_probs - example_token_probs.min(dim=-1, keepdim=True)[0]) / ( - example_token_probs.max(dim=-1, keepdim=True)[0] - - example_token_probs.min(dim=-1, keepdim=True)[ - 0] + 1e-7) - control_mask = example_token_probs > self.guidance_config.pca_guidance.texture_regulation.cross_attn_mask_tr - - appearance_token_id = obj_pair['gen'] - appearance_sample_ids = self.new_id_record[2] - appearance_sample_probs = self.cross_attn_probs['probs'][appearance_sample_ids] - appearance_token_probs = appearance_sample_probs[:, appearance_token_id].sum(dim=1) - # use max, min value to normalize the probs - appearance_token_probs = (appearance_token_probs - appearance_token_probs.min(dim=-1, keepdim=True)[ - 0]) / ( - appearance_token_probs.max(dim=-1, keepdim=True)[0] - - appearance_token_probs.min(dim=-1, keepdim=True)[0] + 1e-7) - appearance_mask = appearance_token_probs > self.guidance_config.pca_guidance.texture_regulation.cross_attn_mask_tr - try: - app_n_components = self.guidance_config.pca_guidance.texture_regulation.app_n_components - except: - app_n_components = n_components - - def compute_app_loss(feature, weights, tr, control_ids, appearance_ids, control_mask, appearance_mask): - - weights = weights[:, :, :app_n_components] - B, C, W = feature.shape - _, _, K = weights.shape - mask = (weights > tr).float() - - mask = weights * mask - # mask[1] = mask[0] - expanded_mask = mask.unsqueeze(-2).expand(B, C, W, K) - # print(expanded_mask.shape, feature.shape) - - masked_feature = feature.unsqueeze(-1) * expanded_mask - count = mask.sum(dim=1, keepdim=True) - avg_feature = masked_feature.sum(dim=1) / (count + 1e-5) - return F.mse_loss(avg_feature[control_ids], avg_feature[appearance_ids].detach()) - - # B, C, W = feature.shape - # _, _, K = weights.shape - # mask = (weights > tr).float() - # #mask = mask * weights - # - # inside_mask = torch.cat([torch.zeros_like(control_mask),control_mask,appearance_mask],dim=0) - # outside_mask = torch.cat([torch.ones_like(control_mask),~control_mask,~appearance_mask],dim=0) - # - # inside_mask = inside_mask.unsqueeze(-1).expand(B,C,K) - # inside_mask = torch.ones_like(inside_mask) * mask - # outside_mask = outside_mask.unsqueeze(-1).expand(B,C,K) * mask - # - # #print(feature.shape,mask.shape, weights.shape,control_mask.shape,appearance_mask.shape) - # - # inside_mask = inside_mask.unsqueeze(-2).expand(B, C, W, K) - # outside_mask = outside_mask.unsqueeze(-2).expand(B, C, W, K) - # - # #print(inside_mask.shape, outside_mask.shape) - # - # - # inside_masked_feature = feature.unsqueeze(-1) * inside_mask - # outside_masked_feature = feature.unsqueeze(-1) * outside_mask - # - # #print(inside_masked_feature.shape, outside_masked_feature.shape) - # inside_count = inside_mask.sum(dim=1, keepdim=True) - # outside_count = outside_mask.sum(dim=1, keepdim=True) - # - # inside_avg_feature = inside_masked_feature.mean(dim=1) #/ (inside_count + 1e-5) - # outside_avg_feature = outside_masked_feature.mean(dim=1) #/ (outside_count + 1e-5) - # - # inside_non_zero_mask = (inside_avg_feature.detach() != 0).any(dim=1).float() - # outside_non_zero_mask = (outside_avg_feature.detach() != 0).any(dim=1).float() - # - # inside_non_zero_mask = inside_non_zero_mask.unsqueeze(1).expand(B,W,K) - # outside_non_zero_mask = outside_non_zero_mask.unsqueeze(1).expand(B,W,K) - # - # # inside_avg_feature = inside_avg_feature * inside_non_zero_mask - # # outside_avg_feature = outside_avg_feature * outside_non_zero_mask - # - # inside_loss = F.mse_loss(inside_avg_feature[control_ids], inside_avg_feature[appearance_ids].detach()) - # #outside_loss = F.mse_loss(outside_avg_feature[control_ids], outside_avg_feature[appearance_ids].detach()) - # - # return inside_loss #+ outside_loss - # return F.mse_loss(avg_feature[control_ids], avg_feature[appearance_ids].detach()) - - temp_loss_list = [] - for temp_feat in feat_copy: - # Compute the texture regularization loss - temp_loss: torch.Tensor = compute_app_loss(temp_feat, feat_proj, - self.guidance_config.pca_guidance.texture_regulation.tr, - cond_control_ids, cond_appearance_ids, control_mask, - appearance_mask) - temp_loss_list.append(temp_loss) - temp_loss = torch.stack(temp_loss_list).mean() - - elif reg_method == 2: - # compute the max value of feat_proj's dim 2, use that position to extract the feature in feat_copy - # print(feat_proj.shape) - max_id = feat_proj.max(dim=1)[1] - # print(max_id[0]) - # print(max_id.shape) - # print(feat_copy.shape) - temp_loss_list = [] - for temp_feat in feat_copy: - selected_feat = temp_feat[torch.arange(temp_feat.size(0)).unsqueeze(1), max_id] - # print(selected_feat.shape) - temp_loss = F.mse_loss(selected_feat[cond_control_ids], selected_feat[cond_appearance_ids].detach()) - temp_loss_list.append(temp_loss) - temp_loss = torch.stack(temp_loss_list).mean() - # exit() - - elif reg_method == 3: - - # Compute the mean value of un-nomalized feature - print(feat_proj_unnorm.shape) - feat_mean = feat_proj_unnorm.mean(dim=1) - print(feat_mean) - print(feat_mean.shape) - print("?????????????????????????????????????????????????????") - temp_loss = F.mse_loss(feat_mean[cond_control_ids], feat_mean[cond_appearance_ids].detach()) - elif reg_method == 4: - # Compute the mean and variance loss of the un-nomalized feature - temp_loss = None - feat_mean = feat_proj_unnorm.mean(dim=1) - feat_var = feat_proj_unnorm.var(dim=1) - mean_loss = F.mse_loss(feat_mean[cond_control_ids], feat_mean[cond_appearance_ids].detach()) - var_loss = F.mse_loss(feat_var[cond_control_ids], feat_var[cond_appearance_ids].detach()) - var_loss = var_loss / var_loss * mean_loss.detach() - temp_loss = mean_loss + var_loss - print("Mean loss", mean_loss.detach().cpu().item(), "Var loss", var_loss.detach().cpu().item()) - # print(feat_mean.shape, feat_var.shape) - # exit() - - elif reg_method == 5: - temp_loss_list = [] - for temp_feat in feat_copy: - feat_mean = temp_feat.mean(dim=1) - temp_loss = F.mse_loss(feat_mean[cond_control_ids], feat_mean[cond_appearance_ids].detach()) - temp_loss_list.append(temp_loss) - temp_loss = torch.stack(temp_loss_list).mean() - # - - else: - raise NotImplementedError("Only method 1 and 2 have been implemented") - # if step < 30: - # reg_factor = (step/30) * reg_factor + app_n_components = self.guidance_config.pca_guidance.appearance_guidance.app_n_components + + def compute_app_loss(feature, weights, tr, control_ids, appearance_ids): + """ + Compute the weighted average feature loss based on the given weights and feature + + :return: loss: MSE_loss between two weighted average feature + """ + weights = weights[:, :, :app_n_components] + B, C, W = feature.shape + _, _, K = weights.shape + mask = (weights > tr).float() + mask = weights * mask + expanded_mask = mask.unsqueeze(-2).expand(B, C, W, K) + masked_feature = feature.unsqueeze(-1) * expanded_mask + count = mask.sum(dim=1, keepdim=True) + avg_feature = masked_feature.sum(dim=1) / (count + 1e-5) + return F.mse_loss(avg_feature[control_ids], avg_feature[appearance_ids].detach()) + + temp_loss_list = [] + for temp_feat in feat_copy: + # Compute the appearance guidance loss + temp_loss: torch.Tensor = compute_app_loss(temp_feat, feat_proj, + self.guidance_config.pca_guidance.appearance_guidance.tr, + cond_control_ids, cond_appearance_ids) + temp_loss_list.append(temp_loss) + temp_loss = torch.stack(temp_loss_list).mean() loss.append(temp_loss * reg_factor) - print("Texture regularization loss", temp_loss.detach().cpu().item()) loss = torch.stack(loss).sum() return loss - def compute_conv_loss(self, cond_control_ids, cond_example_ids, cond_appearance_ids, i): + def compute_conv_pca_loss(self, cond_control_ids, cond_example_ids, cond_appearance_ids, step_i): + """ + Compute the PCA Conv loss based on the given condition control, example, and appearance IDs. + + This method is not used in FreeControl method, only for ablation study + :param cond_control_ids: + :param cond_example_ids: + :param cond_appearance_ids: + :param step_i: + :return: + """ # The Conv loss is not used in our method # The new tensor follows this order: example, control, appearance combined_list = cond_example_ids + cond_control_ids + cond_appearance_ids @@ -844,25 +641,34 @@ def compute_conv_loss(self, cond_control_ids, cond_example_ids, cond_appearance_ new_cond_control_ids = np.arange(len(cond_example_ids), len(cond_control_ids) + len(cond_example_ids)).tolist() new_cond_appearance_ids = np.arange(len(cond_control_ids) + len(cond_example_ids), len(combined_list)).tolist() - step_pca_info: dict = self.loaded_pca_info[i] - conv_feat = get_hidden_state(self.unet) - conv_feat = conv_feat[combined_list] - conv_feat = conv_feat.permute(0, 2, 3, 1).contiguous().reshape(len(combined_list), -1, conv_feat.shape[1]) - # print(conv_feat.shape) - conv_pca_info: dict = step_pca_info['conv'] - loss = self._compute_feat_loss(conv_feat, conv_pca_info, new_cond_control_ids, new_cond_example_ids, - new_cond_appearance_ids, reg_included=True) - + step_pca_info: dict = self.loaded_pca_info[step_i] + total_loss = [] + # get correct blocks + for block_name in self.guidance_config.pca_guidance.blocks: + if "up_blocks" in block_name: + block_id = int(block_name.split(".")[-1]) + else: + raise NotImplementedError("Only support up_blocks") + for j in range(len(self.unet.up_blocks[block_id].resnets)): + name = f"up_blocks.{block_id}.resnets.{j}" + conv_feat = self.unet.up_blocks[block_id].resnets[j].record_hidden_state + conv_feat = conv_feat[combined_list] + conv_feat = conv_feat.permute(0, 2, 3, 1).contiguous().reshape(len(combined_list), -1, + conv_feat.shape[1]) + conv_pca_info: dict = step_pca_info['conv'][name] + + loss = self._compute_feat_loss(conv_feat, conv_pca_info, new_cond_control_ids, new_cond_example_ids, + new_cond_appearance_ids, step_i, reg_included=True, + reg_feature=[conv_feat]) + total_loss.append(loss) weight = float(self.guidance_config.pca_guidance.weight) - if self.guidance_config.pca_guidance.warm_up.apply and i < self.guidance_config.pca_guidance.warm_up.end_step: - weight = weight * (i / self.guidance_config.pca_guidance.warm_up.end_step) - elif self.guidance_config.pca_guidance.adaptive.apply: - # TODO: Implement the adaptive weight - weight = weight * (i / self.guidance_config.pca_guidance.adaptive.end_step) + if self.guidance_config.pca_guidance.warm_up.apply and step_i < self.guidance_config.pca_guidance.warm_up.end_step: + weight = weight * (step_i / self.guidance_config.pca_guidance.warm_up.end_step) + loss = torch.stack(total_loss).mean() loss = loss * weight return loss - def compute_pca_loss(self, cond_control_ids, cond_example_ids, cond_appearance_ids, i): + def compute_attn_pca_loss(self, cond_control_ids, cond_example_ids, cond_appearance_ids, step_i): """ Compute the PCA Semantic loss based on the given condition control, example, and appearance IDs. @@ -876,72 +682,66 @@ def compute_pca_loss(self, cond_control_ids, cond_example_ids, cond_appearance_i - cond_control_ids (List[int]): List of control condition IDs. - cond_example_ids (List[int]): List of example condition IDs. - cond_appearance_ids (List[int]): List of appearance condition IDs. - - i (int): The current step or iteration count. + - step_i (int): The current step in the diffusion process. Returns: - torch.Tensor: The computed PCA loss. """ - + # Only save the cond sample, remove the uncond sample to compute loss # The new tensor follows this order: example, control, appearance combined_list = cond_example_ids + cond_control_ids + cond_appearance_ids new_cond_example_ids = np.arange(len(cond_example_ids)).tolist() new_cond_control_ids = np.arange(len(cond_example_ids), len(cond_control_ids) + len(cond_example_ids)).tolist() new_cond_appearance_ids = np.arange(len(cond_control_ids) + len(cond_example_ids), len(combined_list)).tolist() - temp_query_loss = [] - temp_key_loss = [] - step_pca_info: dict = self.loaded_pca_info[i] - # conv_feat = get_hidden_state(self.unet) - # conv_feat = conv_feat[combined_list] - # conv_feat = conv_feat.permute(0, 2, 3, 1).contiguous().reshape(len(combined_list), -1, conv_feat.shape[1]) + pca_loss = [] + step_pca_info: dict = self.loaded_pca_info[step_i] # 1. Loop though all layers to get the query, key, and Compute the PCA loss for name, module in self.unet.named_modules(): module_name = type(module).__name__ - # print(name) if module_name == "Attention" and 'attn1' in name and 'attentions' in name and \ _classify_blocks(self.guidance_config.pca_guidance.blocks, name): - key: torch.Tensor = module.processor.key[combined_list] - query: torch.Tenosr = module.processor.query[combined_list] - value: torch.Tensor = module.processor.value[combined_list] - #print( step_pca_info['attn_query'].keys()) - #query_pca_info: dict = step_pca_info['attn_query'][name] - key_pca_info: dict = step_pca_info['attn_key'][name] - - self.current_step = i - # Compute the PCA loss - # module_query_loss = self._compute_feat_loss(query,query_pca_info,new_cond_control_ids,new_cond_example_ids,new_cond_appearance_ids) - module_key_loss = self._compute_feat_loss(key, key_pca_info, new_cond_control_ids, - new_cond_example_ids, new_cond_appearance_ids, - i, - reg_included=True, reg_feature=[key]) - - # temp_query_loss.append(module_query_loss) - temp_key_loss.append(module_key_loss) - - # temp_key_loss.append(self._compute_feat_loss(key,key_pca_info,new_cond_control_ids,new_cond_example_ids,new_cond_appearance_ids)) - - # if self.config.data.inversion.policy == 'share': - - # print(name) + try: + select_feature = self.guidance_config.pca_guidance.select_feature + except: + select_feature = 'key' + + self.current_step = step_i + # Compute the PCA loss with selected feature + if select_feature == 'key': + key: torch.Tensor = module.processor.key[combined_list] + key_pca_info: dict = step_pca_info['attn_key'][name] + module_pca_loss = self._compute_feat_loss(key, key_pca_info, new_cond_control_ids, + new_cond_example_ids, new_cond_appearance_ids, + step_i, + reg_included=True, reg_feature=[key]) + elif select_feature == 'query': + query: torch.Tenosr = module.processor.query[combined_list] + query_pca_info: dict = step_pca_info['attn_query'][name] + module_pca_loss = self._compute_feat_loss(query, query_pca_info, new_cond_control_ids, + new_cond_example_ids, new_cond_appearance_ids, + step_i, + reg_included=True, reg_feature=[query]) + else: + value: torch.Tensor = module.processor.value[combined_list] + value_pca_info: dict = step_pca_info['attn_value'][name] + module_pca_loss = self._compute_feat_loss(value, value_pca_info, new_cond_control_ids, + new_cond_example_ids, new_cond_appearance_ids, + step_i, + reg_included=True, reg_feature=[value]) + + pca_loss.append(module_pca_loss) # 2. compute pca weight weight = float(self.guidance_config.pca_guidance.weight) - if self.guidance_config.pca_guidance.warm_up.apply and i < self.guidance_config.pca_guidance.warm_up.end_step: - weight = weight * (i / self.guidance_config.pca_guidance.warm_up.end_step) - elif self.guidance_config.pca_guidance.adaptive.apply: - # TODO: Implement the adaptive weight - weight = weight * (i / self.guidance_config.pca_guidance.adaptive.end_step) - query_loss = 0 - key_loss = 0 - # 3. compute the loss - # query_loss = torch.stack(temp_query_loss).mean() - key_loss = torch.stack(temp_key_loss).mean() - loss = query_loss + key_loss - loss = loss * weight - return loss + if self.guidance_config.pca_guidance.warm_up.apply and step_i < self.guidance_config.pca_guidance.warm_up.end_step: + weight = weight * (step_i / self.guidance_config.pca_guidance.warm_up.end_step) + + # 3. compute the PCA loss + pca_loss = torch.stack(pca_loss).mean() * weight + return pca_loss - def compute_cross_attn_loss(self, cond_control_ids, cond_example_ids, cond_appearance_ids, i): - cross_attn_loss = 0 + def compute_cross_attn_mask(self, cond_control_ids, cond_example_ids, cond_appearance_ids): for name, module in self.unet.named_modules(): module_name = type(module).__name__ @@ -958,32 +758,21 @@ def compute_cross_attn_loss(self, cond_control_ids, cond_example_ids, cond_appea if module_name == "Attention" and 'attn2' in name and 'attentions' in name and \ _classify_blocks(self.input_config.guidance.cross_attn.blocks, name): - print(name) # Combine the condition sample for [example, control, appearance], and compute cross-attention weight query = module.processor.query[combined_list] key = module.processor.key[combined_list] - # print(key) - # print(query) - # exit() + query = module.processor.attn.head_to_batch_dim(query).contiguous() key = module.processor.attn.head_to_batch_dim(key).contiguous() attention_mask = module.processor.attention_mask - attention_probs = module.processor.attn.get_attention_scores(query, key, - attention_mask) - + attention_probs = module.processor.attn.get_attention_scores(query, key, attention_mask) source_batch_size = int(attention_probs.shape[0] // len(combined_list)) - - # print(attention_probs.shape) # record the attention probs and update the averaged attention probs reshaped_attention_probs = attention_probs.detach().reshape(len(combined_list), source_batch_size, -1, 77).permute(1, 0, 3, 2) - print("Is NaN in ", name, torch.isnan(reshaped_attention_probs).any()) - assert torch.isnan(reshaped_attention_probs).any() == False, "NaN in attention probs" - # print("Combined list", len(combined_list)) channel_num = reshaped_attention_probs.shape[0] - # print(channel_num) reshaped_attention_probs = reshaped_attention_probs.mean(dim=0) - + # We followed the method in https://arxiv.org/pdf/2210.04885.pdf to compute the cross-attention mask if self.cross_attn_probs['probs'] is None: updated_probs = reshaped_attention_probs else: @@ -992,47 +781,7 @@ def compute_cross_attn_loss(self, cond_control_ids, cond_example_ids, cond_appea self.cross_attn_probs['channels'] + channel_num) self.cross_attn_probs['probs'] = updated_probs.detach() self.cross_attn_probs['channels'] += channel_num - print(self.cross_attn_probs['channels']) - print("__________________________________________________________________________") - print("Cross attn probs", self.cross_attn_probs['probs'].shape) - print("__________________________________________________________________________") - ref_attn_probs = attention_probs[:source_batch_size * len(cond_example_ids)] - control_attn_probs = attention_probs[source_batch_size * len(cond_example_ids): source_batch_size * len( - cond_example_ids + cond_control_ids)] - - # Log the attention mask - - res = int(math.sqrt(attention_probs.shape[1])) - - # soft_scale = self.guidance_cross_attn_config.soft - # if soft_scale > 0: - # ref_attn_probs = ref_attn_probs.permute(0, 2, 1) - # cond_attn_probs = cond_attn_probs.permute(0, 2, 1) - # - # ref_attn_probs = ref_attn_probs.reshape(source_batch_size, -1, res, res) - # cond_attn_probs = cond_attn_probs.reshape(source_batch_size, -1, res, res) - # # Soft attn Probs - # ref_attn_probs = apply_gaussian_filter(ref_attn_probs, - # kernel_size=self.guidance_self_attn_config.kernel_size, - # sigma=self.guidance_self_attn_config.sigma) - # cond_attn_probs = apply_gaussian_filter(cond_attn_probs, - # kernel_size=self.guidance_self_attn_config.kernel_size, - # sigma=self.guidance_self_attn_config.sigma) - # ref_attn_probs = ref_attn_probs.reshape(source_batch_size, -1, res * res) - # cond_attn_probs = cond_attn_probs.reshape(source_batch_size, -1, res * res) - # ref_attn_probs = ref_attn_probs.permute(0, 2, 1) - # cond_attn_probs = cond_attn_probs.permute(0, 2, 1) - - if self.input_config.guidance.cross_attn.obj_only: - # Compute th object level cross-attention loss - for obj_pair_ids in self.record_obj_pairs: - example_ids = obj_pair_ids['ref'] - control_ids = obj_pair_ids['gen'] - ref_attn_probs = ref_attn_probs[:, :, example_ids].max(-1)[0].mean(0) - control_attn_probs = control_attn_probs[:, :, control_ids].max(-1)[0].mean(0) - cross_attn_loss += F.mse_loss(control_attn_probs, ref_attn_probs).detach() - - return cross_attn_loss * self.input_config.guidance.cross_attn.weight + return @torch.no_grad() def invert(self, @@ -1094,14 +843,14 @@ def invert(self, if use_cache: img_data = get_data(out_folder, data_key) - if img_data is not None: return img_data - inv_latents, _, all_latent, prompt_embeds,add_text_embeds, add_time_id = self.ddim_inversion(prompt, image=img, - num_inference_steps=inversion_config.num_inference_steps, - num_reg_steps=0, - return_dict=False) + inv_latents, _, all_latent, prompt_embeds, add_text_embeds, add_time_id = self.ddim_inversion(prompt, + image=img, + num_inference_steps=inversion_config.num_inference_steps, + num_reg_steps=0, + return_dict=False) img_data: Dict = { 'prompt': prompt, 'all_latents': all_latent, @@ -1151,7 +900,7 @@ def prepare_image_latents(self, image, batch_size, dtype, device, generator=None if needs_upcasting: self.upcast_vae() image = image.to(torch.float32) - #image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + # image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) latents = self.vae.encode(image).latent_dist.sample(generator) # cast back to fp16 if needed @@ -1221,16 +970,8 @@ def ddim_inversion( negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, - - - lambda_auto_corr: float = 20.0, - lambda_kl: float = 20.0, - num_reg_steps: int = 5, - num_auto_corr_rolls: int = 5, ): - - # 1. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -1250,7 +991,7 @@ def ddim_inversion( height = 1024 width = 1024 original_size = original_size or (height, width) - target_size = target_size or (height, width) + target_size = target_size or (height, width) # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -1272,11 +1013,10 @@ def ddim_inversion( # 4. Prepare latent variables latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator).to(torch.float16) - # print(latents) - # exit() + # 5. Encode input prompt num_images_per_prompt = 1 - # 3. Encode input prompt + # 6. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) @@ -1304,14 +1044,18 @@ def ddim_inversion( self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.inverse_scheduler.timesteps - # 6. Rejig the UNet so that we can obtain the cross-attenion maps and - # use them for guiding the subsequent image generation. - self.unet = prepare_unet(self.unet) - # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids = self._get_add_time_ids( - original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( @@ -1319,6 +1063,7 @@ def ddim_inversion( negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids @@ -1363,7 +1108,6 @@ def ddim_inversion( return_dict=False, )[0] - print(noise_pred) # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) @@ -1373,39 +1117,9 @@ def ddim_inversion( # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) - # regularization of the noise prediction - with torch.enable_grad(): - for _ in range(num_reg_steps): - if lambda_auto_corr > 0: - for _ in range(num_auto_corr_rolls): - var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) - - # Derive epsilon from model output before regularizing to IID standard normal - var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) - - l_ac = self.auto_corr_loss(var_epsilon, generator=generator) - l_ac.backward() - - grad = var.grad.detach() / num_auto_corr_rolls - noise_pred = noise_pred - lambda_auto_corr * grad - - if lambda_kl > 0: - var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) - - # Derive epsilon from model output before regularizing to IID standard normal - var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) - - l_kld = self.kl_divergence(var_epsilon) - l_kld.backward() - - grad = var.grad.detach() - noise_pred = noise_pred - lambda_kl * grad - - noise_pred = noise_pred.detach() # compute the previous noisy sample x_t -> x_t-1 latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample - print(latents) assert not torch.isnan(latents).any(), "NaN in latents" all_latents[timestep_key] = latents.detach().cpu() @@ -1420,23 +1134,19 @@ def ddim_inversion( inverted_latents = latents.detach().clone() - # # 8. Post-processing - # image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - # image = self.image_processor.postprocess(image, output_type=output_type) - # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: - return (inverted_latents, None, all_latents, prompt_embeds.detach().cpu(),add_text_embeds.detach().cpu(),add_time_ids.detach().cpu()) + return (inverted_latents, None, all_latents, prompt_embeds.detach().cpu(), add_text_embeds.detach().cpu(), + add_time_ids.detach().cpu()) return None - @torch.no_grad() - def pca_visulization(self, + def sample_semantic_bases(self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, @@ -1467,504 +1177,28 @@ def pca_visulization(self, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, - - # PCA Sampling parameters + # FreeControl Stage1 parameters num_batch: int = 1, config: omegaconf.dictconfig = None, - mask_obj: str = "", - mask_tr: float = 0.5, - num_save_basis: int = 128, - num_save_steps: int = 300, + num_save_basis: int = 64, + num_save_steps: int = 120, ): - r""" - The call function to the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - A higher guidance scale value encourages the model to generate images closely linked to the text - `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide what to not include in image generation. If not defined, you need to - pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies - to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor is generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, text embeddings are generated from the `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated image. Choose between `PIL.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in - [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - guidance_rescale (`float`, *optional*, defaults to 0.7): - Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when - using zero terminal SNR. - - Examples: - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, - otherwise a `tuple` is returned where the first element is a list with the generated images and the - second element is a list of `bool`s indicating whether the corresponding generated image contains - "not-safe-for-work" (nsfw) content. """ - # 0. Prepare the UNet - self.unet = prep_unet(self.unet) - self.unet = prep_conv_layer(self.unet) - self.sampling_config: omegaconf.dictconfig = config - self.pca_info: Dict = dict() - - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - # Use the same height and width for original and target size - original_size = (height, width) - target_size = (height, width) - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - prompt_2, - height, - width, - callback_steps, - negative_prompt, - negative_prompt_2, - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - # compute the ids of the selected object - try: - ids, _ = compute_token_merge_indices(self.tokenizer, prompt, mask_obj) - except: - print(f"Selected object {mask_obj} not found in the prompt {prompt}") - ids = None - - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.encode_prompt( - prompt=prompt, - prompt_2=prompt_2, - device=device, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, - negative_prompt=negative_prompt, - negative_prompt_2=negative_prompt_2, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - lora_scale=text_encoder_lora_scale, - ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 5. Prepare latent variables - num_channels_latents = self.unet.config.in_channels - all_latents = self.prepare_latents( - batch_size * num_images_per_prompt * num_batch, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 7. Prepare added time ids & embeddings - add_text_embeds = pooled_prompt_embeds - add_time_ids = self._get_add_time_ids( - original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype - ) - if negative_original_size is not None and negative_target_size is not None: - negative_add_time_ids = self._get_add_time_ids( - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype=prompt_embeds.dtype, - ) - else: - negative_add_time_ids = add_time_ids - - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) - - prompt_embeds = prompt_embeds.to(device) - add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) - - - latent_list = list(all_latents.chunk(num_batch, dim=0)) - latent_list_copy = latent_list[:] - - seg_maps = dict() - fixed_size = (int(128), int(128)) - - # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - # 7.1 Apply denoising_end - if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: - discrete_timestep_cutoff = int( - round( - self.scheduler.config.num_train_timesteps - - (denoising_end * self.scheduler.config.num_train_timesteps) - ) - ) - num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) - timesteps = timesteps[:num_inference_steps] - - latent_list = latent_list_copy[:] - # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps * num_batch) as progress_bar: - for i, t in enumerate(timesteps): - conv_hidden_state_list = [] - attn_hidden_state_dict = dict() - attn_query_dict = dict() - attn_key_dict = dict() - attn_value_dict = dict() - for latent_id, latents in enumerate(latent_list): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - if do_classifier_free_guidance and guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, - guidance_rescale=guidance_rescale) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - latent_list[latent_id] = latents - - # 8. Post-processing the pca features - # get hidden feature - conv_hidden_feature = get_hidden_state(self.unet).detach().cpu() - # print(conv_hidden_feature.shape) - conv_hidden_feature = conv_hidden_feature.chunk(2)[1] - conv_hidden_state_list.append(conv_hidden_feature) - - # Get self attention features - hidden_state_dict, query_dict, key_dict, value_dict = get_selt_attn_feat_info(self.unet, - self.sampling_config.guidance.pca_guidance) - for name in hidden_state_dict.keys(): - def log_to_dict(feat, selected_dict, name): - feat = feat.chunk(2)[1] - if name in selected_dict.keys(): - selected_dict[name].append(feat) - else: - selected_dict[name] = [feat] - - log_to_dict(hidden_state_dict[name], attn_hidden_state_dict, name) - log_to_dict(key_dict[name], attn_key_dict, name) - log_to_dict(query_dict[name], attn_query_dict, name) - log_to_dict(value_dict[name], attn_value_dict, name) - all_mask = None - def apply_pca(feat): - with torch.autocast(device_type='cuda', dtype=torch.float32): - feat = feat.contiguous().to(torch.float32) - # feat shape in [bs,channels,16,16] - bs, channels, h, w = feat.shape - if ids is not None and all_mask is not None: - temp_mask = F.interpolate(all_mask.unsqueeze(1).float(), size=(h, w), - mode='nearest').repeat(1, - channels, - 1, - 1).bool().to( - feat.device).permute(0, 2, 3, 1).reshape(-1, channels) - - feat = feat.permute(0, 2, 3, 1).reshape(-1, channels)[temp_mask].reshape(-1, channels).to( - 'cuda') - X = feat - else: - # No mask will be applied - X = feat.permute(0, 2, 3, 1).reshape(-1, channels).to('cuda') - # print(feat.shape) - mean = X.mean(dim=0) - tensor_centered = X - mean - U, S, V = torch.svd(tensor_centered) - n_egv = V.shape[-1] - - if n_egv > num_save_basis: - V = V[:, :num_save_basis] - basis = V.T - X_pca = torch.mm(tensor_centered, basis.T).contiguous() - score = X_pca.view(bs, h, w, -1).permute(0, 3, 1, 2) - score = score[:,:3] - - # if not mean.shape[-1] == basis.shape[-1]: - # print(mean.shape, basis.shape, X.shape, V.shape, score.shape, tensor_centered.shape, ) - - assert mean.shape[-1] == basis.shape[-1] - - return { - "score" : score.detach().cpu(), - } - - def process_feat_dict(feat_dict): - for name in feat_dict.keys(): - feat_dict[name] = torch.cat(feat_dict[name], dim=0) - feat_dict[name] = apply_pca(feat_dict[name]) - # print(feat_dict[name].shape) - - # Only process for the first num_save_steps - if i < num_save_steps: - #process_feat_dict(attn_hidden_state_dict) - process_feat_dict(attn_query_dict) - process_feat_dict(attn_key_dict) - #process_feat_dict(attn_value_dict) - - # conv_hidden_state_list = torch.cat(conv_hidden_state_list, dim=0) - # conv_hidden_state_info = apply_pca(conv_hidden_state_list) - - self.pca_info[i] = { - # 'conv': conv_hidden_state_info, - # 'attn_hidden_state': attn_hidden_state_dict, - 'attn_key': attn_key_dict, - 'attn_query': attn_query_dict, - # 'attn_value': attn_value_dict, - } - #print(self.pca_info.keys()) - - if not output_type == "latent": - # make sure the VAE is in float32 mode, as it overflows in float16 - needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast - - if needs_upcasting: - self.upcast_vae() - img_list = [] - for latent_id, latents in enumerate(latent_list): - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - img_list.append(image) - image = torch.cat(img_list, dim=0) - - # cast back to fp16 if needed - if needs_upcasting: - self.vae.to(dtype=torch.float16) - - else: - image = latents - - if not output_type == "latent": - # apply watermark if available - if self.watermark is not None: - image = self.watermark.apply_watermark(image) - - image = self.image_processor.postprocess(image, output_type=output_type) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (image,) - - return StableDiffusionPipelineOutput(images=image) - - - - - @torch.no_grad() - def sample_pca_components(self, - prompt: Union[str, List[str]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - denoising_end: Optional[float] = None, - guidance_scale: float = 5.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - negative_prompt_2: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - guidance_rescale: float = 0.0, - original_size: Optional[Tuple[int, int]] = None, - crops_coords_top_left: Tuple[int, int] = (0, 0), - target_size: Optional[Tuple[int, int]] = None, - negative_original_size: Optional[Tuple[int, int]] = None, - negative_crops_coords_top_left: Tuple[int, int] = (0, 0), - negative_target_size: Optional[Tuple[int, int]] = None, - - - # PCA Sampling parameters - num_batch: int = 1, - config: omegaconf.dictconfig = None, - - mask_obj: str = "", - mask_tr: float = 0.5, - num_save_basis: int = 128, - num_save_steps: int = 300, - save_img=False, - ): - r""" - The call function to the pipeline for generation. + The sample_pca_components function to the pipeline to generate semantic bases. Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - A higher guidance scale value encourages the model to generate images closely linked to the text - `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide what to not include in image generation. If not defined, you need to - pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies - to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor is generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, text embeddings are generated from the `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated image. Choose between `PIL.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in - [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - guidance_rescale (`float`, *optional*, defaults to 0.7): - Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when - using zero terminal SNR. - - Examples: + # Default parameters please check call method for more details + # FreeControl Stage1 parameters + num_batch: int = 1, The number of batches to generate. + The number of seed images generated will be num_batch * num_save_basis + config: omegaconf.dictconfig = None, The config file for the pipeline + num_save_basis : int = 64, The number of leading PC to save + num_save_steps: int = 120, The number of steps to save the semantic bases - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, - otherwise a `tuple` is returned where the first element is a list with the generated images and the - second element is a list of `bool`s indicating whether the corresponding generated image contains - "not-safe-for-work" (nsfw) content. """ # 0. Prepare the UNet - self.unet = prep_unet(self.unet) - self.unet = prep_conv_layer(self.unet) + self.unet = prep_unet_attention(self.unet) + self.unet = prep_unet_conv(self.unet) self.sampling_config: omegaconf.dictconfig = config self.pca_info: Dict = dict() @@ -1999,16 +1233,6 @@ def sample_pca_components(self, else: batch_size = prompt_embeds.shape[0] - # compute the ids of the selected object - try: - ids, _ = compute_token_merge_indices(self.tokenizer, prompt, mask_obj) - except: - print(f"Selected object {mask_obj} not found in the prompt {prompt}") - ids = None - - if mask_tr == 0: - ids = None - device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -2064,8 +1288,16 @@ def sample_pca_components(self, # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids = self._get_add_time_ids( - original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( @@ -2073,6 +1305,7 @@ def sample_pca_components(self, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids @@ -2086,7 +1319,6 @@ def sample_pca_components(self, add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) - latent_list = list(all_latents.chunk(num_batch, dim=0)) latent_list_copy = latent_list[:] @@ -2106,120 +1338,20 @@ def sample_pca_components(self, num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] - if ids is not None: - with (self.progress_bar(total=num_inference_steps * num_batch) as progress_bar): - for i, t in enumerate(timesteps): - for latent_id, latents in enumerate(latent_list): - - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - if do_classifier_free_guidance and guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, - guidance_rescale=guidance_rescale) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - # call the callback, if provided - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - for name, module in self.unet.named_modules(): - module_name = type(module).__name__ - if module_name == "Attention" and 'attn2' in name and 'attentions' in name: - # loop though all the seg maps - - key = module.processor.key - key = key[int(key.size(0) / 2):] - query = module.processor.query - query = query[int(query.size(0) / 2):] - - num_samples = key.size(0) - query2 = module.processor.attn.head_to_batch_dim(query).contiguous() - key2 = module.processor.attn.head_to_batch_dim(key).contiguous() - attention_mask = module.processor.attention_mask - attention_probs = module.processor.attn.get_attention_scores(query2, key2, - attention_mask) - source_batch_size = int(attention_probs.shape[0] / num_samples) - res = int(np.sqrt(attention_probs.shape[1])) - attention_probs = attention_probs.permute(0, 2, 1).reshape( - source_batch_size * num_samples, - -1, res, res, ) - attention_probs = attention_probs.to(torch.float32) - reshaped_attn_probs = F.interpolate(attention_probs, size=fixed_size, - mode='bicubic').clamp_(min=0).reshape(num_samples, - source_batch_size, - -1, - fixed_size[0], - fixed_size[1]) - # reshaped_attn_probs = reshaped_attn_probs[:, :, 1:-1] - # print(reshaped_attn_probs.shape) - # print(reshaped_attn_probs.shape) - if latent_id in seg_maps.keys(): - seg_maps[latent_id]['latent'] = (seg_maps[latent_id]['latent'] * - seg_maps[latent_id][ - 'num_channels'] + reshaped_attn_probs.mean( - dim=1) * source_batch_size) / ( - seg_maps[latent_id]['num_channels'] + - reshaped_attn_probs.shape[1]) - seg_maps[latent_id]['num_channels'] += reshaped_attn_probs.shape[1] - else: - - seg_maps[latent_id] = { - 'latent': reshaped_attn_probs.mean(dim=1), - 'num_channels': reshaped_attn_probs.shape[1], - } - latent_list[latent_id] = latents - - # process masks - for key in seg_maps.keys(): - feat = seg_maps[key]['latent'][:, ids].clone() - feat = feat.max(dim=1)[0] - feat = (feat - feat.min()) / (feat.max() - feat.min() + 1e-7) - seg_maps[key] = feat > mask_tr - - all_mask = torch.cat([seg_maps[key] for key in seg_maps.keys()], dim=0) - self.seg_maps = seg_maps - latent_list = latent_list_copy[:] - # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps * num_batch) as progress_bar: + pbar_steps = min(num_save_steps, num_inference_steps) + with self.progress_bar(total=pbar_steps * num_batch) as progress_bar: for i, t in enumerate(timesteps): - if i >= num_save_steps and not save_img: + if i >= num_save_steps: break - conv_hidden_state_list = [] - attn_hidden_state_dict = dict() - attn_query_dict = dict() + # create dict to store the hidden features attn_key_dict = dict() - attn_value_dict = dict() + for latent_id, latents in enumerate(latent_list): - # mask = seg_maps[latent_id] # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} noise_pred = self.unet( latent_model_input, @@ -2237,8 +1369,7 @@ def sample_pca_components(self, if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, - guidance_rescale=guidance_rescale) + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] @@ -2251,15 +1382,8 @@ def sample_pca_components(self, latent_list[latent_id] = latents # 8. Post-processing the pca features - # get hidden feature - # conv_hidden_feature = get_hidden_state(self.unet).detach().cpu() - # # print(conv_hidden_feature.shape) - # conv_hidden_feature = conv_hidden_feature.chunk(2)[1] - # conv_hidden_state_list.append(conv_hidden_feature) - - # Get self attention features - hidden_state_dict, query_dict, key_dict, value_dict = get_selt_attn_feat_info(self.unet, - self.sampling_config.guidance.pca_guidance) + hidden_state_dict, query_dict, key_dict, value_dict = get_self_attn_feat(self.unet, + self.sampling_config.guidance.pca_guidance) for name in hidden_state_dict.keys(): def log_to_dict(feat, selected_dict, name): feat = feat.chunk(2)[1] @@ -2268,48 +1392,26 @@ def log_to_dict(feat, selected_dict, name): else: selected_dict[name] = [feat] - #log_to_dict(hidden_state_dict[name], attn_hidden_state_dict, name) log_to_dict(key_dict[name], attn_key_dict, name) - log_to_dict(query_dict[name], attn_query_dict, name) - #log_to_dict(value_dict[name], attn_value_dict, name) def apply_pca(feat): with torch.autocast(device_type='cuda', dtype=torch.float32): feat = feat.contiguous().to(torch.float32) # feat shape in [bs,channels,16,16] bs, channels, h, w = feat.shape - if ids is not None: - temp_mask = F.interpolate(all_mask.unsqueeze(1).float(), size=(h, w), - mode='nearest').repeat(1, - channels, - 1, - 1).bool().to( - feat.device).permute(0, 2, 3, 1).reshape(-1, channels) - - feat = feat.permute(0, 2, 3, 1).reshape(-1, channels)[temp_mask].reshape(-1, channels).to( - 'cuda') - X = feat + if feat.ndim == 4: + X = feat.permute(0, 2, 3, 1).reshape(-1, channels).to('cuda') else: - if feat.ndim == 4: - X = feat.permute(0, 2, 3, 1).reshape(-1, channels).to('cuda') - else: - # No mask will be applied - X = feat.permute(0, 2, 1).reshape(-1, channels).to('cuda') - # print(feat.shape) + X = feat.permute(0, 2, 1).reshape(-1, channels).to('cuda') + # Computing PCA mean = X.mean(dim=0) tensor_centered = X - mean U, S, V = torch.svd(tensor_centered) n_egv = V.shape[-1] - if n_egv > num_save_basis: + if n_egv > num_save_basis and num_save_basis > 0: V = V[:, :num_save_basis] basis = V.T - # X_pca = torch.mm(tensor_centered, basis.T).contiguous() - # score = X_pca.view(bs, h, w, -1).permute(0, 3, 1, 2) - - # if not mean.shape[-1] == basis.shape[-1]: - # print(mean.shape, basis.shape, X.shape, V.shape, score.shape, tensor_centered.shape, ) - assert mean.shape[-1] == basis.shape[-1] return { @@ -2321,63 +1423,13 @@ def process_feat_dict(feat_dict): for name in feat_dict.keys(): feat_dict[name] = torch.cat(feat_dict[name], dim=0) feat_dict[name] = apply_pca(feat_dict[name]) - # print(feat_dict[name].shape) # Only process for the first num_save_steps if i < num_save_steps: - # process_feat_dict(attn_hidden_state_dict) - process_feat_dict(attn_query_dict) process_feat_dict(attn_key_dict) - # process_feat_dict(attn_value_dict) - - # conv_hidden_state_list = torch.cat(conv_hidden_state_list, dim=0) - # conv_hidden_state_info = apply_pca(conv_hidden_state_list) self.pca_info[i] = { - # 'conv': conv_hidden_state_info, - # 'attn_hidden_state': attn_hidden_state_dict, 'attn_key': attn_key_dict, - #'attn_query': attn_query_dict, - # 'attn_value': attn_value_dict, } - else: - break - print(self.pca_info.keys()) - - if not output_type == "latent": - # make sure the VAE is in float32 mode, as it overflows in float16 - needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast - - if needs_upcasting: - self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - - - img_list = [] - for latent_id, latents in enumerate(latent_list): - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - img_list.append(image) - image = torch.cat(img_list, dim=0) - - # cast back to fp16 if needed - if needs_upcasting: - self.vae.to(dtype=torch.float16) - - else: - image = latents - - if not output_type == "latent": - # apply watermark if available - if self.watermark is not None: - image = self.watermark.apply_watermark(image) - - image = self.image_processor.postprocess(image, output_type=output_type) - - # # Offload all models - # self.maybe_free_model_hooks() - - if not return_dict: - return (image,) - - return StableDiffusionPipelineOutput(images=image) + return self.pca_info \ No newline at end of file diff --git a/pca_viz.py b/pca_viz.py deleted file mode 100644 index b296db0..0000000 --- a/pca_viz.py +++ /dev/null @@ -1,140 +0,0 @@ -import argparse -import os -import time -import yaml - -import numpy as np -import torch -from omegaconf import OmegaConf -from PIL import Image - -from libs.model import make_pipeline -from libs.model.module.scheduler import CustomDDIMScheduler -from viz.processor import make_processor -import torchvision.transforms.functional as F - - -def concat_images_and_tensors(images, tensors): - # validate inputs - if not (isinstance(images, list) and all(isinstance(img, Image.Image) for img in images)): - raise TypeError("images must be a list of PIL.Image.Image objects") - if not (isinstance(tensors, torch.Tensor) and tensors.dim() == 4): - raise TypeError("tensors must be a 4-dimensional torch.Tensor") - if len(images) != tensors.size(0): - raise ValueError("The length of images and the first dimension of tensors must be the same") - - # normalize and resize - tensors = (tensors - tensors.min()) / (tensors.max() - tensors.min()) - tensors = F.resize(tensors, (512, 512), interpolation=Image.NEAREST) - images = [img.resize((512, 512), resample=Image.BILINEAR) for img in images] - - # combine all images and visualizations - image_row = Image.new("RGB", (512 * len(images), 512)) - for i, img in enumerate(images): - image_row.paste(img, (512 * i, 0)) - - tensor_row = Image.new("RGB", (512 * len(images), 512)) - for i, tensor in enumerate(tensors): - tensor_row.paste(F.to_pil_image(tensor), (512 * i, 0)) - - final_image = Image.new("RGB", (512 * len(images), 512 * 2)) - final_image.paste(image_row, (0, 0)) - final_image.paste(tensor_row, (0, 512)) - - return final_image - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument("--model", type=str, default="1.5", help="Diffusion model name") - parser.add_argument("--pca-path", type=str, help="path to semantic bases from PCA") - parser.add_argument("--img-path", type=str, help="Image path") - parser.add_argument("--output-dir", type=str, help="Output directory") - parser.add_argument("--img-type", type=str, default="rgb", help="Image type") - parser.add_argument("--inv-prompt", type=str, help="Text prompt for inversion") - parser.add_argument("--gen-prompt", type=str, help="Text prompt for generation") - parser.add_argument("--object", type=str, help="Object type") - args = parser.parse_args() - - # currently only support 1.5 and 2.1 base - if args.model == "1.5": - model_name = "sd-legacy/stable-diffusion-v1-5" - elif args.model == "2.1_base": - model_name = "stabilityai/stable-diffusion-2-1-base" - else: - raise ValueError(f"Model {args.model} currently not supported.") - - # load configs - config = yaml.load(open("config/base.yaml", "r"), Loader=yaml.FullLoader) - assert os.path.exists(args.pca_path) - config["sd_config"]["pca_paths"] = [args.pca_path] - config["data"]["inversion"] = { - "target_folder": "dataset/latent", - "num_inference_steps": 999, - "method": "DDIM", - "fixed_size": [512, 512], - "prompt": args.inv_prompt, - "select_objects": args.object, - "policy": "share", - "sd_model": f"{args.model}_naive", - } - config = OmegaConf.create(config) - - # load pipeline - pipeline = make_pipeline( - "SDPipeline", - model_name, - safetensors=False, - safety_checker=None, - torch_dtype=torch.float16 - ).to("cuda") - pipeline.enable_xformers_memory_efficient_attention() - pipeline.scheduler = CustomDDIMScheduler.from_pretrained( - model_name, - subfolder="scheduler", - ) - - # load image - if not args.img_type != "rgb": - processor = make_processor(args.img_type) - else: - processor = lambda x: Image.open(x).convert("RGB") - img_name = ".".join(os.path.basename(args.img_path).split(".")[:-1]) - img = processor(args.img_path) - if args.img_type in ("scribble", "canny"): - img = Image.fromarray(255 - np.array(img)) - - # run inversion to generate features - start_time = time.time() - data_samples_pose = pipeline.invert(img=img, inversion_config=config.data.inversion) - print(f"Time elapsed: {(time.time() - start_time):.2f} seconds") - - # project onto semantic bases from PCA - data_samples = { - "examplar": [data_samples_pose], - "appearance": None, - } - g = torch.Generator() - g.manual_seed(2094) - pca_dict = pipeline.compute_score( - prompt=args.gen_prompt, - num_inference_steps=50, - generator=g, - config=config, - data_samples=data_samples, - ) - - # save visualization - image_list = [data_samples_pose["pil_img"]] - root_dir = os.path.join(args.output_dir, f"{img_name}_{args.img_type}") - for key, value in pca_dict.items(): - step = key - for feat_name in value.keys(): - for block_name in value[feat_name].keys(): - folder_name = os.path.join(root_dir, feat_name, block_name) - if not os.path.exists(folder_name): - os.makedirs(folder_name,exist_ok=True) - score = value[feat_name][block_name]["score"] - final_img = concat_images_and_tensors(image_list, score) - final_img.save(os.path.join(folder_name, f"{str(step)}.png")) - - print("Success") diff --git a/sample_semantic_bases.py b/sample_semantic_bases.py index 6519537..1212ad6 100755 --- a/sample_semantic_bases.py +++ b/sample_semantic_bases.py @@ -2,6 +2,8 @@ import os import torch +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True import yaml from omegaconf import OmegaConf @@ -23,12 +25,16 @@ def main(args): config = yaml.load(open(args.config_path, "r"), Loader=yaml.FullLoader) config = OmegaConf.create(config) - pipeline_name = "SDPipeline" + if 'XL' in args.sd_version: + pipeline_name = "SDXLPipeline" + else: + pipeline_name = "SDPipeline" pipeline = make_pipeline(pipeline_name, model_path, torch_dtype=torch.float16 ).to('cuda') pipeline.enable_xformers_memory_efficient_attention() + pipeline.enable_sequential_cpu_offload() pipeline.scheduler = CustomDDIMScheduler.from_pretrained(model_path, subfolder="scheduler") g = torch.Generator() g.manual_seed(args.seed) diff --git a/scripts/example_sample_car_xl_1.0.sh b/scripts/example_sample_car_xl_1.0.sh index 42549c9..78331cd 100644 --- a/scripts/example_sample_car_xl_1.0.sh +++ b/scripts/example_sample_car_xl_1.0.sh @@ -8,7 +8,9 @@ SEED=28988 NUM_STEPS=200 OUTPUT_CLASS="car" -python sample_semantic_basis.py --prompt "${PROMPT}" \ +python sample_semantic_bases.py \ +--config_path "config/sdxl_base.yaml" \ +--prompt "${PROMPT}" \ --negative_prompt "${NEGATIVE_PROMPT}" \ --sd_version ${SD_VERSION} \ --model_name ${MODEL_NAME} \ @@ -17,7 +19,6 @@ python sample_semantic_basis.py --prompt "${PROMPT}" \ --seed ${SEED} \ --num_steps ${NUM_STEPS} \ --output_class ${OUTPUT_CLASS} \ ---mask_obj ${OUTPUT_CLASS} \ --num_images 5 \ ---num_batch 1 \ +--num_batch 4 \ --log \ \ No newline at end of file diff --git a/viz/__init__.py b/viz/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/viz/processor.py b/viz/processor.py deleted file mode 100644 index 9ab4fcc..0000000 --- a/viz/processor.py +++ /dev/null @@ -1,357 +0,0 @@ -import copy -from abc import ABC, abstractmethod - -import cv2 -import numpy as np -import torch -from diffusers.utils import load_image -from PIL import Image -from transformers import pipeline - -# pip install -U controlnet-aux -from controlnet_aux import OpenposeDetector -from controlnet_aux import HEDdetector -from controlnet_aux import MLSDdetector - - -palette = np.asarray([ - [0, 0, 0], - [120, 120, 120], - [180, 120, 120], - [6, 230, 230], - [80, 50, 50], - [4, 200, 3], - [120, 120, 80], - [140, 140, 140], - [204, 5, 255], - [230, 230, 230], - [4, 250, 7], - [224, 5, 255], - [235, 255, 7], - [150, 5, 61], - [120, 120, 70], - [8, 255, 51], - [255, 6, 82], - [143, 255, 140], - [204, 255, 4], - [255, 51, 7], - [204, 70, 3], - [0, 102, 200], - [61, 230, 250], - [255, 6, 51], - [11, 102, 255], - [255, 7, 71], - [255, 9, 224], - [9, 7, 230], - [220, 220, 220], - [255, 9, 92], - [112, 9, 255], - [8, 255, 214], - [7, 255, 224], - [255, 184, 6], - [10, 255, 71], - [255, 41, 10], - [7, 255, 255], - [224, 255, 8], - [102, 8, 255], - [255, 61, 6], - [255, 194, 7], - [255, 122, 8], - [0, 255, 20], - [255, 8, 41], - [255, 5, 153], - [6, 51, 255], - [235, 12, 255], - [160, 150, 20], - [0, 163, 255], - [140, 140, 140], - [250, 10, 15], - [20, 255, 0], - [31, 255, 0], - [255, 31, 0], - [255, 224, 0], - [153, 255, 0], - [0, 0, 255], - [255, 71, 0], - [0, 235, 255], - [0, 173, 255], - [31, 0, 255], - [11, 200, 200], - [255, 82, 0], - [0, 255, 245], - [0, 61, 255], - [0, 255, 112], - [0, 255, 133], - [255, 0, 0], - [255, 163, 0], - [255, 102, 0], - [194, 255, 0], - [0, 143, 255], - [51, 255, 0], - [0, 82, 255], - [0, 255, 41], - [0, 255, 173], - [10, 0, 255], - [173, 255, 0], - [0, 255, 153], - [255, 92, 0], - [255, 0, 255], - [255, 0, 245], - [255, 0, 102], - [255, 173, 0], - [255, 0, 20], - [255, 184, 184], - [0, 31, 255], - [0, 255, 61], - [0, 71, 255], - [255, 0, 204], - [0, 255, 194], - [0, 255, 82], - [0, 10, 255], - [0, 112, 255], - [51, 0, 255], - [0, 194, 255], - [0, 122, 255], - [0, 255, 163], - [255, 153, 0], - [0, 255, 10], - [255, 112, 0], - [143, 255, 0], - [82, 0, 255], - [163, 255, 0], - [255, 235, 0], - [8, 184, 170], - [133, 0, 255], - [0, 255, 92], - [184, 0, 255], - [255, 0, 31], - [0, 184, 255], - [0, 214, 255], - [255, 0, 112], - [92, 255, 0], - [0, 224, 255], - [112, 224, 255], - [70, 184, 160], - [163, 0, 255], - [153, 0, 255], - [71, 255, 0], - [255, 0, 163], - [255, 204, 0], - [255, 0, 143], - [0, 255, 235], - [133, 255, 0], - [255, 0, 235], - [245, 0, 255], - [255, 0, 122], - [255, 245, 0], - [10, 190, 212], - [214, 255, 0], - [0, 204, 255], - [20, 0, 255], - [255, 255, 0], - [0, 153, 255], - [0, 41, 255], - [0, 255, 204], - [41, 0, 255], - [41, 255, 0], - [173, 0, 255], - [0, 245, 255], - [71, 0, 255], - [122, 0, 255], - [0, 255, 184], - [0, 92, 255], - [184, 255, 0], - [0, 133, 255], - [255, 214, 0], - [25, 194, 194], - [102, 255, 0], - [92, 0, 255], -]) - - -processor = {} -def register_processor(name): - def decorator(fn): - processor[name] = fn - return fn - return decorator - - -def make_processor(name): - return processor[name]() - - -class Processor(ABC): - @abstractmethod - def __call__(self, img_path): - pass - @abstractmethod - def get_controlnet_id(self): - pass - - -@register_processor("openpose") -class Openposedetector(Processor): - - def get_controlnet_id(self): - return "lllyasviel/sd-controlnet-openpose" - - def __call__(self, img_path): - openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") - image = load_image(img_path) - image = openpose(image) - condition_image = image - return condition_image - - -@register_processor("depth") -class DepthMapDetector(Processor): - - def get_controlnet_id(self): - return "lllyasviel/sd-controlnet-depth" - - def __call__(self, img_path): - depth_estimator = pipeline('depth-estimation') - image = load_image(img_path) - image = depth_estimator(image)['depth'] - image = np.array(image) - image = image[:, :, None] - image = np.concatenate([image, image, image], axis=2) - image = Image.fromarray(image) - - condition_image = copy.deepcopy(image) - return condition_image - - -@register_processor("hed") -class HEDDetector(Processor): - - def get_controlnet_id(self): - return "lllyasviel/sd-controlnet-hed" - - def __call__(self, img_path): - hed = HEDdetector.from_pretrained("lllyasviel/ControlNet") - image = load_image(img_path) - image = hed(image) - condition_image = image - return condition_image - - -@register_processor("mlsd") -class MLSDDetector(Processor): - - def get_controlnet_id(self): - return "lllyasviel/sd-controlnet-mlsd" - - def __call__(self, img_path): - mlsd = MLSDdetector.from_pretrained("lllyasviel/ControlNet") - image = load_image(img_path) - image = mlsd(image) - condition_image = image - return condition_image - - -@register_processor("canny") -class CannyEdgeDetector(Processor): - - def get_controlnet_id(self): - return "lllyasviel/sd-controlnet-canny" - - def __call__(self,img_path): - image = load_image(img_path) - image = np.array(image) - - low_threshold = 100 - high_threshold = 200 - - image = cv2.Canny(image, low_threshold, high_threshold) - image = image[:, :, None] - image = np.concatenate([image, image, image], axis=2) - image = Image.fromarray(image) - - condition_image = copy.deepcopy(image) - return condition_image - - -@register_processor("normal") -class NormalMapDetector(Processor): - - def get_controlnet_id(self): - return "fusing/stable-diffusion-v1-5-controlnet-normal" - - def __call__(self, img_path): - image = load_image(img_path) - depth_estimator = pipeline("depth-estimation", model="Intel/dpt-hybrid-midas") - - image = depth_estimator(image)['predicted_depth'][0] - - image = image.numpy() - - image_depth = image.copy() - image_depth -= np.min(image_depth) - image_depth /= np.max(image_depth) - - bg_threhold = 0.4 - - x = cv2.Sobel(image, cv2.CV_32F, 1, 0, ksize=3) - x[image_depth < bg_threhold] = 0 - - y = cv2.Sobel(image, cv2.CV_32F, 0, 1, ksize=3) - y[image_depth < bg_threhold] = 0 - - z = np.ones_like(x) * np.pi * 2.0 - - image = np.stack([x, y, z], axis=2) - image /= np.sum(image ** 2.0, axis=2, keepdims=True) ** 0.5 - image = (image * 127.5 + 127.5).clip(0, 255).astype(np.uint8) - image = Image.fromarray(image) - condition_image = copy.deepcopy(image) - return condition_image - - -@register_processor("scribble") -class ScribbleDetector(Processor): - - def get_controlnet_id(self): - return "lllyasviel/sd-controlnet-scribble" - - def __call__(self, img_path): - hed = HEDdetector.from_pretrained('lllyasviel/ControlNet') - image = load_image(img_path) - image = hed(image, scribble=True) - fliped_image = Image.fromarray(255 - np.array(image)) - condition_image = copy.deepcopy(image) - return condition_image - - -@register_processor("seg") -class SegMapDetector(Processor): - - def get_controlnet_id(self): - return "lllyasviel/sd-controlnet-seg" - - def __call__(self, img_path): - from transformers import AutoImageProcessor, UperNetForSemanticSegmentation - image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small") - image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small") - - image = load_image(img_path) - pixel_values = image_processor(image, return_tensors="pt").pixel_values - with torch.no_grad(): - outputs = image_segmentor(pixel_values) - - seg = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] - - color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3 - - for label, color in enumerate(palette): - color_seg[seg == label, :] = color - - for label, color in enumerate(palette): - color_seg[seg == label, :] = color - - color_seg = color_seg.astype(np.uint8) - - image = Image.fromarray(color_seg) - condition_image = copy.deepcopy(image) - return condition_image