Skip to content

Commit

Permalink
fixed masked padding bug. (#130)
Browse files Browse the repository at this point in the history
  • Loading branch information
A-Jacobson authored Mar 26, 2024
1 parent 05b8197 commit 5f3c9aa
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 27 deletions.
58 changes: 35 additions & 23 deletions diffusion/models/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class StableDiffusion(ComposerModel):
Default: `False`.
encode_latents_in_fp16 (bool): whether to encode latents in fp16.
Default: `False`.
mask_pad_tokens (bool): whether to mask pad tokens in cross attention.
mask_pad_tokens (bool): whether to mask pad tokens in unet cross attention.
Default: `False`.
sdxl (bool): Whether or not we're training SDXL. Default: `False`.
"""
Expand Down Expand Up @@ -148,14 +148,7 @@ def forward(self, batch):
attention_mask = batch['attention_mask'] # mask for text encoders
# text mask for U-Net
if self.mask_pad_tokens:
if len(attention_mask.shape) == 2:
encoder_attention_mask = attention_mask
elif len(attention_mask.shape) == 3:
encoder_attention_mask = attention_mask[:, 0]
for i in range(1, attention_mask.shape[1]):
encoder_attention_mask |= attention_mask[:, i]
else:
raise ValueError(f'attention_mask should have either 2 or 3 dimensions: {attention_mask.shape}')
encoder_attention_mask = _create_unet_attention_mask(attention_mask)

# Use latents if specified and available. When specified, they might not exist during eval
if self.precomputed_latents and self.image_latents_key in batch and self.text_latents_key in batch:
Expand Down Expand Up @@ -469,22 +462,17 @@ def _prepare_text_embeddings(self, prompt, tokenized_prompts, tokenized_pad_mask
raise NotImplementedError('SDXL not yet supported with precomputed embeddings')

# duplicate text embeddings for each generation per prompt
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) # type: ignore
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

if self.mask_pad_tokens and tokenized_pad_mask is not None:
if len(tokenized_pad_mask.shape) == 3:
attention_mask = tokenized_pad_mask[:, 0]
for i in range(1, tokenized_pad_mask.shape[1]):
attention_mask |= tokenized_pad_mask[:, i]
tokenized_pad_mask = attention_mask
tokenized_pad_mask = tokenized_pad_mask.repeat(1, num_images_per_prompt, 1)
tokenized_pad_mask = tokenized_pad_mask.view(bs_embed * num_images_per_prompt, seq_len) # [B, 77]
prompt_embeds = _duplicate_tensor(prompt_embeds, num_images_per_prompt)

if not self.mask_pad_tokens:
tokenized_pad_mask = None

if tokenized_pad_mask is not None:
tokenized_pad_mask = _create_unet_attention_mask(tokenized_pad_mask)
tokenized_pad_mask = _duplicate_tensor(tokenized_pad_mask, num_images_per_prompt)

if self.sdxl and pooled_text_embeddings is not None:
pooled_text_embeddings = pooled_text_embeddings.repeat(1, num_images_per_prompt)
pooled_text_embeddings = pooled_text_embeddings.view(bs_embed * num_images_per_prompt, -1)
pooled_text_embeddings = _duplicate_tensor(pooled_text_embeddings, num_images_per_prompt)
return prompt_embeds, pooled_text_embeddings, tokenized_pad_mask


Expand All @@ -502,3 +490,27 @@ def _check_prompt_lenths(prompt, negative_prompt):
def _check_prompt_given(prompt, tokenized_prompts, prompt_embeds):
if prompt is None and tokenized_prompts is None and prompt_embeds is None:
raise ValueError('Must provide one of `prompt`, `tokenized_prompts`, or `prompt_embeds`')


def _create_unet_attention_mask(attention_mask):
"""Takes the union of multiple attention masks if given more than one mask."""
if len(attention_mask.shape) == 2:
return attention_mask
elif len(attention_mask.shape) == 3:
encoder_attention_mask = attention_mask[:, 0]
for i in range(1, attention_mask.shape[1]):
encoder_attention_mask |= attention_mask[:, i]
return encoder_attention_mask
else:
raise ValueError(f'attention_mask should have either 2 or 3 dimensions: {attention_mask.shape}')


def _duplicate_tensor(tensor, num_images_per_prompt):
"""Duplicate tensor for multiple generations from a single prompt."""
batch_size, seq_len = tensor.shape[:2]
tensor = tensor.repeat(1, num_images_per_prompt, *[
1,
] * len(tensor.shape[2:]))
return tensor.view(batch_size * num_images_per_prompt, seq_len, *[
-1,
] * len(tensor.shape[2:]))
14 changes: 10 additions & 4 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,15 @@ def test_sd2_generate(guidance_scale, negative_prompt):
assert output.shape == (1, 3, 8, 8)


def test_sdxl_forward():
@pytest.mark.parametrize('mask_pad_tokens', [True, False])
def test_sdxl_forward(mask_pad_tokens):
# fp16 vae does not run on cpu
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = stable_diffusion_xl(pretrained=False, fsdp=False, encode_latents_in_fp16=False, use_xformers=False)
model = stable_diffusion_xl(pretrained=False,
mask_pad_tokens=mask_pad_tokens,
fsdp=False,
encode_latents_in_fp16=False,
use_xformers=False)
batch_size = 1
H = 16
W = 16
Expand All @@ -77,13 +82,14 @@ def test_sdxl_forward():

@pytest.mark.parametrize('guidance_scale', [0.0, 3.0])
@pytest.mark.parametrize('negative_prompt', [None, 'so cool'])
def test_sdxl_generate(guidance_scale, negative_prompt):
@pytest.mark.parametrize('mask_pad_tokens', [True, False])
def test_sdxl_generate(guidance_scale, negative_prompt, mask_pad_tokens):
# fp16 vae does not run on cpu
model = stable_diffusion_xl(pretrained=False,
fsdp=False,
encode_latents_in_fp16=False,
use_xformers=False,
mask_pad_tokens=True)
mask_pad_tokens=mask_pad_tokens)
output = model.generate(
prompt='a cool doge',
negative_prompt=negative_prompt,
Expand Down

0 comments on commit 5f3c9aa

Please sign in to comment.