From 385b78cef96dfc1f06682a04803e3abfb87552a1 Mon Sep 17 00:00:00 2001 From: Wouter Doppenberg Date: Tue, 23 Apr 2024 10:42:12 +0200 Subject: [PATCH] Added simple example (#31) --- crates/glowrs/examples/simple.rs | 45 ++++++++++++++++++++++++++++++++ crates/glowrs/src/lib.rs | 1 + 2 files changed, 46 insertions(+) create mode 100644 crates/glowrs/examples/simple.rs diff --git a/crates/glowrs/examples/simple.rs b/crates/glowrs/examples/simple.rs new file mode 100644 index 0000000..e1a5a2b --- /dev/null +++ b/crates/glowrs/examples/simple.rs @@ -0,0 +1,45 @@ +use glowrs::{Device, Error, PoolingStrategy, SentenceTransformer}; +use std::process::ExitCode; + +fn main() -> Result { + let sentences = vec![ + "The cat sits outside", + "A man is playing guitar", + "I love pasta", + "The new movie is awesome", + "The cat plays in the garden", + "A woman watches TV", + "The new movie is so great", + "Do you like pizza?", + "The cat sits", + ]; + let device = Device::Cpu; + let encoder = + SentenceTransformer::from_repo_string("Snowflake/snowflake-arctic-embed-xs", &device)?; + + let pooling_strategy = PoolingStrategy::Mean; + let embeddings = encoder.encode_batch(sentences.clone(), false, pooling_strategy)?; + println!("Embeddings: {:?}", embeddings); + + let (n_sentences, _) = embeddings.dims2()?; + let mut similarities = Vec::with_capacity(n_sentences * (n_sentences - 1) / 2); + + for i in 0..n_sentences { + let e_i = embeddings.get(i)?; + for j in (i + 1)..n_sentences { + let e_j = embeddings.get(j)?; + let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::()?; + let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::()?; + let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::()?; + let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt(); + similarities.push((cosine_similarity, i, j)) + } + } + + similarities.sort_by(|u, v| v.0.total_cmp(&u.0)); + for &(score, i, j) in similarities[..5].iter() { + println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j]) + } + + Ok(ExitCode::SUCCESS) +} diff --git a/crates/glowrs/src/lib.rs b/crates/glowrs/src/lib.rs index 306c38c..4ccdd83 100644 --- a/crates/glowrs/src/lib.rs +++ b/crates/glowrs/src/lib.rs @@ -3,6 +3,7 @@ mod error; mod exports; pub mod model; + pub use exports::*; pub use crate::error::Error;