diff --git a/candle-nn/tests/embedding.rs b/candle-nn/tests/embedding.rs new file mode 100644 index 0000000000..fff2b5f94c --- /dev/null +++ b/candle-nn/tests/embedding.rs @@ -0,0 +1,16 @@ +use candle::{DType, Result, Shape}; +use candle_nn::{VarBuilder, VarMap}; +use candle_nn::embedding; + +#[test] +fn test_embedding() -> Result<()> { + let device = candle::Device::Cpu; + + let varmap = VarMap::new(); + let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device); + let embed = embedding(10, 20, vb)?; + + assert_eq!(embed.embeddings().shape(), &Shape::from((10, 20))); + + Ok(()) +} \ No newline at end of file