-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixed tokenizer strategy; added example; added comparison test w.r.t.…
… sentence-transformers
- Loading branch information
1 parent
2d74134
commit 51ccb8e
Showing
11 changed files
with
202 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
/target | ||
Cargo.lock | ||
.idea | ||
.idea | ||
|
||
traces/ | ||
**/trace-*.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
use candle_core::Tensor; | ||
use serde::Deserialize; | ||
use std::process::ExitCode; | ||
|
||
use glowrs::model::utils::normalize_l2; | ||
use glowrs::{PoolingStrategy, Result}; | ||
|
||
#[derive(Deserialize)] | ||
struct EmbeddingsExample { | ||
sentence: String, | ||
embedding: Vec<f32>, | ||
} | ||
|
||
#[derive(Deserialize)] | ||
struct EmbeddingsFixture { | ||
model: String, | ||
examples: Vec<EmbeddingsExample>, | ||
} | ||
|
||
#[derive(Deserialize)] | ||
struct Examples { | ||
fixtures: Vec<EmbeddingsFixture>, | ||
} | ||
|
||
#[test] | ||
fn test_similarity_sentence_transformers() -> Result<ExitCode> { | ||
use approx::assert_relative_eq; | ||
let examples: Examples = | ||
serde_json::from_str(include_str!("./fixtures/embeddings/examples.json"))?; | ||
let device = glowrs::Device::Cpu; | ||
for fixture in examples.fixtures { | ||
let encoder = glowrs::SentenceTransformer::from_repo_string(&fixture.model, &device)?; | ||
println!("Loaded model: {}", &fixture.model); | ||
for example in fixture.examples { | ||
let embedding = | ||
encoder.encode_batch(vec![example.sentence], false, PoolingStrategy::Mean)?; | ||
let embedding = normalize_l2(&embedding)?; | ||
|
||
let expected_dim = example.embedding.len(); | ||
let expected = Tensor::from_vec(example.embedding, (1, expected_dim), &device)?; | ||
let expected = normalize_l2(&expected)?; | ||
|
||
assert_eq!(embedding.dims(), expected.dims()); | ||
|
||
let sim = embedding.matmul(&expected.t()?)?.squeeze(1)?; | ||
|
||
let sim = sim.to_vec1::<f32>()?; | ||
let sim = sim.first().expect("Expected a value"); | ||
println!("Similarity: {}", sim); | ||
assert_relative_eq!(*sim, 1.0, epsilon = 1e-3); | ||
} | ||
println!("Passed all examples for model: {}", &fixture.model) | ||
} | ||
|
||
Ok(ExitCode::SUCCESS) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import json | ||
from sentence_transformers import SentenceTransformer | ||
|
||
SENTENCES = [ | ||
"The cat sits outside", | ||
"A man is playing guitar", | ||
"I love pasta", | ||
"The new movie is awesome", | ||
"The cat plays in the garden", | ||
"A woman watches TV", | ||
"The new movie is so great", | ||
"Do you like pizza?", | ||
"The cat sits", | ||
] | ||
|
||
MODELS = [ | ||
"jinaai/jina-embeddings-v2-small-en", | ||
"sentence-transformers/all-MiniLM-L6-v2", | ||
"sentence-transformers/multi-qa-distilbert-cos-v1", | ||
] | ||
|
||
|
||
def generate_examples(model: str) -> list: | ||
model = SentenceTransformer(model, trust_remote_code=True) | ||
embeddings = model.encode(SENTENCES, normalize_embeddings=False, batch_size=len(SENTENCES)) | ||
return [ | ||
{"sentence": sentence, "embedding": embedding.tolist()} for sentence, embedding in zip(SENTENCES, embeddings) | ||
] | ||
|
||
|
||
if __name__ == "__main__": | ||
out = { | ||
"fixtures": [ | ||
{ | ||
"model": m, | ||
"examples": generate_examples(m) | ||
|
||
} for m in MODELS] | ||
} | ||
|
||
with open("crates/glowrs/tests/fixtures/embeddings/examples.json", "w") as f: | ||
json.dump(out, f) |