Skip to content

Commit 111dd90

Browse files
baskaryanccurme
andauthored
openai[patch]: support structured output and tools (#30581)
Co-authored-by: ccurme <chester.curme@gmail.com>
1 parent 32f7695 commit 111dd90

File tree

2 files changed

+93
-2
lines changed
  • libs/partners/openai

2 files changed

+93
-2
lines changed

libs/partners/openai/langchain_openai/chat_models/base.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1477,6 +1477,7 @@ def with_structured_output(
14771477
] = "function_calling",
14781478
include_raw: bool = False,
14791479
strict: Optional[bool] = None,
1480+
tools: Optional[list] = None,
14801481
**kwargs: Any,
14811482
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
14821483
"""Model wrapper that returns outputs formatted to match the given schema.
@@ -1537,6 +1538,51 @@ def with_structured_output(
15371538
- None:
15381539
``strict`` argument will not be passed to the model.
15391540
1541+
tools:
1542+
A list of tool-like objects to bind to the chat model. Requires that:
1543+
1544+
- ``method`` is ``"json_schema"`` (default).
1545+
- ``strict=True``
1546+
- ``include_raw=True``
1547+
1548+
If a model elects to call a
1549+
tool, the resulting ``AIMessage`` in ``"raw"`` will include tool calls.
1550+
1551+
.. dropdown:: Example
1552+
1553+
.. code-block:: python
1554+
1555+
from langchain.chat_models import init_chat_model
1556+
from pydantic import BaseModel
1557+
1558+
1559+
class ResponseSchema(BaseModel):
1560+
response: str
1561+
1562+
1563+
def get_weather(location: str) -> str:
1564+
\"\"\"Get weather at a location.\"\"\"
1565+
pass
1566+
1567+
llm = init_chat_model("openai:gpt-4o-mini")
1568+
1569+
structured_llm = llm.with_structured_output(
1570+
ResponseSchema,
1571+
tools=[get_weather],
1572+
strict=True,
1573+
include_raw=True,
1574+
)
1575+
1576+
structured_llm.invoke("What's the weather in Boston?")
1577+
1578+
.. code-block:: python
1579+
1580+
{
1581+
"raw": AIMessage(content="", tool_calls=[...], ...),
1582+
"parsing_error": None,
1583+
"parsed": None,
1584+
}
1585+
15401586
kwargs: Additional keyword args aren't supported.
15411587
15421588
Returns:
@@ -1558,6 +1604,9 @@ def with_structured_output(
15581604
15591605
Support for ``strict`` argument added.
15601606
Support for ``method`` = "json_schema" added.
1607+
1608+
.. versionchanged:: 0.3.12
1609+
Support for ``tools`` added.
15611610
""" # noqa: E501
15621611
if kwargs:
15631612
raise ValueError(f"Received unsupported arguments {kwargs}")
@@ -1642,13 +1691,18 @@ def with_structured_output(
16421691
"Received None."
16431692
)
16441693
response_format = _convert_to_openai_response_format(schema, strict=strict)
1645-
llm = self.bind(
1694+
bind_kwargs = dict(
16461695
response_format=response_format,
16471696
ls_structured_output_format={
16481697
"kwargs": {"method": method, "strict": strict},
16491698
"schema": convert_to_openai_tool(schema),
16501699
},
16511700
)
1701+
if tools:
1702+
bind_kwargs["tools"] = [
1703+
convert_to_openai_tool(t, strict=strict) for t in tools
1704+
]
1705+
llm = self.bind(**bind_kwargs)
16521706
if is_pydantic_schema:
16531707
output_parser = RunnableLambda(
16541708
partial(_oai_structured_outputs_parser, schema=cast(type, schema))
@@ -2776,14 +2830,16 @@ def _convert_to_openai_response_format(
27762830

27772831
def _oai_structured_outputs_parser(
27782832
ai_msg: AIMessage, schema: Type[_BM]
2779-
) -> PydanticBaseModel:
2833+
) -> Optional[PydanticBaseModel]:
27802834
if parsed := ai_msg.additional_kwargs.get("parsed"):
27812835
if isinstance(parsed, dict):
27822836
return schema(**parsed)
27832837
else:
27842838
return parsed
27852839
elif ai_msg.additional_kwargs.get("refusal"):
27862840
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])
2841+
elif ai_msg.tool_calls:
2842+
return None
27872843
else:
27882844
raise ValueError(
27892845
"Structured Output response does not have a 'parsed' field nor a 'refusal' "

libs/partners/openai/tests/integration_tests/chat_models/test_base.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,3 +1265,38 @@ class ResponseFormat(BaseModel):
12651265
assert len(full.tool_calls) == 1
12661266
tool_call = full.tool_calls[0]
12671267
assert tool_call["name"] == "GenerateUsername"
1268+
1269+
1270+
def test_tools_and_structured_output() -> None:
1271+
class ResponseFormat(BaseModel):
1272+
response: str
1273+
explanation: str
1274+
1275+
llm = ChatOpenAI(model="gpt-4o-mini").with_structured_output(
1276+
ResponseFormat, strict=True, include_raw=True, tools=[GenerateUsername]
1277+
)
1278+
1279+
expected_keys = {"raw", "parsing_error", "parsed"}
1280+
query = "Hello"
1281+
tool_query = "Generate a user name for Alice, black hair. Use the tool."
1282+
# Test invoke
1283+
## Engage structured output
1284+
response = llm.invoke(query)
1285+
assert isinstance(response["parsed"], ResponseFormat)
1286+
## Engage tool calling
1287+
response_tools = llm.invoke(tool_query)
1288+
ai_msg = response_tools["raw"]
1289+
assert isinstance(ai_msg, AIMessage)
1290+
assert ai_msg.tool_calls
1291+
assert response_tools["parsed"] is None
1292+
1293+
# Test stream
1294+
aggregated: dict = {}
1295+
for chunk in llm.stream(tool_query):
1296+
assert isinstance(chunk, dict)
1297+
assert all(key in expected_keys for key in chunk)
1298+
aggregated = {**aggregated, **chunk}
1299+
assert all(key in aggregated for key in expected_keys)
1300+
assert isinstance(aggregated["raw"], AIMessage)
1301+
assert aggregated["raw"].tool_calls
1302+
assert aggregated["parsed"] is None

0 commit comments

Comments
 (0)