diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index e58a53b5b7e8..e6e19efc6e4e 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -47,9 +47,10 @@ class DDPMPipeline(DiffusionPipeline): model_cpu_offload_seq = "unet" - def __init__(self, unet, scheduler): + def __init__(self, unet, scheduler, encoder_hidden_states=None): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) + self.encoder_hidden_states=encoder_hidden_states @torch.no_grad() def __call__( @@ -120,7 +121,9 @@ def __call__( for t in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output - model_output = self.unet(image, t).sample + if self.encoder_hidden_states: + model_output = self.unet(image, t, self.encoder_hidden_states).sample + else: model_output = self.unet(image, t).sample # 2. compute previous image: x_t -> x_t-1 image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample