Skip to content

Commit

Permalink
Added simple example (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
wdoppenberg authored Apr 23, 2024
1 parent c7d7fe5 commit 385b78c
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
45 changes: 45 additions & 0 deletions crates/glowrs/examples/simple.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use glowrs::{Device, Error, PoolingStrategy, SentenceTransformer};
use std::process::ExitCode;

fn main() -> Result<ExitCode, Error> {
let sentences = vec![
"The cat sits outside",
"A man is playing guitar",
"I love pasta",
"The new movie is awesome",
"The cat plays in the garden",
"A woman watches TV",
"The new movie is so great",
"Do you like pizza?",
"The cat sits",
];
let device = Device::Cpu;
let encoder =
SentenceTransformer::from_repo_string("Snowflake/snowflake-arctic-embed-xs", &device)?;

let pooling_strategy = PoolingStrategy::Mean;
let embeddings = encoder.encode_batch(sentences.clone(), false, pooling_strategy)?;
println!("Embeddings: {:?}", embeddings);

let (n_sentences, _) = embeddings.dims2()?;
let mut similarities = Vec::with_capacity(n_sentences * (n_sentences - 1) / 2);

for i in 0..n_sentences {
let e_i = embeddings.get(i)?;
for j in (i + 1)..n_sentences {
let e_j = embeddings.get(j)?;
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
similarities.push((cosine_similarity, i, j))
}
}

similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
for &(score, i, j) in similarities[..5].iter() {
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
}

Ok(ExitCode::SUCCESS)
}
1 change: 1 addition & 0 deletions crates/glowrs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
mod error;
mod exports;
pub mod model;

pub use exports::*;

pub use crate::error::Error;
Expand Down

0 comments on commit 385b78c

Please sign in to comment.