Skip to content

Commit

Permalink
feat: add support for gemini llm via langchain
Browse files Browse the repository at this point in the history
  • Loading branch information
jaredpek committed Sep 8, 2024
1 parent 643bd78 commit 62ef8e8
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 20 deletions.
20 changes: 12 additions & 8 deletions flowsettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,17 @@
},
"default": False,
}
# KH_LLMS["gemini"] = {
# "spec": {
# "__type__": "kotaemon.llms.chats.LCGeminiChat",
# "model_name": "gemini-1.5-pro",
# "api_key": "your-key",
# },
# "default": False,
# }

KH_LLMS["gemini"] = {
"spec": {
"__type__": "kotaemon.llms.chats.LCChatGemini",
"model": "models/gemini-1.5-flash",
"google_api_key": "your-key",
"temperature": 0.7,
},
"default": False,
}

KH_LLMS["groq"] = {
"spec": {
"__type__": "kotaemon.llms.ChatOpenAI",
Expand All @@ -211,6 +214,7 @@
},
"default": False,
}

# KH_EMBEDDINGS["huggingface"] = {
# "spec": {
# "__type__": "kotaemon.embeddings.LCHuggingFaceEmbeddings",
Expand Down
4 changes: 2 additions & 2 deletions libs/kotaemon/kotaemon/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
EndpointChatLLM,
LCAnthropicChat,
LCAzureChatOpenAI,
LCChatGemini,
LCChatOpenAI,
LCGeminiChat,
LlamaCppChat,
)
from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI
Expand All @@ -30,7 +30,7 @@
"AzureChatOpenAI",
"ChatOpenAI",
"LCAnthropicChat",
"LCGeminiChat",
"LCChatGemini",
"LCAzureChatOpenAI",
"LCChatOpenAI",
"LlamaCppChat",
Expand Down
4 changes: 2 additions & 2 deletions libs/kotaemon/kotaemon/llms/chats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from .langchain_based import (
LCAnthropicChat,
LCAzureChatOpenAI,
LCChatGemini,
LCChatMixin,
LCChatOpenAI,
LCGeminiChat,
)
from .llamacpp import LlamaCppChat
from .openai import AzureChatOpenAI, ChatOpenAI
Expand All @@ -17,7 +17,7 @@
"EndpointChatLLM",
"ChatOpenAI",
"LCAnthropicChat",
"LCGeminiChat",
"LCChatGemini",
"LCChatOpenAI",
"LCAzureChatOpenAI",
"LCChatMixin",
Expand Down
18 changes: 12 additions & 6 deletions libs/kotaemon/kotaemon/llms/chats/langchain_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import AsyncGenerator, Iterator

from kotaemon.base import BaseMessage, HumanMessage, LLMInterface
from kotaemon.platforms.gemini import BaseGeminiChatModel

from .base import ChatLLM

Expand Down Expand Up @@ -247,18 +248,23 @@ def _get_lc_class(self):
return ChatAnthropic


class LCGeminiChat(LCChatMixin, ChatLLM): # type: ignore
class LCChatGemini(LCChatMixin, ChatLLM, BaseGeminiChatModel): # type: ignore
"""Gemini Chat Model
https://python.langchain.com/v0.2/api_reference/google_genai/chat_models/langchain_google_genai.chat_models.ChatGoogleGenerativeAI.html
"""

def __init__(
self,
api_key: str | None = None,
model_name: str | None = None,
google_api_key: str | None = None,
model: str | None = None,
temperature: float = 0.7,
**params,
):
super().__init__(
google_api_key=api_key,
model=model_name,
temperature=temperature,
google_api_key=google_api_key or self.google_api_key,
model=model or self.model,
temperature=temperature or self.temperature,
**params,
)

Expand Down
61 changes: 61 additions & 0 deletions libs/kotaemon/kotaemon/platforms/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Optional

from kotaemon.base import Param


def get_model(type):
types = {"embedding": "text-embedding-004", "chat": "gemini-1.5-flash"}
return Param(
f"models/{types.get(type)}",
help=(
f"The name of the required {type} model "
"(https://ai.google.dev/gemini-api/docs/models/gemini). "
"Must be in the format 'model/...'."
),
)


class BaseGeminiModel:
google_api_key: str = Param("", help="The API key received from Gemini.")


class BaseGeminiChatModel(BaseGeminiModel):
model: str = get_model("chat")

temperature: Optional[float] = Param(
0.7,
help="Run inference with this temperature. Must be between 0.0 and 1.0.",
)

top_k: Optional[float] = Param(
None,
help=(
"Decode using top-k sampling by considering the set of "
"top_k most probable tokens. Must be larger than 0."
),
)

top_p: Optional[float] = Param(
None,
help=(
"Decode using nucleus samplingby considering the smallest set of "
"tokens whose probability sum is at least top_p. "
"Must be between 0.0 and 1.0."
),
)

max_output_tokens: Optional[float] = Param(
None,
help=(
"Maximum number of tokens to include in a candidate. "
"Must be greater than zero. If unset, will default to 64."
),
)

max_retries: Optional[float] = Param(
None, help="The maximum number of retries to make when generating."
)

timeout: Optional[float] = Param(
None, help="The maximum number of seconds to wait for a response."
)
4 changes: 2 additions & 2 deletions libs/ktem/ktem/llms/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ def load_vendors(self):
AzureChatOpenAI,
ChatOpenAI,
LCAnthropicChat,
LCGeminiChat,
LCChatGemini,
LlamaCppChat,
)

self._vendors = [
ChatOpenAI,
AzureChatOpenAI,
LCAnthropicChat,
LCGeminiChat,
LCChatGemini,
LlamaCppChat,
]

Expand Down

0 comments on commit 62ef8e8

Please sign in to comment.