Skip to content

Commit

Permalink
fix(schedulers): same device on DDPMScheduler
Browse files Browse the repository at this point in the history
make sure variance is on same device as model output in the  method
  • Loading branch information
mspronesti committed Feb 1, 2023
1 parent ab2ce6d commit 45f6cd1
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/schedulers/ddpm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,10 @@ impl DDPMScheduler {

// https://github.com/huggingface/diffusers/blob/df2b548e893ccb8a888467c2508756680df22821/src/diffusers/schedulers/scheduling_ddpm.py#L305
// 6. Add noise
let mut variance = Tensor::zeros(&pred_prev_sample.size(), kind::FLOAT_CPU);
let (device, dtype) = (model_output.device(), model_output.kind());
let mut variance = Tensor::zeros(&model_output.size(), (dtype, device));
if timestep > 0 {
let variance_noise = Tensor::randn_like(model_output);
let variance_noise = model_output.randn_like();
if self.config.variance_type == DDPMVarianceType::FixedSmallLog {
variance = self.get_variance(timestep) * variance_noise;
} else {
Expand Down

0 comments on commit 45f6cd1

Please sign in to comment.