Skip to content

Commit

Permalink
Fixed tokenizer strategy; added example; added comparison test w.r.t.…
Browse files Browse the repository at this point in the history
… sentence-transformers
  • Loading branch information
wdoppenberg committed Apr 26, 2024
1 parent 2d74134 commit 51ccb8e
Show file tree
Hide file tree
Showing 11 changed files with 202 additions and 20 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
/target
Cargo.lock
.idea
.idea

traces/
**/trace-*.json
3 changes: 3 additions & 0 deletions crates/glowrs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ tokenizers = "0.19.1"
hf-hub = { version = "0.3.2", features = ["tokio"] }
thiserror = "1.0.56"
once_cell = "1.19.0"
clap = { version = "4.5.4", features = ["derive"] }

[features]
default = []
Expand All @@ -35,4 +36,6 @@ cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
dirs = "5.0.1"
tempfile = "3.10.1"
approx = "0.5.1"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
tracing-chrome = "0.7.2"

34 changes: 29 additions & 5 deletions crates/glowrs/examples/simple.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
use clap::Parser;
use glowrs::{Device, Error, PoolingStrategy, SentenceTransformer};
use std::process::ExitCode;
use tracing_subscriber::prelude::*;

#[derive(Debug, Parser)]
pub struct App {
#[clap(short, long, default_value = "jinaai/jina-embeddings-v2-small-en")]
pub model_repo: String,

#[clap(short, long, default_value = "debug")]
pub log_level: String,
}

fn main() -> Result<ExitCode, Error> {
let sentences = vec![
let app = App::parse();

let sentences = [
"The cat sits outside",
"A man is playing guitar",
"I love pasta",
Expand All @@ -11,14 +24,25 @@ fn main() -> Result<ExitCode, Error> {
"A woman watches TV",
"The new movie is so great",
"Do you like pizza?",
"The cat sits",
];

tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
eprintln!("No environment variables found that can initialize tracing_subscriber::EnvFilter. Using defaults.");
// axum logs rejections from built-in extractors with the `axum::rejection`
// target, at `TRACE` level. `axum::rejection=trace` enables showing those events
format!("glowrs={},tower_http=debug,axum::rejection=trace", app.log_level).into()
}),
)
.with(tracing_subscriber::fmt::layer()).init();

println!("Using model {}", app.model_repo);
let device = Device::Cpu;
let encoder =
SentenceTransformer::from_repo_string("Snowflake/snowflake-arctic-embed-xs", &device)?;
let encoder = SentenceTransformer::from_repo_string(&app.model_repo, &device)?;

let pooling_strategy = PoolingStrategy::Mean;
let embeddings = encoder.encode_batch(sentences.clone(), false, pooling_strategy)?;
let embeddings = encoder.encode_batch(sentences.into(), false, pooling_strategy)?;
println!("Embeddings: {:?}", embeddings);

let (n_sentences, _) = embeddings.dims2()?;
Expand Down
4 changes: 3 additions & 1 deletion crates/glowrs/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ pub enum Error {
ModelLoad(&'static str),
#[error("Invalid model architecture: {0}")]
InvalidModelConfig(&'static str),
#[error("Inference error: {0}")]
InferenceError(&'static str),
#[error("Candle error: {0}")]
Candle(#[from] candle_core::Error),
#[error("Tokenization error: {0}")]
Expand All @@ -20,7 +22,7 @@ pub enum Error {
HFHub(#[from] hf_hub::api::sync::ApiError),
}

pub(crate) type Result<T> = std::result::Result<T, Error>;
pub type Result<T> = std::result::Result<T, Error>;

#[cfg(test)]
mod test {
Expand Down
3 changes: 1 addition & 2 deletions crates/glowrs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ pub mod model;

pub use exports::*;

pub use crate::error::Error;
pub(crate) use error::Result;
pub use crate::error::{Error, Result};

pub use model::pooling::PoolingStrategy;
pub use model::sentence_transformer::SentenceTransformer;
Expand Down
25 changes: 21 additions & 4 deletions crates/glowrs/src/model/embedder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use candle_transformers::models::{
bert::Config as BertConfig, distilbert::Config as DistilBertConfig,
jina_bert::Config as JinaBertConfig,
};
use serde::Deserialize;
use std::ops::Deref;
use std::path::Path;
use tokenizers::{EncodeInput, Tokenizer};
Expand All @@ -12,7 +13,6 @@ use tokenizers::{EncodeInput, Tokenizer};
pub use candle_transformers::models::{
bert::BertModel, distilbert::DistilBertModel, jina_bert::BertModel as JinaBertModel,
};
use serde::Deserialize;

use crate::model::pooling::{pool_embeddings, PoolingStrategy};
use crate::model::utils::normalize_l2;
Expand Down Expand Up @@ -129,8 +129,15 @@ impl EmbedderModel for JinaBertModel {
impl EmbedderModel for DistilBertModel {
#[inline]
fn encode(&self, token_ids: &Tensor) -> Result<Tensor> {
let attention_mask = token_ids.ones_like()?;
Ok(self.forward(token_ids, &attention_mask)?)
let size = token_ids.dim(0)?;

let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
.collect();

let mask = Tensor::from_slice(&mask, (size, size), token_ids.device())?;

Ok(self.forward(token_ids, &mask)?)
}

fn get_device(&self) -> &Device {
Expand Down Expand Up @@ -179,18 +186,28 @@ where
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();

Tensor::new(tokens.as_slice(), model.get_device())
})
.collect::<candle_core::Result<Vec<_>>>()?;

let token_ids = Tensor::stack(&token_ids, 0)?;

let pad_id: u32;
if let Some(pp) = tokenizer.get_padding() {
pad_id = pp.pad_id;
} else {
pad_id = 0;
}

let pad_mask = token_ids.ne(pad_id)?;

tracing::trace!("running inference on batch {:?}", token_ids.shape());
let embeddings = model.encode(&token_ids)?;
tracing::trace!("generated embeddings {:?}", embeddings.shape());

// Apply pooling
let pooled_embeddings = pool_embeddings(&embeddings, pooling_strategy)?;
let pooled_embeddings = pool_embeddings(&embeddings, &pad_mask, pooling_strategy)?;

// Normalize embeddings (if required)
let embeddings = if normalize {
Expand Down
17 changes: 11 additions & 6 deletions crates/glowrs/src/model/pooling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,22 @@ pub enum PoolingStrategy {
Sum,
}

pub fn pool_embeddings(embeddings: &Tensor, strategy: PoolingStrategy) -> Result<Tensor> {
pub fn pool_embeddings(
embeddings: &Tensor,
pad_mask: &Tensor,
strategy: PoolingStrategy,
) -> Result<Tensor> {
match strategy {
PoolingStrategy::Mean => mean_pooling(embeddings),
PoolingStrategy::Mean => mean_pooling(embeddings, pad_mask),
PoolingStrategy::Max => max_pooling(embeddings),
PoolingStrategy::Sum => sum_pooling(embeddings),
}
}

pub fn mean_pooling(embeddings: &Tensor) -> Result<Tensor> {
let (_, out_tokens, _) = embeddings.dims3()?;
pub fn mean_pooling(embeddings: &Tensor, pad_mask: &Tensor) -> Result<Tensor> {
let out_tokens = pad_mask.sum(1)?.to_vec1::<u8>()?.iter().sum::<u8>() as f64;

Ok((embeddings.sum(1)? / (out_tokens as f64))?)
Ok((embeddings.sum(1)? / (out_tokens))?)
}

pub fn max_pooling(embeddings: &Tensor) -> Result<Tensor> {
Expand All @@ -43,7 +47,8 @@ mod test {
) -> Result<()> {
// 1 sentence, 20 tokens, 32 dimensions
let v = Tensor::ones(&[1, 20, 32], DType::F32, &Device::Cpu)?;
let v_pool = pool_embeddings(&v, strategy)?;
let pad_mask = Tensor::ones(&[1, 20], DType::F32, &Device::Cpu)?;
let v_pool = pool_embeddings(&v, &pad_mask, strategy)?;
let (sent, dim) = v_pool.dims2()?;
assert_eq!(sent, 1);
assert_eq!(dim, 32);
Expand Down
32 changes: 31 additions & 1 deletion crates/glowrs/src/model/sentence_transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,15 @@ impl SentenceTransformer {
/// # }
/// ```
pub fn from_repo_string(repo_string: &str, device: &Device) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "st-from-repo-string");
let _enter = span.enter();
let (model_repo, default_revision) = utils::parse_repo_string(repo_string)?;
Self::from_repo(model_repo, default_revision, device)
}

pub fn from_repo(repo_name: &str, revision: &str, device: &Device) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "st-from-repo");
let _enter = span.enter();
let api = Api::new()?.repo(Repo::with_revision(
repo_name.into(),
RepoType::Model,
Expand All @@ -76,6 +80,8 @@ impl SentenceTransformer {
}

pub fn from_api(api: ApiRepo, device: &Device) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "st-from-api");
let _enter = span.enter();
let model_path = api.get("model.safetensors")?;

let config_path = api.get("config.json")?;
Expand All @@ -91,7 +97,19 @@ impl SentenceTransformer {
tokenizer_path: &Path,
device: &Device,
) -> Result<Self> {
let tokenizer = Tokenizer::from_file(tokenizer_path)?;
let span = tracing::span!(tracing::Level::TRACE, "st-from-path");
let _enter = span.enter();
let mut tokenizer = Tokenizer::from_file(tokenizer_path)?;

if let Some(pp) = tokenizer.get_padding_mut() {
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
} else {
let pp = tokenizers::PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
tokenizer.with_padding(Some(pp));
}

let model = load_pretrained_model(model_path, config_path, device)?;

Expand Down Expand Up @@ -119,6 +137,8 @@ impl SentenceTransformer {
/// # Ok(())
/// # }
pub fn from_folder(folder_path: &Path, device: &Device) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "st-from-folder");
let _enter = span.enter();
// Construct PathBuf objects for model, config, and tokenizer json files
let model_path = folder_path.join("model.safetensors");
let config_path = folder_path.join("config.json");
Expand Down Expand Up @@ -177,6 +197,9 @@ impl SentenceTransformer {
where
E: Into<EncodeInput<'s>> + Send,
{
let span = tracing::span!(tracing::Level::TRACE, "st-encode-batch");
let _enter = span.enter();

let (embeddings, usage) = encode_batch_with_usage(
self.model.as_ref(),
&self.tokenizer,
Expand All @@ -196,6 +219,9 @@ impl SentenceTransformer {
where
E: Into<EncodeInput<'s>> + Send,
{
let span = tracing::span!(tracing::Level::TRACE, "st-encode-batch");
let _enter = span.enter();

encode_batch(
self.model.as_ref(),
&self.tokenizer,
Expand All @@ -204,6 +230,10 @@ impl SentenceTransformer {
normalize,
)
}

pub fn get_tokenizer_mut(&mut self) -> &mut Tokenizer {
&mut self.tokenizer
}
}

#[cfg(test)]
Expand Down
1 change: 1 addition & 0 deletions crates/glowrs/tests/fixtures/embeddings/examples.json

Large diffs are not rendered by default.

56 changes: 56 additions & 0 deletions crates/glowrs/tests/test_embeddings.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
use candle_core::Tensor;
use serde::Deserialize;
use std::process::ExitCode;

use glowrs::model::utils::normalize_l2;
use glowrs::{PoolingStrategy, Result};

#[derive(Deserialize)]
struct EmbeddingsExample {
sentence: String,
embedding: Vec<f32>,
}

#[derive(Deserialize)]
struct EmbeddingsFixture {
model: String,
examples: Vec<EmbeddingsExample>,
}

#[derive(Deserialize)]
struct Examples {
fixtures: Vec<EmbeddingsFixture>,
}

#[test]
fn test_similarity_sentence_transformers() -> Result<ExitCode> {
use approx::assert_relative_eq;
let examples: Examples =
serde_json::from_str(include_str!("./fixtures/embeddings/examples.json"))?;
let device = glowrs::Device::Cpu;
for fixture in examples.fixtures {
let encoder = glowrs::SentenceTransformer::from_repo_string(&fixture.model, &device)?;
println!("Loaded model: {}", &fixture.model);
for example in fixture.examples {
let embedding =
encoder.encode_batch(vec![example.sentence], false, PoolingStrategy::Mean)?;
let embedding = normalize_l2(&embedding)?;

let expected_dim = example.embedding.len();
let expected = Tensor::from_vec(example.embedding, (1, expected_dim), &device)?;
let expected = normalize_l2(&expected)?;

assert_eq!(embedding.dims(), expected.dims());

let sim = embedding.matmul(&expected.t()?)?.squeeze(1)?;

let sim = sim.to_vec1::<f32>()?;
let sim = sim.first().expect("Expected a value");
println!("Similarity: {}", sim);
assert_relative_eq!(*sim, 1.0, epsilon = 1e-3);
}
println!("Passed all examples for model: {}", &fixture.model)
}

Ok(ExitCode::SUCCESS)
}
42 changes: 42 additions & 0 deletions tests/generate-fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import json
from sentence_transformers import SentenceTransformer

SENTENCES = [
"The cat sits outside",
"A man is playing guitar",
"I love pasta",
"The new movie is awesome",
"The cat plays in the garden",
"A woman watches TV",
"The new movie is so great",
"Do you like pizza?",
"The cat sits",
]

MODELS = [
"jinaai/jina-embeddings-v2-small-en",
"sentence-transformers/all-MiniLM-L6-v2",
"sentence-transformers/multi-qa-distilbert-cos-v1",
]


def generate_examples(model: str) -> list:
model = SentenceTransformer(model, trust_remote_code=True)
embeddings = model.encode(SENTENCES, normalize_embeddings=False, batch_size=len(SENTENCES))
return [
{"sentence": sentence, "embedding": embedding.tolist()} for sentence, embedding in zip(SENTENCES, embeddings)
]


if __name__ == "__main__":
out = {
"fixtures": [
{
"model": m,
"examples": generate_examples(m)

} for m in MODELS]
}

with open("crates/glowrs/tests/fixtures/embeddings/examples.json", "w") as f:
json.dump(out, f)

0 comments on commit 51ccb8e

Please sign in to comment.