From ac0bfd0563f61c0dcd43ca89e662301d604026af Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 20 Feb 2025 22:39:07 -0800 Subject: [PATCH] Add simple test for embedding function. --- candle-nn/tests/embedding.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 candle-nn/tests/embedding.rs diff --git a/candle-nn/tests/embedding.rs b/candle-nn/tests/embedding.rs new file mode 100644 index 0000000000..fff2b5f94c --- /dev/null +++ b/candle-nn/tests/embedding.rs @@ -0,0 +1,16 @@ +use candle::{DType, Result, Shape}; +use candle_nn::{VarBuilder, VarMap}; +use candle_nn::embedding; + +#[test] +fn test_embedding() -> Result<()> { + let device = candle::Device::Cpu; + + let varmap = VarMap::new(); + let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device); + let embed = embedding(10, 20, vb)?; + + assert_eq!(embed.embeddings().shape(), &Shape::from((10, 20))); + + Ok(()) +} \ No newline at end of file