Skip to content

Commit

Permalink
Merge pull request #1216 from bghira/feature/sana-complex-human-instr…
Browse files Browse the repository at this point in the history
…uction

sana: add complex human instruction to user prompts by default (untested)
  • Loading branch information
bghira authored Dec 17, 2024
2 parents 2d52cda + d3500aa commit fb58cef
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 9 deletions.
12 changes: 10 additions & 2 deletions helpers/caching/text_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,13 +599,18 @@ def compute_t5_prompt(self, prompt: str):

return result, attn_mask

def compute_gemma_prompt(self, prompt: str):
def compute_gemma_prompt(self, prompt: str, is_negative_prompt: bool):
prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt(
prompt=prompt,
do_classifier_free_guidance=False,
device=self.accelerator.device,
clean_caption=False,
max_sequence_length=300,
complex_human_instruction=(
StateTracker.get_args().sana_complex_human_instruction
if not is_negative_prompt
else None
),
)

return prompt_embeds, prompt_attention_mask
Expand All @@ -616,6 +621,7 @@ def compute_embeddings_for_prompts(
return_concat: bool = True,
is_validation: bool = False,
load_from_cache: bool = True,
is_negative_prompt: bool = False,
):
logger.debug("Initialising text embed calculator...")
if not self.batch_write_thread.is_alive():
Expand Down Expand Up @@ -694,6 +700,7 @@ def compute_embeddings_for_prompts(
raw_prompts,
return_concat=return_concat,
load_from_cache=load_from_cache,
is_negative_prompt=is_negative_prompt,
)
else:
raise ValueError(
Expand Down Expand Up @@ -1073,6 +1080,7 @@ def compute_embeddings_for_sana_prompts(
prompts: list = None,
return_concat: bool = True,
load_from_cache: bool = True,
is_negative_prompt: bool = False,
):
logger.debug(
f"compute_embeddings_for_sana_prompts arguments: prompts={prompts}, return_concat={return_concat}, load_from_cache={load_from_cache}"
Expand Down Expand Up @@ -1171,7 +1179,7 @@ def compute_embeddings_for_sana_prompts(
time.sleep(5)
# TODO: Batch this
prompt_embeds, attention_mask = self.compute_gemma_prompt(
prompt=prompt,
prompt=prompt, is_negative_prompt=is_negative_prompt
)
if "deepfloyd" not in StateTracker.get_args().model_type:
# we have to store the attn mask with the embed for pixart.
Expand Down
38 changes: 37 additions & 1 deletion helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1469,7 +1469,24 @@ def get_argument_parser():
" the value given."
),
)

parser.add_argument(
"--sana_complex_human_instruction",
type=str,
default=[
"Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
"- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
"Here are examples of how to transform or refine prompts:",
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
"Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
"User Prompt: ",
],
help=(
"When generating embeds for Sana, a complex human instruction will be attached to your prompt by default."
" This is required for the Gemma model to produce meaningful image caption embeds."
),
)
parser.add_argument(
"--allow_tf32",
action="store_true",
Expand Down Expand Up @@ -2541,6 +2558,25 @@ def parse_cmdline_args(input_args=None, exit_on_error: bool = False):
logger.error(f"Could not load skip layers: {e}")
raise

if (
args.sana_complex_human_instruction is not None
and type(args.sana_complex_human_instruction) is str
and args.sana_complex_human_instruction not in ["", "None"]
):
try:
import json

args.sana_complex_human_instruction = json.loads(
args.sana_complex_human_instruction
)
except Exception as e:
logger.error(
f"Could not load complex human instruction ({args.sana_complex_human_instruction}): {e}"
)
raise
elif args.sana_complex_human_instruction == "None":
args.sana_complex_human_instruction = None

if args.enable_xformers_memory_efficient_attention:
if args.attention_mechanism != "xformers":
warning_log(
Expand Down
14 changes: 9 additions & 5 deletions helpers/training/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,10 @@ def prepare_validation_prompt_list(args, embed_cache):
or model_type == "sana"
):
# we use the legacy encoder but we return no pooled embeds.
validation_negative_prompt_embeds = (
embed_cache.compute_embeddings_for_prompts(
[StateTracker.get_args().validation_negative_prompt],
load_from_cache=False,
)
validation_negative_prompt_embeds = embed_cache.compute_embeddings_for_prompts(
[StateTracker.get_args().validation_negative_prompt],
load_from_cache=False,
is_negative_prompt=True, # sana needs this to disable Complex Human Instruction on negative embed generation
)

return (
Expand Down Expand Up @@ -1388,6 +1387,11 @@ def validate_prompt(
if StateTracker.get_model_family() == "flux":
if "negative_prompt" in pipeline_kwargs:
del pipeline_kwargs["negative_prompt"]
if self.args.model_family == "sana":
pipeline_kwargs["complex_human_instruction"] = (
self.args.sana_complex_human_instruction
)

if (
StateTracker.get_model_family() == "pixart_sigma"
or StateTracker.get_model_family() == "smoldit"
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit fb58cef

Please sign in to comment.