Skip to content

Commit

Permalink
add structured output to openai
Browse files Browse the repository at this point in the history
  • Loading branch information
bfdykstra committed Jan 6, 2025
1 parent 5343d0d commit 3ebc021
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 3 deletions.
2 changes: 2 additions & 0 deletions libs/kotaemon/kotaemon/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
ExtractorOutput,
HumanMessage,
LLMInterface,
StructuredOutputLLMInterface,
RetrievedDocument,
SystemMessage,
)
Expand All @@ -21,6 +22,7 @@
"HumanMessage",
"RetrievedDocument",
"LLMInterface",
"StructuredOutputLLMInterface",
"ExtractorOutput",
"Param",
"Node",
Expand Down
4 changes: 4 additions & 0 deletions libs/kotaemon/kotaemon/base/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ class LLMInterface(AIMessage):
messages: list[AIMessage] = Field(default_factory=list)
logprobs: list[float] = []

class StructuredOutputLLMInterface(LLMInterface):
parsed: Any
refusal: str = ''


class ExtractorOutput(Document):
"""
Expand Down
2 changes: 2 additions & 0 deletions libs/kotaemon/kotaemon/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
AzureChatOpenAI,
ChatLLM,
ChatOpenAI,
StructuredOutputChatOpenAI,
EndpointChatLLM,
LCAnthropicChat,
LCAzureChatOpenAI,
Expand All @@ -30,6 +31,7 @@
"SystemMessage",
"AzureChatOpenAI",
"ChatOpenAI",
"StructuredOutputChatOpenAI",
"LCAnthropicChat",
"LCGeminiChat",
"LCCohereChat",
Expand Down
3 changes: 2 additions & 1 deletion libs/kotaemon/kotaemon/llms/chats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
LCGeminiChat,
)
from .llamacpp import LlamaCppChat
from .openai import AzureChatOpenAI, ChatOpenAI
from .openai import AzureChatOpenAI, ChatOpenAI, StructuredOutputChatOpenAI

__all__ = [
"ChatOpenAI",
"AzureChatOpenAI",
"ChatLLM",
"EndpointChatLLM",
"ChatOpenAI",
"StructuredOutputChatOpenAI",
"LCAnthropicChat",
"LCGeminiChat",
"LCCohereChat",
Expand Down
92 changes: 90 additions & 2 deletions libs/kotaemon/kotaemon/llms/chats/openai.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import TYPE_CHECKING, AsyncGenerator, Iterator, Optional
from typing import TYPE_CHECKING, AsyncGenerator, Iterator, Optional, Type

from theflow.utils.modules import import_dotted_string

from kotaemon.base import AIMessage, BaseMessage, HumanMessage, LLMInterface, Param
from kotaemon.base import AIMessage, BaseMessage, HumanMessage, LLMInterface, Param, StructuredOutputLLMInterface

from .base import ChatLLM
from pydantic import BaseModel

if TYPE_CHECKING:
from openai.types.chat.chat_completion_message_param import (
Expand Down Expand Up @@ -328,6 +329,93 @@ def openai_response(self, client, **kwargs):
async def aopenai_response(self, client, **kwargs):
params = self.prepare_params(**kwargs)
return await client.chat.completions.create(**params)


class StructuredOutputChatOpenAI(ChatOpenAI):
"""OpenAI chat model that returns structured output"""
response_schema: Type[BaseModel] = Param(help="class that subclasses pydantics BaseModel", required = True)


def prepare_output(self, resp: dict) -> StructuredOutputLLMInterface:
"""Convert the OpenAI response into StructuredOutputLLMInterface"""
additional_kwargs = {}

if 'parsed' in resp['choices'][0]['message']:
additional_kwargs['parsed'] = resp["choices"][0]["message"][
"parsed"
]

if "tool_calls" in resp["choices"][0]["message"]:
additional_kwargs["tool_calls"] = resp["choices"][0]["message"][
"tool_calls"
]

if resp["choices"][0].get("logprobs") is None:
logprobs = []
else:
all_logprobs = resp["choices"][0]["logprobs"].get("content")
logprobs = (
[logprob["logprob"] for logprob in all_logprobs] if all_logprobs else []
)

output = StructuredOutputLLMInterface(
**additional_kwargs, # TODO: clarify how additional_kwargs is used - diff bw BaseChatOpenAI usage and here
candidates=[(_["message"]["content"] or "") for _ in resp["choices"]],
content=resp["choices"][0]["message"]["content"] or "",
total_tokens=resp["usage"]["total_tokens"],
prompt_tokens=resp["usage"]["prompt_tokens"],
completion_tokens=resp["usage"]["completion_tokens"],
messages=[
AIMessage(content=(_["message"]["content"]) or "")
for _ in resp["choices"]
],
logprobs=logprobs,

)

return output

def prepare_params(self, **kwargs):
if "tools_pydantic" in kwargs:
kwargs.pop("tools_pydantic")

params_ = {
"model": self.model,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"n": self.n,
"stop": self.stop,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"tool_choice": self.tool_choice,
"tools": self.tools,
"logprobs": self.logprobs,
"logit_bias": self.logit_bias,
"top_logprobs": self.top_logprobs,
"top_p": self.top_p,
"response_format": self.response_schema,
}
params = {k: v for k, v in params_.items() if v is not None}
params.update(kwargs)

# doesn't do streaming
params.pop('stream')

return params

def openai_response(self, client, **kwargs):
"""Get the openai response"""
params = self.prepare_params(**kwargs)

return client.beta.chat.completions.parse(**params)


async def aopenai_response(self, client, **kwargs):
"""Get the openai response"""
params = self.prepare_params(**kwargs)

return await client.beta.chat.completions.parse(**params)



class AzureChatOpenAI(BaseChatOpenAI):
Expand Down
Loading

0 comments on commit 3ebc021

Please sign in to comment.