diff --git a/optimum/exporters/neuron/__main__.py b/optimum/exporters/neuron/__main__.py index 2b197d815..2459b4254 100644 --- a/optimum/exporters/neuron/__main__.py +++ b/optimum/exporters/neuron/__main__.py @@ -264,10 +264,8 @@ def get_submodels_and_neuron_configs( if is_stable_diffusion: # TODO: Enable optional outputs for Stable Diffusion - if output_attentions or output_hidden_states: - raise ValueError( - f"`output_attentions` and `output_hidden_states` are not supported by the {task} task yet." - ) + if output_attentions: + raise ValueError(f"`output_attentions`is not supported by the {task} task yet.") models_and_neuron_configs, output_model_names = _get_submodels_and_neuron_configs_for_stable_diffusion( model=model, input_shapes=input_shapes, @@ -275,6 +273,7 @@ def get_submodels_and_neuron_configs( output=output, dynamic_batch_size=dynamic_batch_size, submodels=submodels, + output_hidden_states=output_hidden_states, lora_model_ids=lora_model_ids, lora_weight_names=lora_weight_names, lora_adapter_names=lora_adapter_names, @@ -334,6 +333,7 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion( output: Path, dynamic_batch_size: bool = False, submodels: Optional[Dict[str, Union[Path, str]]] = None, + output_hidden_states: bool = False, lora_model_ids: Optional[Union[str, List[str]]] = None, lora_weight_names: Optional[Union[str, List[str]]] = None, lora_adapter_names: Optional[Union[str, List[str]]] = None, @@ -368,6 +368,7 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion( vae_encoder_input_shapes=input_shapes["vae_encoder"], vae_decoder_input_shapes=input_shapes["vae_decoder"], dynamic_batch_size=dynamic_batch_size, + output_hidden_states=output_hidden_states, lora_model_ids=lora_model_ids, lora_weight_names=lora_weight_names, lora_adapter_names=lora_adapter_names, diff --git a/optimum/exporters/neuron/base.py b/optimum/exporters/neuron/base.py index b3b5a7783..a9d8c1dba 100644 --- a/optimum/exporters/neuron/base.py +++ b/optimum/exporters/neuron/base.py @@ -336,6 +336,7 @@ def patch_model_for_export( Checks if inputs order of the model's forward pass correspond to the generated dummy inputs to ensure the dummy inputs tuple used for tracing are under the correct order. """ + output_hidden_states = self.output_hidden_states class ModelWrapper(torch.nn.Module): def __init__(self, model: "PreTrainedModel", input_names: List[str]): @@ -355,10 +356,13 @@ def forward(self, *input): if forward_with_tuple is True: outputs = self.model(*ordered_inputs.values()) else: + if output_hidden_states: + ordered_inputs["output_hidden_states"] = True outputs = self.model(**ordered_inputs) - if isinstance(outputs, dict) and eligible_outputs is not None: - outputs = {name: outputs[name] for name in outputs.keys() & eligible_outputs} + if isinstance(outputs, dict): + if eligible_outputs is not None: + outputs = {name: outputs[name] for name in outputs.keys() & eligible_outputs} if isinstance(outputs, tuple) and eligible_outputs is not None: if not all(isinstance(x, int) for x in eligible_outputs): diff --git a/optimum/exporters/neuron/model_configs.py b/optimum/exporters/neuron/model_configs.py index 1c3a47d1b..d528a7d22 100644 --- a/optimum/exporters/neuron/model_configs.py +++ b/optimum/exporters/neuron/model_configs.py @@ -244,7 +244,7 @@ def outputs(self) -> List[str]: @register_in_tasks_manager("clip-text-with-projection", *["feature-extraction"], library_name="diffusers") class CLIPTextWithProjectionNeuronConfig(TextEncoderNeuronConfig): - MODEL_TYPE = "clip-text-model" + MODEL_TYPE = "clip-text-with-projection" ATOL_FOR_VALIDATION = 1e-3 NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( diff --git a/optimum/exporters/neuron/utils.py b/optimum/exporters/neuron/utils.py index f0718d18c..76063d487 100644 --- a/optimum/exporters/neuron/utils.py +++ b/optimum/exporters/neuron/utils.py @@ -121,6 +121,7 @@ def get_stable_diffusion_models_for_export( vae_encoder_input_shapes: Dict[str, int], vae_decoder_input_shapes: Dict[str, int], dynamic_batch_size: Optional[bool] = False, + output_hidden_states: bool = False, lora_model_ids: Optional[List[str]] = None, lora_weight_names: Optional[List[str]] = None, lora_adapter_names: Optional[List[str]] = None, @@ -147,6 +148,8 @@ def get_stable_diffusion_models_for_export( Static shapes used for compiling vae decoder. dynamic_batch_size (`bool`, defaults to `False`): Whether the Neuron compiled model supports dynamic batch size. + output_hidden_states (`bool`, defaults to `False`): + Whether or not for the traced text encoders to return the hidden states of all layers. lora_model_ids (`Optional[List[str]]`, defaults to `None`): List of model ids (eg. `ostris/super-cereal-sdxl-lora`) of pretrained lora models hosted on the Hub or paths to local directories containing the lora weights. lora_weight_names (`Optional[List[str]]`, defaults to `None`): @@ -183,6 +186,7 @@ def get_stable_diffusion_models_for_export( text_encoder.config, task="feature-extraction", dynamic_batch_size=dynamic_batch_size, + output_hidden_states=output_hidden_states, **text_encoder_input_shapes, ) models_for_export[DIFFUSION_MODEL_TEXT_ENCODER_NAME] = (text_encoder, text_encoder_neuron_config) @@ -200,6 +204,7 @@ def get_stable_diffusion_models_for_export( text_encoder_2.config, task="feature-extraction", dynamic_batch_size=dynamic_batch_size, + output_hidden_states=output_hidden_states, **text_encoder_input_shapes, ) models_for_export[DIFFUSION_MODEL_TEXT_ENCODER_2_NAME] = (text_encoder_2, text_encoder_neuron_config_2) @@ -306,6 +311,7 @@ def _load_lora_weights_to_pipeline( def get_submodels_for_export_stable_diffusion( pipeline: Union["StableDiffusionPipeline", "StableDiffusionXLImg2ImgPipeline"], task: str, + output_hidden_states: bool = False, lora_model_ids: Optional[Union[str, List[str]]] = None, lora_weight_names: Optional[Union[str, List[str]]] = None, lora_adapter_names: Optional[Union[str, List[str]]] = None, @@ -332,7 +338,7 @@ def get_submodels_for_export_stable_diffusion( # Text encoders if pipeline.text_encoder is not None: - if is_sdxl: + if is_sdxl or output_hidden_states: pipeline.text_encoder.config.output_hidden_states = True models_for_export.append((DIFFUSION_MODEL_TEXT_ENCODER_NAME, copy.deepcopy(pipeline.text_encoder))) diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py index 8681afd6c..e86420c39 100644 --- a/optimum/neuron/modeling_diffusion.py +++ b/optimum/neuron/modeling_diffusion.py @@ -27,6 +27,7 @@ import torch from huggingface_hub import snapshot_download from transformers import CLIPFeatureExtractor, CLIPTokenizer, PretrainedConfig +from transformers.modeling_outputs import ModelOutput from ..exporters.neuron import ( load_models_and_neuron_configs, @@ -585,6 +586,7 @@ def _export( auto_cast: Optional[str] = "matmul", auto_cast_type: Optional[str] = "bf16", dynamic_batch_size: bool = False, + output_hidden_states: bool = False, data_parallel_mode: Optional[str] = None, lora_model_ids: Optional[Union[str, List[str]]] = None, lora_weight_names: Optional[Union[str, List[str]]] = None, @@ -646,6 +648,8 @@ def _export( dynamic_batch_size (`bool`, defaults to `False`): Whether to enable dynamic batch size for neuron compiled model. If this option is enabled, the input batch size can be a multiple of the batch size during the compilation, but it comes with a potential tradeoff in terms of latency. + output_hidden_states (`bool`, defaults to `False`): + Whether or not for the traced text encoders to return the hidden states of all layers. data_parallel_mode (`Optional[str]`, defaults to `None`): Mode to decide what components to load into both NeuronCores of a Neuron device. Can be "none"(no data parallel), "unet"(only load unet into both cores of each device), "all"(load the whole pipeline into both cores). @@ -708,6 +712,7 @@ def _export( local_files_only=local_files_only, use_auth_token=use_auth_token, submodels=submodels, + output_hidden_states=output_hidden_states, lora_model_ids=lora_model_ids, lora_weight_names=lora_weight_names, lora_adapter_names=lora_adapter_names, @@ -738,6 +743,7 @@ def _export( optlevel=optlevel, model_type=getattr(neuron_config, "MODEL_TYPE", None), task=getattr(neuron_config, "task", None), + output_hidden_states=output_hidden_states, ) compilation_configs[name] = compilation_config @@ -780,6 +786,7 @@ def _export( use_auth_token=use_auth_token, do_validation=False, submodels={"unet": unet_id}, + output_hidden_states=output_hidden_states, lora_model_ids=lora_model_ids, lora_weight_names=lora_weight_names, lora_adapter_names=lora_adapter_names, @@ -842,9 +849,30 @@ def __init__( ): super().__init__(model, parent_model, config, neuron_config, DIFFUSION_MODEL_TEXT_ENCODER_NAME) - def forward(self, input_ids: torch.Tensor): + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + if attention_mask is not None: + assert torch.equal( + torch.ones_like(attention_mask), attention_mask + ), "attention_mask is expected to be only all ones." + if output_hidden_states: + assert ( + self.config.output_hidden_states or self.config.neuron.get("output_hidden_states") + ) == output_hidden_states, "output_hidden_states is expected to be False since the model was compiled without hidden_states as output." + + input_ids = input_ids.to(torch.long) # dummy generator uses long int for tracing + inputs = (input_ids,) outputs = self.model(*inputs) + + if return_dict: + outputs = ModelOutput(dict(zip(self.neuron_config.outputs, outputs))) + return outputs diff --git a/setup.py b/setup.py index 04c5dc8bc..2bb2e2ddc 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ "safetensors", "sentence-transformers >= 2.2.0", "peft", + "compel", ] QUALITY_REQUIRES = [ diff --git a/tests/inference/test_stable_diffusion_pipeline.py b/tests/inference/test_stable_diffusion_pipeline.py index 24490a347..174eac1f1 100644 --- a/tests/inference/test_stable_diffusion_pipeline.py +++ b/tests/inference/test_stable_diffusion_pipeline.py @@ -17,6 +17,7 @@ import unittest import PIL +from compel import Compel, ReturnedEmbeddingsType from parameterized import parameterized from optimum.neuron import ( @@ -165,6 +166,28 @@ def test_export_and_inference_with_fused_lora(self, model_arch): image = neuron_pipeline(prompts, num_images_per_prompt=num_images_per_prompt).images[0] self.assertIsInstance(image, PIL.Image.Image) + @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) + def test_compatibility_with_compel(self, model_arch): + num_images_per_prompt = 1 + input_shapes = copy.deepcopy(self.STATIC_INPUTS_SHAPES) + input_shapes.update({"num_images_per_prompt": num_images_per_prompt}) + pipe = self.NEURON_MODEL_CLASS.from_pretrained( + MODEL_NAMES[model_arch], + export=True, + inline_weights_to_neff=True, + output_hidden_states=True, + **input_shapes, + **self.COMPILER_ARGS, + ) + + prompt = "a red cat playing with a ball++" + compel_proc = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder) + + prompt_embeds = compel_proc(prompt) + + image = pipe(prompt_embeds=prompt_embeds, num_inference_steps=2).images[0] + self.assertIsInstance(image, PIL.Image.Image) + @is_inferentia_test @requires_neuronx @@ -268,3 +291,42 @@ def test_inpaint_export_and_inference(self, model_arch): prompt = "A deep sea diver floating" image = neuron_pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0] self.assertIsInstance(image, PIL.Image.Image) + + @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) + def test_compatibility_with_compel(self, model_arch): + num_images_per_prompt = 1 + input_shapes = copy.deepcopy(self.STATIC_INPUTS_SHAPES) + input_shapes.update({"num_images_per_prompt": num_images_per_prompt}) + pipe = self.NEURON_MODEL_CLASS.from_pretrained( + MODEL_NAMES[model_arch], + export=True, + inline_weights_to_neff=True, + output_hidden_states=True, + **input_shapes, + **self.COMPILER_ARGS, + ) + + prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" + negative_prompt = "low quality, low resolution" + + compel = Compel( + tokenizer=[pipe.tokenizer, pipe.tokenizer_2], + text_encoder=[pipe.text_encoder, pipe.text_encoder_2], + returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, + requires_pooled=[False, True], + ) + prompt_embeds, pooled = compel(prompt) + neg_prompt_embeds, neg_pooled = compel(negative_prompt) + positive_prompt_embeds, negative_prompt_embeds = compel.pad_conditioning_tensors_to_same_length( + [prompt_embeds, neg_prompt_embeds] + ) + + image = pipe( + prompt_embeds=positive_prompt_embeds, + pooled_prompt_embeds=pooled, + negative_prompt_embeds=negative_prompt_embeds, + negative_pooled_prompt_embeds=neg_pooled, + output_type="pil", + num_inference_steps=1, + ).images[0] + self.assertIsInstance(image, PIL.Image.Image)