Skip to content

Commit

Permalink
✨ Added VLLMOpenAI
Browse files Browse the repository at this point in the history
Signed-off-by: JonahSussman <sussmanjonah@gmail.com>
  • Loading branch information
JonahSussman committed Feb 18, 2025
1 parent 47ee50b commit 87c609c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
2 changes: 2 additions & 0 deletions .trunk/configs/custom-words.txt
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ Scrapy
sdist
semconv
sessionbean
SHFT
shurley
sirupsen
smallrye
Expand All @@ -145,6 +146,7 @@ tiiuae
tracesdk
upperbound
venv
vllm
webassets
webmvc
SHFT
Expand Down
31 changes: 25 additions & 6 deletions kai/llm_interfacing/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from langchain_aws import ChatBedrock
from langchain_community.chat_models.fake import FakeListChatModel
from langchain_core.language_models.base import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_community.llms.vllm import VLLMOpenAI
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
from langchain_core.messages import BaseMessage
from langchain_core.runnables import ConfigurableField, RunnableConfig
from langchain_deepseek import ChatDeepSeek
Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(
self.demo_mode: bool = demo_mode
self.cache = cache

model_class: type[BaseChatModel]
model_class: type[BaseLanguageModel[Any]]
defaults: dict[str, Any]
model_args: dict[str, Any]
model_id: str
Expand Down Expand Up @@ -173,11 +173,19 @@ def _get_request_payload(
model_args = deep_update(defaults, config.args)
model_id = model_args["model"]

case "VLLMOpenAI":
model_class = VLLMOpenAI

defaults = {}

model_args = deep_update(defaults, config.args)
model_id = model_args["model_name"]

case _:
raise Exception(f"Unrecognized provider '{config.provider}'")

self.provider_id: str = config.provider
self.llm: BaseChatModel = model_class(**model_args)
self.llm: BaseLanguageModel[Any] = model_class(**model_args)
self.model_id: str = model_id

if config.template is None:
Expand Down Expand Up @@ -239,7 +247,9 @@ def invoke(
invoke_llm = self.llm

if not (self.cache and cache_path_resolver):
return invoke_llm.invoke(input, config, stop=stop, **kwargs)
return self.response_to_base_message(
invoke_llm.invoke(input, config, stop=stop, **kwargs)
)

cache_path = cache_path_resolver.cache_path()
cache_meta = cache_path_resolver.cache_meta()
Expand All @@ -263,4 +273,13 @@ def invoke(
# only raise an exception when we are in demo mode
if self.demo_mode:
raise e
return response

return self.response_to_base_message(response)

def response_to_base_message(self, llm_response: Any) -> BaseMessage:
if isinstance(llm_response, str):
return BaseMessage(content=llm_response)
elif isinstance(llm_response, BaseMessage):
return llm_response

raise ValueError(f"Unexpected LLM response type: {type(llm_response)}")

0 comments on commit 87c609c

Please sign in to comment.