Skip to content

Commit

Permalink
feat: add support for gemini embedding model via langchain
Browse files Browse the repository at this point in the history
  • Loading branch information
jaredpek committed Sep 8, 2024
1 parent 62ef8e8 commit 671a3d8
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 3 deletions.
10 changes: 10 additions & 0 deletions flowsettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,16 @@
"default": False,
}

KH_EMBEDDINGS["gemini"] = {
"spec": {
"__type__": "kotaemon.embeddings.LCGeminiEmbeddings",
"model": "models/text-embedding-004",
"google_api_key": "your-key",
"task_type": "retrieval_document",
},
"default": False,
}

# KH_EMBEDDINGS["huggingface"] = {
# "spec": {
# "__type__": "kotaemon.embeddings.LCHuggingFaceEmbeddings",
Expand Down
6 changes: 4 additions & 2 deletions libs/kotaemon/kotaemon/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,21 @@
from .langchain_based import (
LCAzureOpenAIEmbeddings,
LCCohereEmbeddings,
LCGeminiEmbeddings,
LCHuggingFaceEmbeddings,
LCOpenAIEmbeddings,
)
from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings

__all__ = [
"AzureOpenAIEmbeddings",
"BaseEmbeddings",
"EndpointEmbeddings",
"FastEmbedEmbeddings",
"LCOpenAIEmbeddings",
"LCAzureOpenAIEmbeddings",
"LCCohereEmbeddings",
"LCGeminiEmbeddings",
"LCHuggingFaceEmbeddings",
"OpenAIEmbeddings",
"AzureOpenAIEmbeddings",
"FastEmbedEmbeddings",
]
30 changes: 30 additions & 0 deletions libs/kotaemon/kotaemon/embeddings/langchain_based.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional

from kotaemon.base import Document, DocumentWithEmbedding
from kotaemon.platforms.gemini import BaseGeminiEmbeddingModel

from .base import BaseEmbeddings

Expand Down Expand Up @@ -187,6 +188,35 @@ def _get_lc_class(self):
return CohereEmbeddings


class LCGeminiEmbeddings(LCEmbeddingMixin, BaseEmbeddings, BaseGeminiEmbeddingModel):
"""Wrapper around Langchain's Gemini embedding, focusing on key parameters
https://python.langchain.com/v0.2/api_reference/google_genai/embeddings/langchain_google_genai.embeddings.GoogleGenerativeAIEmbeddings.html
"""

def __init__(
self,
model: Optional[str] = None,
google_api_key: Optional[str] = None,
task_type: Optional[str] = None,
**params,
):
super().__init__(
model=model or self.model,
google_api_key=google_api_key or self.google_api_key,
task_type=task_type or self.task_type,
**params,
)

def _get_lc_class(self):
try:
from langchain_google_genai import GoogleGenerativeAIEmbeddings
except ImportError:
raise ImportError("Please install langchain-google-genai")

return GoogleGenerativeAIEmbeddings


class LCHuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's HuggingFace embedding, focusing on key parameters"""

Expand Down
14 changes: 14 additions & 0 deletions libs/kotaemon/kotaemon/platforms/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,20 @@ class BaseGeminiModel:
google_api_key: str = Param("", help="The API key received from Gemini.")


class BaseGeminiEmbeddingModel(BaseGeminiModel):
model: str = get_model("embedding")

task_type: str = Param(
"retrieval_document",
help=(
"The valid task type for embedding, "
"choices are 'retrieval_document' (default), 'retrieval_query', "
"'semantic_similarity', 'classification' 'clustering' and "
"'task_type_unspecified'."
),
)


class BaseGeminiChatModel(BaseGeminiModel):
model: str = get_model("chat")

Expand Down
4 changes: 3 additions & 1 deletion libs/ktem/ktem/embeddings/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,18 @@ def load_vendors(self):
AzureOpenAIEmbeddings,
FastEmbedEmbeddings,
LCCohereEmbeddings,
LCGeminiEmbeddings,
LCHuggingFaceEmbeddings,
OpenAIEmbeddings,
)

self._vendors = [
AzureOpenAIEmbeddings,
OpenAIEmbeddings,
FastEmbedEmbeddings,
LCCohereEmbeddings,
LCGeminiEmbeddings,
LCHuggingFaceEmbeddings,
OpenAIEmbeddings,
]

def __getitem__(self, key: str) -> BaseEmbeddings:
Expand Down

0 comments on commit 671a3d8

Please sign in to comment.