Skip to content

Commit

Permalink
Use randn_tensor to replace torch.randn
Browse files Browse the repository at this point in the history
`torch.randn` requires `generator` and `latents` on the same device, while the wrapped function `randn_tensor` does not have this issue.
  • Loading branch information
lmxyy authored Jan 11, 2025
1 parent 36acdd7 commit f233868
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/ltx/pipeline_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ def __call__(
if not self.vae.config.timestep_conditioning:
timestep = None
else:
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
if not isinstance(decode_timestep, list):
decode_timestep = [decode_timestep] * batch_size
if decode_noise_scale is None:
Expand Down

0 comments on commit f233868

Please sign in to comment.