Skip to content

Commit

Permalink
Make stable diffusion pipelines compatible with compel (#581)
Browse files Browse the repository at this point in the history
* poc

* clean

* clean comments

* add tests

* Update optimum/neuron/modeling_diffusion.py

Co-authored-by: Wenchen Li <[email protected]>

* Update optimum/neuron/modeling_diffusion.py

Co-authored-by: Wenchen Li <[email protected]>

* fix and apply comments

---------

Co-authored-by: Wenchen Li <[email protected]>
  • Loading branch information
JingyaHuang and neo authored Apr 30, 2024
1 parent 43796c0 commit 5643f94
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 9 deletions.
9 changes: 5 additions & 4 deletions optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,17 +264,16 @@ def get_submodels_and_neuron_configs(

if is_stable_diffusion:
# TODO: Enable optional outputs for Stable Diffusion
if output_attentions or output_hidden_states:
raise ValueError(
f"`output_attentions` and `output_hidden_states` are not supported by the {task} task yet."
)
if output_attentions:
raise ValueError(f"`output_attentions`is not supported by the {task} task yet.")
models_and_neuron_configs, output_model_names = _get_submodels_and_neuron_configs_for_stable_diffusion(
model=model,
input_shapes=input_shapes,
task=task,
output=output,
dynamic_batch_size=dynamic_batch_size,
submodels=submodels,
output_hidden_states=output_hidden_states,
lora_model_ids=lora_model_ids,
lora_weight_names=lora_weight_names,
lora_adapter_names=lora_adapter_names,
Expand Down Expand Up @@ -334,6 +333,7 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion(
output: Path,
dynamic_batch_size: bool = False,
submodels: Optional[Dict[str, Union[Path, str]]] = None,
output_hidden_states: bool = False,
lora_model_ids: Optional[Union[str, List[str]]] = None,
lora_weight_names: Optional[Union[str, List[str]]] = None,
lora_adapter_names: Optional[Union[str, List[str]]] = None,
Expand Down Expand Up @@ -368,6 +368,7 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion(
vae_encoder_input_shapes=input_shapes["vae_encoder"],
vae_decoder_input_shapes=input_shapes["vae_decoder"],
dynamic_batch_size=dynamic_batch_size,
output_hidden_states=output_hidden_states,
lora_model_ids=lora_model_ids,
lora_weight_names=lora_weight_names,
lora_adapter_names=lora_adapter_names,
Expand Down
8 changes: 6 additions & 2 deletions optimum/exporters/neuron/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def patch_model_for_export(
Checks if inputs order of the model's forward pass correspond to the generated dummy inputs to ensure the dummy inputs tuple used for
tracing are under the correct order.
"""
output_hidden_states = self.output_hidden_states

class ModelWrapper(torch.nn.Module):
def __init__(self, model: "PreTrainedModel", input_names: List[str]):
Expand All @@ -355,10 +356,13 @@ def forward(self, *input):
if forward_with_tuple is True:
outputs = self.model(*ordered_inputs.values())
else:
if output_hidden_states:
ordered_inputs["output_hidden_states"] = True
outputs = self.model(**ordered_inputs)

if isinstance(outputs, dict) and eligible_outputs is not None:
outputs = {name: outputs[name] for name in outputs.keys() & eligible_outputs}
if isinstance(outputs, dict):
if eligible_outputs is not None:
outputs = {name: outputs[name] for name in outputs.keys() & eligible_outputs}

if isinstance(outputs, tuple) and eligible_outputs is not None:
if not all(isinstance(x, int) for x in eligible_outputs):
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/neuron/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def outputs(self) -> List[str]:

@register_in_tasks_manager("clip-text-with-projection", *["feature-extraction"], library_name="diffusers")
class CLIPTextWithProjectionNeuronConfig(TextEncoderNeuronConfig):
MODEL_TYPE = "clip-text-model"
MODEL_TYPE = "clip-text-with-projection"
ATOL_FOR_VALIDATION = 1e-3

NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
Expand Down
8 changes: 7 additions & 1 deletion optimum/exporters/neuron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def get_stable_diffusion_models_for_export(
vae_encoder_input_shapes: Dict[str, int],
vae_decoder_input_shapes: Dict[str, int],
dynamic_batch_size: Optional[bool] = False,
output_hidden_states: bool = False,
lora_model_ids: Optional[List[str]] = None,
lora_weight_names: Optional[List[str]] = None,
lora_adapter_names: Optional[List[str]] = None,
Expand All @@ -147,6 +148,8 @@ def get_stable_diffusion_models_for_export(
Static shapes used for compiling vae decoder.
dynamic_batch_size (`bool`, defaults to `False`):
Whether the Neuron compiled model supports dynamic batch size.
output_hidden_states (`bool`, defaults to `False`):
Whether or not for the traced text encoders to return the hidden states of all layers.
lora_model_ids (`Optional[List[str]]`, defaults to `None`):
List of model ids (eg. `ostris/super-cereal-sdxl-lora`) of pretrained lora models hosted on the Hub or paths to local directories containing the lora weights.
lora_weight_names (`Optional[List[str]]`, defaults to `None`):
Expand Down Expand Up @@ -183,6 +186,7 @@ def get_stable_diffusion_models_for_export(
text_encoder.config,
task="feature-extraction",
dynamic_batch_size=dynamic_batch_size,
output_hidden_states=output_hidden_states,
**text_encoder_input_shapes,
)
models_for_export[DIFFUSION_MODEL_TEXT_ENCODER_NAME] = (text_encoder, text_encoder_neuron_config)
Expand All @@ -200,6 +204,7 @@ def get_stable_diffusion_models_for_export(
text_encoder_2.config,
task="feature-extraction",
dynamic_batch_size=dynamic_batch_size,
output_hidden_states=output_hidden_states,
**text_encoder_input_shapes,
)
models_for_export[DIFFUSION_MODEL_TEXT_ENCODER_2_NAME] = (text_encoder_2, text_encoder_neuron_config_2)
Expand Down Expand Up @@ -306,6 +311,7 @@ def _load_lora_weights_to_pipeline(
def get_submodels_for_export_stable_diffusion(
pipeline: Union["StableDiffusionPipeline", "StableDiffusionXLImg2ImgPipeline"],
task: str,
output_hidden_states: bool = False,
lora_model_ids: Optional[Union[str, List[str]]] = None,
lora_weight_names: Optional[Union[str, List[str]]] = None,
lora_adapter_names: Optional[Union[str, List[str]]] = None,
Expand All @@ -332,7 +338,7 @@ def get_submodels_for_export_stable_diffusion(

# Text encoders
if pipeline.text_encoder is not None:
if is_sdxl:
if is_sdxl or output_hidden_states:
pipeline.text_encoder.config.output_hidden_states = True
models_for_export.append((DIFFUSION_MODEL_TEXT_ENCODER_NAME, copy.deepcopy(pipeline.text_encoder)))

Expand Down
30 changes: 29 additions & 1 deletion optimum/neuron/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import torch
from huggingface_hub import snapshot_download
from transformers import CLIPFeatureExtractor, CLIPTokenizer, PretrainedConfig
from transformers.modeling_outputs import ModelOutput

from ..exporters.neuron import (
load_models_and_neuron_configs,
Expand Down Expand Up @@ -585,6 +586,7 @@ def _export(
auto_cast: Optional[str] = "matmul",
auto_cast_type: Optional[str] = "bf16",
dynamic_batch_size: bool = False,
output_hidden_states: bool = False,
data_parallel_mode: Optional[str] = None,
lora_model_ids: Optional[Union[str, List[str]]] = None,
lora_weight_names: Optional[Union[str, List[str]]] = None,
Expand Down Expand Up @@ -646,6 +648,8 @@ def _export(
dynamic_batch_size (`bool`, defaults to `False`):
Whether to enable dynamic batch size for neuron compiled model. If this option is enabled, the input batch size can be a multiple of the
batch size during the compilation, but it comes with a potential tradeoff in terms of latency.
output_hidden_states (`bool`, defaults to `False`):
Whether or not for the traced text encoders to return the hidden states of all layers.
data_parallel_mode (`Optional[str]`, defaults to `None`):
Mode to decide what components to load into both NeuronCores of a Neuron device. Can be "none"(no data parallel), "unet"(only
load unet into both cores of each device), "all"(load the whole pipeline into both cores).
Expand Down Expand Up @@ -708,6 +712,7 @@ def _export(
local_files_only=local_files_only,
use_auth_token=use_auth_token,
submodels=submodels,
output_hidden_states=output_hidden_states,
lora_model_ids=lora_model_ids,
lora_weight_names=lora_weight_names,
lora_adapter_names=lora_adapter_names,
Expand Down Expand Up @@ -738,6 +743,7 @@ def _export(
optlevel=optlevel,
model_type=getattr(neuron_config, "MODEL_TYPE", None),
task=getattr(neuron_config, "task", None),
output_hidden_states=output_hidden_states,
)
compilation_configs[name] = compilation_config

Expand Down Expand Up @@ -780,6 +786,7 @@ def _export(
use_auth_token=use_auth_token,
do_validation=False,
submodels={"unet": unet_id},
output_hidden_states=output_hidden_states,
lora_model_ids=lora_model_ids,
lora_weight_names=lora_weight_names,
lora_adapter_names=lora_adapter_names,
Expand Down Expand Up @@ -842,9 +849,30 @@ def __init__(
):
super().__init__(model, parent_model, config, neuron_config, DIFFUSION_MODEL_TEXT_ENCODER_NAME)

def forward(self, input_ids: torch.Tensor):
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
if attention_mask is not None:
assert torch.equal(
torch.ones_like(attention_mask), attention_mask
), "attention_mask is expected to be only all ones."
if output_hidden_states:
assert (
self.config.output_hidden_states or self.config.neuron.get("output_hidden_states")
) == output_hidden_states, "output_hidden_states is expected to be False since the model was compiled without hidden_states as output."

input_ids = input_ids.to(torch.long) # dummy generator uses long int for tracing

inputs = (input_ids,)
outputs = self.model(*inputs)

if return_dict:
outputs = ModelOutput(dict(zip(self.neuron_config.outputs, outputs)))

return outputs


Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"safetensors",
"sentence-transformers >= 2.2.0",
"peft",
"compel",
]

QUALITY_REQUIRES = [
Expand Down
62 changes: 62 additions & 0 deletions tests/inference/test_stable_diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest

import PIL
from compel import Compel, ReturnedEmbeddingsType
from parameterized import parameterized

from optimum.neuron import (
Expand Down Expand Up @@ -165,6 +166,28 @@ def test_export_and_inference_with_fused_lora(self, model_arch):
image = neuron_pipeline(prompts, num_images_per_prompt=num_images_per_prompt).images[0]
self.assertIsInstance(image, PIL.Image.Image)

@parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True)
def test_compatibility_with_compel(self, model_arch):
num_images_per_prompt = 1
input_shapes = copy.deepcopy(self.STATIC_INPUTS_SHAPES)
input_shapes.update({"num_images_per_prompt": num_images_per_prompt})
pipe = self.NEURON_MODEL_CLASS.from_pretrained(
MODEL_NAMES[model_arch],
export=True,
inline_weights_to_neff=True,
output_hidden_states=True,
**input_shapes,
**self.COMPILER_ARGS,
)

prompt = "a red cat playing with a ball++"
compel_proc = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)

prompt_embeds = compel_proc(prompt)

image = pipe(prompt_embeds=prompt_embeds, num_inference_steps=2).images[0]
self.assertIsInstance(image, PIL.Image.Image)


@is_inferentia_test
@requires_neuronx
Expand Down Expand Up @@ -268,3 +291,42 @@ def test_inpaint_export_and_inference(self, model_arch):
prompt = "A deep sea diver floating"
image = neuron_pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
self.assertIsInstance(image, PIL.Image.Image)

@parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True)
def test_compatibility_with_compel(self, model_arch):
num_images_per_prompt = 1
input_shapes = copy.deepcopy(self.STATIC_INPUTS_SHAPES)
input_shapes.update({"num_images_per_prompt": num_images_per_prompt})
pipe = self.NEURON_MODEL_CLASS.from_pretrained(
MODEL_NAMES[model_arch],
export=True,
inline_weights_to_neff=True,
output_hidden_states=True,
**input_shapes,
**self.COMPILER_ARGS,
)

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
negative_prompt = "low quality, low resolution"

compel = Compel(
tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=[False, True],
)
prompt_embeds, pooled = compel(prompt)
neg_prompt_embeds, neg_pooled = compel(negative_prompt)
positive_prompt_embeds, negative_prompt_embeds = compel.pad_conditioning_tensors_to_same_length(
[prompt_embeds, neg_prompt_embeds]
)

image = pipe(
prompt_embeds=positive_prompt_embeds,
pooled_prompt_embeds=pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_pooled_prompt_embeds=neg_pooled,
output_type="pil",
num_inference_steps=1,
).images[0]
self.assertIsInstance(image, PIL.Image.Image)

0 comments on commit 5643f94

Please sign in to comment.