From 87c609cbda37923fc474a6d0954e3a767f09f0d1 Mon Sep 17 00:00:00 2001 From: JonahSussman Date: Tue, 18 Feb 2025 14:39:44 -0500 Subject: [PATCH] :sparkles: Added VLLMOpenAI Signed-off-by: JonahSussman --- .trunk/configs/custom-words.txt | 2 ++ kai/llm_interfacing/model_provider.py | 31 +++++++++++++++++++++------ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/.trunk/configs/custom-words.txt b/.trunk/configs/custom-words.txt index f7bbf59c..2ec2976f 100644 --- a/.trunk/configs/custom-words.txt +++ b/.trunk/configs/custom-words.txt @@ -121,6 +121,7 @@ Scrapy sdist semconv sessionbean +SHFT shurley sirupsen smallrye @@ -145,6 +146,7 @@ tiiuae tracesdk upperbound venv +vllm webassets webmvc SHFT diff --git a/kai/llm_interfacing/model_provider.py b/kai/llm_interfacing/model_provider.py index 83d4bafa..d4abe499 100644 --- a/kai/llm_interfacing/model_provider.py +++ b/kai/llm_interfacing/model_provider.py @@ -5,8 +5,8 @@ from langchain_aws import ChatBedrock from langchain_community.chat_models.fake import FakeListChatModel -from langchain_core.language_models.base import LanguageModelInput -from langchain_core.language_models.chat_models import BaseChatModel +from langchain_community.llms.vllm import VLLMOpenAI +from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput from langchain_core.messages import BaseMessage from langchain_core.runnables import ConfigurableField, RunnableConfig from langchain_deepseek import ChatDeepSeek @@ -36,7 +36,7 @@ def __init__( self.demo_mode: bool = demo_mode self.cache = cache - model_class: type[BaseChatModel] + model_class: type[BaseLanguageModel[Any]] defaults: dict[str, Any] model_args: dict[str, Any] model_id: str @@ -173,11 +173,19 @@ def _get_request_payload( model_args = deep_update(defaults, config.args) model_id = model_args["model"] + case "VLLMOpenAI": + model_class = VLLMOpenAI + + defaults = {} + + model_args = deep_update(defaults, config.args) + model_id = model_args["model_name"] + case _: raise Exception(f"Unrecognized provider '{config.provider}'") self.provider_id: str = config.provider - self.llm: BaseChatModel = model_class(**model_args) + self.llm: BaseLanguageModel[Any] = model_class(**model_args) self.model_id: str = model_id if config.template is None: @@ -239,7 +247,9 @@ def invoke( invoke_llm = self.llm if not (self.cache and cache_path_resolver): - return invoke_llm.invoke(input, config, stop=stop, **kwargs) + return self.response_to_base_message( + invoke_llm.invoke(input, config, stop=stop, **kwargs) + ) cache_path = cache_path_resolver.cache_path() cache_meta = cache_path_resolver.cache_meta() @@ -263,4 +273,13 @@ def invoke( # only raise an exception when we are in demo mode if self.demo_mode: raise e - return response + + return self.response_to_base_message(response) + + def response_to_base_message(self, llm_response: Any) -> BaseMessage: + if isinstance(llm_response, str): + return BaseMessage(content=llm_response) + elif isinstance(llm_response, BaseMessage): + return llm_response + + raise ValueError(f"Unexpected LLM response type: {type(llm_response)}")