Skip to content

Commit

Permalink
Stronger function calling typing
Browse files Browse the repository at this point in the history
  • Loading branch information
GICodeWarrior committed Feb 15, 2025
1 parent 614f609 commit 4e9d83a
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 75 deletions.
30 changes: 19 additions & 11 deletions llama-index-core/llama_index/core/llms/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@

if TYPE_CHECKING:
from llama_index.core.chat_engine.types import AgentChatResponse
from llama_index.core.program.utils import FlexibleModel
from llama_index.core.tools.types import BaseTool
from llama_index.core.llms.structured_llm import StructuredLLM

Expand Down Expand Up @@ -322,11 +323,11 @@ def _as_query_component(self, **kwargs: Any) -> QueryComponent:
@dispatcher.span
def structured_predict(
self,
output_cls: Type[BaseModel],
output_cls: Type[Model],
prompt: PromptTemplate,
llm_kwargs: Optional[Dict[str, Any]] = None,
**prompt_args: Any,
) -> BaseModel:
) -> Model:
r"""Structured predict.
Args:
Expand Down Expand Up @@ -372,17 +373,18 @@ class Test(BaseModel):
)

result = program(llm_kwargs=llm_kwargs, **prompt_args)
assert isinstance(result, output_cls)
dispatcher.event(LLMStructuredPredictEndEvent(output=result))
return result

@dispatcher.span
async def astructured_predict(
self,
output_cls: Type[BaseModel],
output_cls: Type[Model],
prompt: PromptTemplate,
llm_kwargs: Optional[Dict[str, Any]] = None,
**prompt_args: Any,
) -> BaseModel:
) -> Model:
r"""Async Structured predict.
Args:
Expand Down Expand Up @@ -429,17 +431,18 @@ class Test(BaseModel):
)

result = await program.acall(llm_kwargs=llm_kwargs, **prompt_args)
assert isinstance(result, output_cls)
dispatcher.event(LLMStructuredPredictEndEvent(output=result))
return result

@dispatcher.span
def stream_structured_predict(
self,
output_cls: Type[BaseModel],
output_cls: Type[Model],
prompt: PromptTemplate,
llm_kwargs: Optional[Dict[str, Any]] = None,
**prompt_args: Any,
) -> Generator[Union[Model, List[Model]], None, None]:
) -> Generator[Union[Model, "FlexibleModel"], None, None]:
r"""Stream Structured predict.
Args:
Expand Down Expand Up @@ -472,7 +475,7 @@ class Test(BaseModel):
print(partial_output.name)
```
"""
from llama_index.core.program.utils import get_program_for_llm
from llama_index.core.program.utils import FlexibleModel, get_program_for_llm

dispatcher.event(
LLMStructuredPredictStartEvent(
Expand All @@ -488,6 +491,7 @@ class Test(BaseModel):

result = program.stream_call(llm_kwargs=llm_kwargs, **prompt_args)
for r in result:
assert isinstance(r, (FlexibleModel, output_cls))
dispatcher.event(LLMStructuredPredictInProgressEvent(output=r))
yield r

Expand All @@ -496,11 +500,11 @@ class Test(BaseModel):
@dispatcher.span
async def astream_structured_predict(
self,
output_cls: Type[BaseModel],
output_cls: Type[Model],
prompt: PromptTemplate,
llm_kwargs: Optional[Dict[str, Any]] = None,
**prompt_args: Any,
) -> AsyncGenerator[Union[Model, List[Model]], None]:
) -> AsyncGenerator[Union[Model, "FlexibleModel"], None]:
r"""Async Stream Structured predict.
Args:
Expand Down Expand Up @@ -534,8 +538,11 @@ class Test(BaseModel):
```
"""

async def gen() -> AsyncGenerator[Union[Model, List[Model]], None]:
from llama_index.core.program.utils import get_program_for_llm
async def gen() -> AsyncGenerator[Union[Model, "FlexibleModel"], None]:
from llama_index.core.program.utils import (
FlexibleModel,
get_program_for_llm,
)

dispatcher.event(
LLMStructuredPredictStartEvent(
Expand All @@ -551,6 +558,7 @@ async def gen() -> AsyncGenerator[Union[Model, List[Model]], None]:

result = await program.astream_call(llm_kwargs=llm_kwargs, **prompt_args)
async for r in result:
assert isinstance(r, (FlexibleModel, output_cls))
dispatcher.event(LLMStructuredPredictInProgressEvent(output=r))
yield r

Expand Down
6 changes: 3 additions & 3 deletions llama-index-core/llama_index/core/output_parsers/pydantic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Pydantic output parser."""

import json
from typing import Any, List, Optional, Type
from typing import Any, Generic, List, Optional, Type

from llama_index.core.output_parsers.base import ChainableOutputParser
from llama_index.core.output_parsers.utils import extract_json_str
Expand All @@ -15,7 +15,7 @@
"""


class PydanticOutputParser(ChainableOutputParser):
class PydanticOutputParser(ChainableOutputParser, Generic[Model]):
"""Pydantic Output Parser.
Args:
Expand All @@ -36,7 +36,7 @@ def __init__(

@property
def output_cls(self) -> Type[Model]:
return self._output_cls # type: ignore
return self._output_cls

@property
def format_string(self) -> str:
Expand Down
90 changes: 48 additions & 42 deletions llama-index-core/llama_index/core/program/function_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
)

from llama_index.core.bridge.pydantic import (
BaseModel,
ValidationError,
)
from llama_index.core.base.llms.types import ChatResponse
Expand All @@ -26,39 +25,22 @@
from llama_index.core.tools.function_tool import FunctionTool
from llama_index.core.chat_engine.types import AgentChatResponse
from llama_index.core.program.utils import (
FlexibleModel,
process_streaming_objects,
num_valid_fields,
)

_logger = logging.getLogger(__name__)


def _parse_tool_outputs(
agent_response: AgentChatResponse,
allow_parallel_tool_calls: bool = False,
) -> Union[BaseModel, List[BaseModel]]:
"""Parse tool outputs."""
outputs = [cast(BaseModel, s.raw_output) for s in agent_response.sources]
if allow_parallel_tool_calls:
return outputs
else:
if len(outputs) > 1:
_logger.warning(
"Multiple outputs found, returning first one. "
"If you want to return all outputs, set output_multiple=True."
)

return outputs[0]


def get_function_tool(output_cls: Type[Model]) -> FunctionTool:
"""Get function tool."""
schema = output_cls.model_json_schema()
schema_description = schema.get("description", None)

# NOTE: this does not specify the schema in the function signature,
# so instead we'll directly provide it in the fn_schema in the ToolMetadata
def model_fn(**kwargs: Any) -> BaseModel:
def model_fn(**kwargs: Any) -> Model:
"""Model function."""
return output_cls(**kwargs)

Expand All @@ -70,7 +52,7 @@ def model_fn(**kwargs: Any) -> BaseModel:
)


class FunctionCallingProgram(BasePydanticProgram[BaseModel]):
class FunctionCallingProgram(BasePydanticProgram[Model]):
"""Function Calling Program.
Uses function calling LLMs to obtain a structured output.
Expand Down Expand Up @@ -122,7 +104,7 @@ def from_defaults(
prompt = PromptTemplate(prompt_template_str)

return cls(
output_cls=output_cls, # type: ignore
output_cls=output_cls,
llm=llm, # type: ignore
prompt=cast(PromptTemplate, prompt),
tool_choice=tool_choice,
Expand All @@ -131,7 +113,7 @@ def from_defaults(
)

@property
def output_cls(self) -> Type[BaseModel]:
def output_cls(self) -> Type[Model]:
return self._output_cls

@property
Expand All @@ -147,7 +129,7 @@ def __call__(
*args: Any,
llm_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> BaseModel:
) -> Union[Model, List[Model]]:
llm_kwargs = llm_kwargs or {}
tool = get_function_tool(self._output_cls)

Expand All @@ -161,17 +143,17 @@ def __call__(
allow_parallel_tool_calls=self._allow_parallel_tool_calls,
**llm_kwargs,
)
return _parse_tool_outputs(
return self._parse_tool_outputs(
agent_response,
allow_parallel_tool_calls=self._allow_parallel_tool_calls,
) # type: ignore
)

async def acall(
self,
*args: Any,
llm_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> BaseModel:
) -> Union[Model, List[Model]]:
llm_kwargs = llm_kwargs or {}
tool = get_function_tool(self._output_cls)

Expand All @@ -182,16 +164,34 @@ async def acall(
allow_parallel_tool_calls=self._allow_parallel_tool_calls,
**llm_kwargs,
)
return _parse_tool_outputs(
return self._parse_tool_outputs(
agent_response,
allow_parallel_tool_calls=self._allow_parallel_tool_calls,
) # type: ignore
)

def _parse_tool_outputs(
self,
agent_response: AgentChatResponse,
allow_parallel_tool_calls: bool = False,
) -> Union[Model, List[Model]]:
"""Parse tool outputs."""
outputs = [cast(Model, s.raw_output) for s in agent_response.sources]
if allow_parallel_tool_calls:
return outputs
else:
if len(outputs) > 1:
_logger.warning(
"Multiple outputs found, returning first one. "
"If you want to return all outputs, set output_multiple=True."
)

return outputs[0]

def _process_objects(
self,
chat_response: ChatResponse,
output_cls: Type[BaseModel],
cur_objects: Optional[List[BaseModel]] = None,
output_cls: Type[Model],
cur_objects: Optional[List[Model]] = None,
) -> Union[Model, List[Model]]:
"""Process stream."""
tool_calls = self._llm.get_tool_calls_from_response(
Expand All @@ -202,7 +202,7 @@ def _process_objects(
# TODO: change
if len(tool_calls) == 0:
# if no tool calls, return single blank output_class
return output_cls() # type: ignore
return output_cls()

tool_fn_args = [call.tool_kwargs for call in tool_calls]
objects = [
Expand All @@ -222,22 +222,24 @@ def _process_objects(
new_obj = self._output_cls.model_validate(obj.model_dump())
except ValidationError as e:
_logger.warning(f"Failed to parse object: {e}")
new_obj = obj # type: ignore
new_obj = obj
new_cur_objects.append(new_obj)

if self._allow_parallel_tool_calls:
return new_cur_objects # type: ignore
return new_cur_objects
else:
if len(new_cur_objects) > 1:
_logger.warning(
"Multiple outputs found, returning first one. "
"If you want to return all outputs, set output_multiple=True."
)
return new_cur_objects[0] # type: ignore
return new_cur_objects[0]

def stream_call( # type: ignore
def stream_call(
self, *args: Any, llm_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> Generator[Union[Model, List[Model]], None, None]:
) -> Generator[
Union[Model, List[Model], FlexibleModel, List[FlexibleModel]], None, None
]:
"""Stream object.
Returns a generator returning partials of the same object
Expand Down Expand Up @@ -273,14 +275,16 @@ def stream_call( # type: ignore
llm=self._llm,
)
cur_objects = objects if isinstance(objects, list) else [objects]
yield objects # type: ignore
yield objects
except Exception as e:
_logger.warning(f"Failed to parse streaming response: {e}")
continue

async def astream_call( # type: ignore
async def astream_call(
self, *args: Any, llm_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> AsyncGenerator[Union[Model, List[Model]], None]:
) -> AsyncGenerator[
Union[Model, List[Model], FlexibleModel, List[FlexibleModel]], None
]:
"""Stream objects.
Returns a generator returning partials of the same object
Expand All @@ -302,7 +306,9 @@ async def astream_call( # type: ignore
**(llm_kwargs or {}),
)

async def gen() -> AsyncGenerator[Union[Model, List[Model]], None]:
async def gen() -> AsyncGenerator[
Union[Model, List[Model], FlexibleModel, List[FlexibleModel]], None
]:
cur_objects = None
async for partial_resp in chat_response_gen:
try:
Expand All @@ -315,7 +321,7 @@ async def gen() -> AsyncGenerator[Union[Model, List[Model]], None]:
llm=self._llm,
)
cur_objects = objects if isinstance(objects, list) else [objects]
yield objects # type: ignore
yield objects
except Exception as e:
_logger.warning(f"Failed to parse streaming response: {e}")
continue
Expand Down
Loading

0 comments on commit 4e9d83a

Please sign in to comment.