From d8e3ad119368ec9a090a619428fb35c608f7f0de Mon Sep 17 00:00:00 2001 From: Zhang JianAo Date: Thu, 13 Jul 2023 18:29:31 +0800 Subject: [PATCH] Add sentence-transformers embedding support --- datastore/datastore.py | 4 +- datastore/factory.py | 9 ++-- datastore/providers/analyticdb_datastore.py | 6 +-- datastore/providers/pinecone_datastore.py | 4 +- pyproject.toml | 1 + services/chunks.py | 4 +- services/embedding.py | 59 +++++++++++++++++++++ 7 files changed, 75 insertions(+), 12 deletions(-) create mode 100644 services/embedding.py diff --git a/datastore/datastore.py b/datastore/datastore.py index ff0c79dd8..e1510356a 100644 --- a/datastore/datastore.py +++ b/datastore/datastore.py @@ -11,7 +11,7 @@ QueryWithEmbedding, ) from services.chunks import get_document_chunks -from services.openai import get_embeddings +from services.embedding import Embedding class DataStore(ABC): @@ -56,7 +56,7 @@ async def query(self, queries: List[Query]) -> List[QueryResult]: """ # get a list of of just the queries from the Query list query_texts = [query.query for query in queries] - query_embeddings = get_embeddings(query_texts) + query_embeddings = Embedding.instance().get_embeddings(query_texts) # hydrate the queries with embeddings queries_with_embeddings = [ QueryWithEmbedding(**query.dict(), embedding=embedding) diff --git a/datastore/factory.py b/datastore/factory.py index adde49d76..e59865429 100644 --- a/datastore/factory.py +++ b/datastore/factory.py @@ -1,4 +1,5 @@ from datastore.datastore import DataStore +from services.embedding import Embedding import os @@ -6,6 +7,8 @@ async def get_datastore() -> DataStore: datastore = os.environ.get("DATASTORE") assert datastore is not None + dimension = Embedding.instance().dimension + match datastore: case "chroma": from datastore.providers.chroma_datastore import ChromaDataStore @@ -19,7 +22,7 @@ async def get_datastore() -> DataStore: case "pinecone": from datastore.providers.pinecone_datastore import PineconeDataStore - return PineconeDataStore() + return PineconeDataStore(dimension=dimension) case "weaviate": from datastore.providers.weaviate_datastore import WeaviateDataStore @@ -35,11 +38,11 @@ async def get_datastore() -> DataStore: case "redis": from datastore.providers.redis_datastore import RedisDataStore - return await RedisDataStore.init() + return await RedisDataStore.init(dim=dimension) case "qdrant": from datastore.providers.qdrant_datastore import QdrantDataStore - return QdrantDataStore() + return QdrantDataStore(vector_size=dimension) case "azuresearch": from datastore.providers.azuresearch_datastore import AzureSearchDataStore diff --git a/datastore/providers/analyticdb_datastore.py b/datastore/providers/analyticdb_datastore.py index ba206f2e1..2cc5c7848 100644 --- a/datastore/providers/analyticdb_datastore.py +++ b/datastore/providers/analyticdb_datastore.py @@ -30,17 +30,17 @@ "host": os.environ.get("PG_HOST", "localhost"), "port": int(os.environ.get("PG_PORT", "5432")), } -OUTPUT_DIM = 1536 class AnalyticDBDataStore(DataStore): - def __init__(self, config: Dict[str, str] = PG_CONFIG): + def __init__(self, config: Dict[str, str] = PG_CONFIG, dimension=1536): self.collection_name = config["collection"] self.user = config["user"] self.password = config["password"] self.database = config["database"] self.host = config["host"] self.port = config["port"] + self.dimension = dimension self.connection_pool = SimpleConnectionPool( minconn=1, @@ -99,7 +99,7 @@ def _create_embedding_index(self, cur: psycopg2.extensions.cursor): USING ann(embedding) WITH ( distancemeasure=L2, - dim=OUTPUT_DIM, + dim={self.dimension}, pq_segments=64, hnsw_m=100, pq_centers=2048 diff --git a/datastore/providers/pinecone_datastore.py b/datastore/providers/pinecone_datastore.py index c10ee2bea..03ef6f39d 100644 --- a/datastore/providers/pinecone_datastore.py +++ b/datastore/providers/pinecone_datastore.py @@ -33,7 +33,7 @@ class PineconeDataStore(DataStore): - def __init__(self): + def __init__(self, dimension=1536): # Check if the index name is specified and exists in Pinecone if PINECONE_INDEX and PINECONE_INDEX not in pinecone.list_indexes(): @@ -47,7 +47,7 @@ def __init__(self): ) pinecone.create_index( PINECONE_INDEX, - dimension=1536, # dimensionality of OpenAI ada v2 embeddings + dimension=dimension, metadata_config={"indexed": fields_to_index}, ) self.index = pinecone.Index(PINECONE_INDEX) diff --git a/pyproject.toml b/pyproject.toml index 0a8f588ee..bc02dae3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ azure-search-documents = {version = "11.4.0a20230509004", source = "azure-sdk-de pgvector = "^0.1.7" psycopg2cffi = {version = "^2.9.0", optional = true} loguru = "^0.7.0" +sentence-transformers = "^2.2.2" [tool.poetry.scripts] start = "server.main:start" diff --git a/services/chunks.py b/services/chunks.py index 7e6ac32ed..d49fe39b4 100644 --- a/services/chunks.py +++ b/services/chunks.py @@ -5,7 +5,7 @@ import tiktoken -from services.openai import get_embeddings +from services.embedding import Embedding # Global variables tokenizer = tiktoken.get_encoding( @@ -190,7 +190,7 @@ def get_document_chunks( ] # Get the embeddings for the batch texts - batch_embeddings = get_embeddings(batch_texts) + batch_embeddings = Embedding.instance().get_embeddings(batch_texts) # Append the batch embeddings to the embeddings list embeddings.extend(batch_embeddings) diff --git a/services/embedding.py b/services/embedding.py new file mode 100644 index 000000000..019aeda55 --- /dev/null +++ b/services/embedding.py @@ -0,0 +1,59 @@ +import os +from typing import List +from abc import ABC, abstractmethod + +import services.openai + +EMBEDDING_ENGINE = 'openai' +_instance = None + + +class Embedding(ABC): + @property + @abstractmethod + def dimension(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_embeddings(self, texts: List[str]) -> List[List[float]]: + raise NotImplementedError + + @staticmethod + def instance() -> 'Embedding': + global _instance + + if _instance is None: + datastore = os.getenv("EMBEDDING_ENGINE", EMBEDDING_ENGINE) + if datastore == 'openai': + _instance = OpenaiEmbedding() + elif datastore == 'sentence': + _instance = SentenceTransformerEmbedding() + else: + raise ValueError( + f"Unsupported embedding engine: {datastore}. " + f"Try one of the following: openai, sentence" + ) + + return _instance + + +class OpenaiEmbedding(Embedding): + @property + def dimension(self) -> int: + return 1536 + + def get_embeddings(self, texts: List[str]) -> List[List[float]]: + return services.openai.get_embeddings(texts) + + +class SentenceTransformerEmbedding(Embedding): + def __init__(self): + from sentence_transformers import SentenceTransformer + self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') + + @property + def dimension(self) -> int: + return 384 + + def get_embeddings(self, texts: List[str]) -> List[List[float]]: + return self.model.encode(texts).tolist()