diff --git a/src/schedulers/ddpm.rs b/src/schedulers/ddpm.rs index c3cf8aa..1866b85 100644 --- a/src/schedulers/ddpm.rs +++ b/src/schedulers/ddpm.rs @@ -167,9 +167,9 @@ impl DDPMScheduler { // https://github.com/huggingface/diffusers/blob/df2b548e893ccb8a888467c2508756680df22821/src/diffusers/schedulers/scheduling_ddpm.py#L305 // 6. Add noise - let mut variance = Tensor::zeros(&pred_prev_sample.size(), kind::FLOAT_CPU); + let mut variance = model_output.zeros_like(); if timestep > 0 { - let variance_noise = Tensor::randn_like(model_output); + let variance_noise = model_output.randn_like(); if self.config.variance_type == DDPMVarianceType::FixedSmallLog { variance = self.get_variance(timestep) * variance_noise; } else {