diff --git a/src/ai/fou/fou_vae.rs b/src/ai/fou/fou_vae.rs index 59e26e6..1906ef4 100644 --- a/src/ai/fou/fou_vae.rs +++ b/src/ai/fou/fou_vae.rs @@ -469,7 +469,7 @@ mod tests { )?; let batch_size = 32; - let xs = Tensor::randn(0.0, 1.0, &[batch_size, seq_len, input_dim], device)?; + let xs = Tensor::randn(0.0, 1.0, &[batch_size, seq_len, input_dim], &Device::Cpu)?; let (x_reconstructed, sigma_estimated, mu, log_var, z) = model.forward(&xs)?;