Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sana: add complex human instruction to user prompts by default (untested) #1216

Merged
merged 2 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions helpers/caching/text_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,13 +599,18 @@ def compute_t5_prompt(self, prompt: str):

return result, attn_mask

def compute_gemma_prompt(self, prompt: str):
def compute_gemma_prompt(self, prompt: str, is_negative_prompt: bool):
prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt(
prompt=prompt,
do_classifier_free_guidance=False,
device=self.accelerator.device,
clean_caption=False,
max_sequence_length=300,
complex_human_instruction=(
StateTracker.get_args().sana_complex_human_instruction
if not is_negative_prompt
else None
),
)

return prompt_embeds, prompt_attention_mask
Expand All @@ -616,6 +621,7 @@ def compute_embeddings_for_prompts(
return_concat: bool = True,
is_validation: bool = False,
load_from_cache: bool = True,
is_negative_prompt: bool = False,
):
logger.debug("Initialising text embed calculator...")
if not self.batch_write_thread.is_alive():
Expand Down Expand Up @@ -694,6 +700,7 @@ def compute_embeddings_for_prompts(
raw_prompts,
return_concat=return_concat,
load_from_cache=load_from_cache,
is_negative_prompt=is_negative_prompt,
)
else:
raise ValueError(
Expand Down Expand Up @@ -1073,6 +1080,7 @@ def compute_embeddings_for_sana_prompts(
prompts: list = None,
return_concat: bool = True,
load_from_cache: bool = True,
is_negative_prompt: bool = False,
):
logger.debug(
f"compute_embeddings_for_sana_prompts arguments: prompts={prompts}, return_concat={return_concat}, load_from_cache={load_from_cache}"
Expand Down Expand Up @@ -1171,7 +1179,7 @@ def compute_embeddings_for_sana_prompts(
time.sleep(5)
# TODO: Batch this
prompt_embeds, attention_mask = self.compute_gemma_prompt(
prompt=prompt,
prompt=prompt, is_negative_prompt=is_negative_prompt
)
if "deepfloyd" not in StateTracker.get_args().model_type:
# we have to store the attn mask with the embed for pixart.
Expand Down
38 changes: 37 additions & 1 deletion helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1469,7 +1469,24 @@ def get_argument_parser():
" the value given."
),
)

parser.add_argument(
"--sana_complex_human_instruction",
type=str,
default=[
"Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
"- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
"Here are examples of how to transform or refine prompts:",
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
"Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
"User Prompt: ",
],
help=(
"When generating embeds for Sana, a complex human instruction will be attached to your prompt by default."
" This is required for the Gemma model to produce meaningful image caption embeds."
),
)
parser.add_argument(
"--allow_tf32",
action="store_true",
Expand Down Expand Up @@ -2541,6 +2558,25 @@ def parse_cmdline_args(input_args=None, exit_on_error: bool = False):
logger.error(f"Could not load skip layers: {e}")
raise

if (
args.sana_complex_human_instruction is not None
and type(args.sana_complex_human_instruction) is str
and args.sana_complex_human_instruction not in ["", "None"]
):
try:
import json

args.sana_complex_human_instruction = json.loads(
args.sana_complex_human_instruction
)
except Exception as e:
logger.error(
f"Could not load complex human instruction ({args.sana_complex_human_instruction}): {e}"
)
raise
elif args.sana_complex_human_instruction == "None":
args.sana_complex_human_instruction = None

if args.enable_xformers_memory_efficient_attention:
if args.attention_mechanism != "xformers":
warning_log(
Expand Down
14 changes: 9 additions & 5 deletions helpers/training/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,10 @@ def prepare_validation_prompt_list(args, embed_cache):
or model_type == "sana"
):
# we use the legacy encoder but we return no pooled embeds.
validation_negative_prompt_embeds = (
embed_cache.compute_embeddings_for_prompts(
[StateTracker.get_args().validation_negative_prompt],
load_from_cache=False,
)
validation_negative_prompt_embeds = embed_cache.compute_embeddings_for_prompts(
[StateTracker.get_args().validation_negative_prompt],
load_from_cache=False,
is_negative_prompt=True, # sana needs this to disable Complex Human Instruction on negative embed generation
)

return (
Expand Down Expand Up @@ -1388,6 +1387,11 @@ def validate_prompt(
if StateTracker.get_model_family() == "flux":
if "negative_prompt" in pipeline_kwargs:
del pipeline_kwargs["negative_prompt"]
if self.args.model_family == "sana":
pipeline_kwargs["complex_human_instruction"] = (
self.args.sana_complex_human_instruction
)

if (
StateTracker.get_model_family() == "pixart_sigma"
or StateTracker.get_model_family() == "smoldit"
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading