Skip to content

[Enhancement] Add sentence-transformers embedding support #334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions datastore/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
QueryWithEmbedding,
)
from services.chunks import get_document_chunks
from services.openai import get_embeddings
from services.embedding import Embedding


class DataStore(ABC):
Expand Down Expand Up @@ -56,7 +56,7 @@ async def query(self, queries: List[Query]) -> List[QueryResult]:
"""
# get a list of of just the queries from the Query list
query_texts = [query.query for query in queries]
query_embeddings = get_embeddings(query_texts)
query_embeddings = Embedding.instance().get_embeddings(query_texts)
# hydrate the queries with embeddings
queries_with_embeddings = [
QueryWithEmbedding(**query.dict(), embedding=embedding)
Expand Down
9 changes: 6 additions & 3 deletions datastore/factory.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from datastore.datastore import DataStore
from services.embedding import Embedding
import os


async def get_datastore() -> DataStore:
datastore = os.environ.get("DATASTORE")
assert datastore is not None

dimension = Embedding.instance().dimension

match datastore:
case "chroma":
from datastore.providers.chroma_datastore import ChromaDataStore
Expand All @@ -19,7 +22,7 @@ async def get_datastore() -> DataStore:
case "pinecone":
from datastore.providers.pinecone_datastore import PineconeDataStore

return PineconeDataStore()
return PineconeDataStore(dimension=dimension)
case "weaviate":
from datastore.providers.weaviate_datastore import WeaviateDataStore

Expand All @@ -35,11 +38,11 @@ async def get_datastore() -> DataStore:
case "redis":
from datastore.providers.redis_datastore import RedisDataStore

return await RedisDataStore.init()
return await RedisDataStore.init(dim=dimension)
case "qdrant":
from datastore.providers.qdrant_datastore import QdrantDataStore

return QdrantDataStore()
return QdrantDataStore(vector_size=dimension)
case "azuresearch":
from datastore.providers.azuresearch_datastore import AzureSearchDataStore

Expand Down
6 changes: 3 additions & 3 deletions datastore/providers/analyticdb_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@
"host": os.environ.get("PG_HOST", "localhost"),
"port": int(os.environ.get("PG_PORT", "5432")),
}
OUTPUT_DIM = 1536


class AnalyticDBDataStore(DataStore):
def __init__(self, config: Dict[str, str] = PG_CONFIG):
def __init__(self, config: Dict[str, str] = PG_CONFIG, dimension=1536):
self.collection_name = config["collection"]
self.user = config["user"]
self.password = config["password"]
self.database = config["database"]
self.host = config["host"]
self.port = config["port"]
self.dimension = dimension

self.connection_pool = SimpleConnectionPool(
minconn=1,
Expand Down Expand Up @@ -99,7 +99,7 @@ def _create_embedding_index(self, cur: psycopg2.extensions.cursor):
USING ann(embedding)
WITH (
distancemeasure=L2,
dim=OUTPUT_DIM,
dim={self.dimension},
pq_segments=64,
hnsw_m=100,
pq_centers=2048
Expand Down
4 changes: 2 additions & 2 deletions datastore/providers/pinecone_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@


class PineconeDataStore(DataStore):
def __init__(self):
def __init__(self, dimension=1536):
# Check if the index name is specified and exists in Pinecone
if PINECONE_INDEX and PINECONE_INDEX not in pinecone.list_indexes():

Expand All @@ -47,7 +47,7 @@ def __init__(self):
)
pinecone.create_index(
PINECONE_INDEX,
dimension=1536, # dimensionality of OpenAI ada v2 embeddings
dimension=dimension,
metadata_config={"indexed": fields_to_index},
)
self.index = pinecone.Index(PINECONE_INDEX)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ azure-search-documents = {version = "11.4.0a20230509004", source = "azure-sdk-de
pgvector = "^0.1.7"
psycopg2cffi = {version = "^2.9.0", optional = true}
loguru = "^0.7.0"
sentence-transformers = "^2.2.2"

[tool.poetry.scripts]
start = "server.main:start"
Expand Down
4 changes: 2 additions & 2 deletions services/chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import tiktoken

from services.openai import get_embeddings
from services.embedding import Embedding

# Global variables
tokenizer = tiktoken.get_encoding(
Expand Down Expand Up @@ -190,7 +190,7 @@ def get_document_chunks(
]

# Get the embeddings for the batch texts
batch_embeddings = get_embeddings(batch_texts)
batch_embeddings = Embedding.instance().get_embeddings(batch_texts)

# Append the batch embeddings to the embeddings list
embeddings.extend(batch_embeddings)
Expand Down
59 changes: 59 additions & 0 deletions services/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
from typing import List
from abc import ABC, abstractmethod

import services.openai

EMBEDDING_ENGINE = 'openai'
_instance = None


class Embedding(ABC):
@property
@abstractmethod
def dimension(self) -> int:
raise NotImplementedError

@abstractmethod
def get_embeddings(self, texts: List[str]) -> List[List[float]]:
raise NotImplementedError

@staticmethod
def instance() -> 'Embedding':
global _instance

if _instance is None:
datastore = os.getenv("EMBEDDING_ENGINE", EMBEDDING_ENGINE)
if datastore == 'openai':
_instance = OpenaiEmbedding()
elif datastore == 'sentence':
_instance = SentenceTransformerEmbedding()
else:
raise ValueError(
f"Unsupported embedding engine: {datastore}. "
f"Try one of the following: openai, sentence"
)

return _instance


class OpenaiEmbedding(Embedding):
@property
def dimension(self) -> int:
return 1536

def get_embeddings(self, texts: List[str]) -> List[List[float]]:
return services.openai.get_embeddings(texts)


class SentenceTransformerEmbedding(Embedding):
def __init__(self):
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

@property
def dimension(self) -> int:
return 384

def get_embeddings(self, texts: List[str]) -> List[List[float]]:
return self.model.encode(texts).tolist()