diff --git a/helpers/caching/text_embeds.py b/helpers/caching/text_embeds.py index 5fae6cb5..f35e9350 100644 --- a/helpers/caching/text_embeds.py +++ b/helpers/caching/text_embeds.py @@ -599,13 +599,18 @@ def compute_t5_prompt(self, prompt: str): return result, attn_mask - def compute_gemma_prompt(self, prompt: str): + def compute_gemma_prompt(self, prompt: str, is_negative_prompt: bool): prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt( prompt=prompt, do_classifier_free_guidance=False, device=self.accelerator.device, clean_caption=False, max_sequence_length=300, + complex_human_instruction=( + StateTracker.get_args().sana_complex_human_instruction + if not is_negative_prompt + else None + ), ) return prompt_embeds, prompt_attention_mask @@ -616,6 +621,7 @@ def compute_embeddings_for_prompts( return_concat: bool = True, is_validation: bool = False, load_from_cache: bool = True, + is_negative_prompt: bool = False, ): logger.debug("Initialising text embed calculator...") if not self.batch_write_thread.is_alive(): @@ -694,6 +700,7 @@ def compute_embeddings_for_prompts( raw_prompts, return_concat=return_concat, load_from_cache=load_from_cache, + is_negative_prompt=is_negative_prompt, ) else: raise ValueError( @@ -1073,6 +1080,7 @@ def compute_embeddings_for_sana_prompts( prompts: list = None, return_concat: bool = True, load_from_cache: bool = True, + is_negative_prompt: bool = False, ): logger.debug( f"compute_embeddings_for_sana_prompts arguments: prompts={prompts}, return_concat={return_concat}, load_from_cache={load_from_cache}" @@ -1171,7 +1179,7 @@ def compute_embeddings_for_sana_prompts( time.sleep(5) # TODO: Batch this prompt_embeds, attention_mask = self.compute_gemma_prompt( - prompt=prompt, + prompt=prompt, is_negative_prompt=is_negative_prompt ) if "deepfloyd" not in StateTracker.get_args().model_type: # we have to store the attn mask with the embed for pixart. diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index 97ea5589..89ff22f0 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1469,7 +1469,24 @@ def get_argument_parser(): " the value given." ), ) - + parser.add_argument( + "--sana_complex_human_instruction", + type=str, + default=[ + "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:", + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.", + "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.", + "Here are examples of how to transform or refine prompts:", + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.", + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.", + "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:", + "User Prompt: ", + ], + help=( + "When generating embeds for Sana, a complex human instruction will be attached to your prompt by default." + " This is required for the Gemma model to produce meaningful image caption embeds." + ), + ) parser.add_argument( "--allow_tf32", action="store_true", @@ -2541,6 +2558,25 @@ def parse_cmdline_args(input_args=None, exit_on_error: bool = False): logger.error(f"Could not load skip layers: {e}") raise + if ( + args.sana_complex_human_instruction is not None + and type(args.sana_complex_human_instruction) is str + and args.sana_complex_human_instruction not in ["", "None"] + ): + try: + import json + + args.sana_complex_human_instruction = json.loads( + args.sana_complex_human_instruction + ) + except Exception as e: + logger.error( + f"Could not load complex human instruction ({args.sana_complex_human_instruction}): {e}" + ) + raise + elif args.sana_complex_human_instruction == "None": + args.sana_complex_human_instruction = None + if args.enable_xformers_memory_efficient_attention: if args.attention_mechanism != "xformers": warning_log( diff --git a/helpers/training/validation.py b/helpers/training/validation.py index 8055e452..de8cf7cf 100644 --- a/helpers/training/validation.py +++ b/helpers/training/validation.py @@ -258,11 +258,10 @@ def prepare_validation_prompt_list(args, embed_cache): or model_type == "sana" ): # we use the legacy encoder but we return no pooled embeds. - validation_negative_prompt_embeds = ( - embed_cache.compute_embeddings_for_prompts( - [StateTracker.get_args().validation_negative_prompt], - load_from_cache=False, - ) + validation_negative_prompt_embeds = embed_cache.compute_embeddings_for_prompts( + [StateTracker.get_args().validation_negative_prompt], + load_from_cache=False, + is_negative_prompt=True, # sana needs this to disable Complex Human Instruction on negative embed generation ) return ( @@ -1388,6 +1387,11 @@ def validate_prompt( if StateTracker.get_model_family() == "flux": if "negative_prompt" in pipeline_kwargs: del pipeline_kwargs["negative_prompt"] + if self.args.model_family == "sana": + pipeline_kwargs["complex_human_instruction"] = ( + self.args.sana_complex_human_instruction + ) + if ( StateTracker.get_model_family() == "pixart_sigma" or StateTracker.get_model_family() == "smoldit" diff --git a/poetry.lock b/poetry.lock index 7a266c75..9ef17355 100644 --- a/poetry.lock +++ b/poetry.lock @@ -720,7 +720,7 @@ training = ["Jinja2", "accelerate (>=0.31.0)", "datasets", "peft (>=0.6.0)", "pr type = "git" url = "https://github.com/lawrence-cj/diffusers" reference = "Sana" -resolved_reference = "d3312ccec73ff753338792a3e8dc8fc39168ce49" +resolved_reference = "b4af50d67f83a893420496ad1382186df8e91688" [[package]] name = "dill"