Skip to content

Commit

Permalink
test_style_aligned: switch to CLIP text batch API
Browse files Browse the repository at this point in the history
  • Loading branch information
deltheil authored and rodSiry committed Feb 28, 2024
1 parent 2c91aab commit 3c231b4
Showing 1 changed file with 3 additions and 23 deletions.
26 changes: 3 additions & 23 deletions tests/e2e/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2283,29 +2283,9 @@ def test_style_aligned(
]

# create (context) embeddings from prompts
# TODO: replace this logic with https://github.com/finegrain-ai/refiners/pull/263 when it gets merged
unconds: list[torch.Tensor] = []
conds: list[torch.Tensor] = []
pooled_unconds: list[torch.Tensor] = []
pooled_conds: list[torch.Tensor] = []
for prompt in set_of_prompts:
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(text=prompt)

uncond, cond = clip_text_embedding.chunk(2)
pooled_uncond, pooled_cond = pooled_text_embedding.chunk(2)

unconds.append(uncond)
conds.append(cond)
pooled_unconds.append(pooled_uncond)
pooled_conds.append(pooled_cond)

uncond = torch.cat(unconds, dim=0)
cond = torch.cat(conds, dim=0)
pooled_uncond = torch.cat(pooled_unconds, dim=0)
pooled_cond = torch.cat(pooled_conds, dim=0)

clip_text_embedding = torch.cat((uncond, cond), dim=0)
pooled_text_embedding = torch.cat((pooled_uncond, pooled_cond), dim=0)
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=set_of_prompts, negative_text=[""] * len(set_of_prompts)
)

time_ids = sdxl.default_time_ids.repeat(len(set_of_prompts), 1)

Expand Down

0 comments on commit 3c231b4

Please sign in to comment.