Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retain return type from @dispatcher.span #17817

Merged
merged 14 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def stream_chat(
)

response = synthesizer.synthesize(message, context_nodes)
assert isinstance(response, StreamingResponse)

def wrapped_gen(response: StreamingResponse) -> ChatResponseGen:
full_response = ""
Expand Down Expand Up @@ -405,6 +406,7 @@ async def astream_chat(
)

response = await synthesizer.asynthesize(message, context_nodes)
assert isinstance(response, AsyncStreamingResponse)

async def wrapped_gen(response: AsyncStreamingResponse) -> ChatResponseAsyncGen:
full_response = ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def _aget_retrieved_ids_and_texts(

return (
[node.node.node_id for node in retrieved_nodes],
[node.node.text for node in retrieved_nodes],
[node.text for node in retrieved_nodes],
)


Expand Down Expand Up @@ -84,7 +84,7 @@ async def _aget_retrieved_ids_and_texts(
node = scored_node.node
if isinstance(node, ImageNode):
image_nodes.append(node)
if node.text:
if isinstance(node, TextNode):
text_nodes.append(node)

if mode == "text":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
(similar with contrastive learning)
"""

from typing import Any, Callable, Dict, List, Optional, Sequence, cast
from typing import Any, Callable, Dict, Generic, List, Optional, Sequence, cast

from llama_index.core.async_utils import DEFAULT_NUM_WORKERS, run_jobs
from llama_index.core.bridge.pydantic import (
Expand All @@ -33,7 +33,7 @@
from llama_index.core.prompts import PromptTemplate
from llama_index.core.schema import BaseNode, TextNode
from llama_index.core.settings import Settings
from llama_index.core.types import BasePydanticProgram
from llama_index.core.types import BasePydanticProgram, Model

DEFAULT_TITLE_NODE_TEMPLATE = """\
Context: {context_str}. Give a title that summarizes all of \
Expand Down Expand Up @@ -462,15 +462,15 @@ async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
"""


class PydanticProgramExtractor(BaseExtractor):
class PydanticProgramExtractor(BaseExtractor, Generic[Model]):
"""Pydantic program extractor.
Uses an LLM to extract out a Pydantic object. Return attributes of that object
in a dictionary.
"""

program: SerializeAsAny[BasePydanticProgram] = Field(
program: SerializeAsAny[BasePydanticProgram[Model]] = Field(
..., description="Pydantic program to extract."
)
input_key: str = Field(
Expand Down Expand Up @@ -500,7 +500,9 @@ async def _acall_program(self, node: BaseNode) -> Dict[str, Any]:
)

ret_object = await self.program.acall(**{self.input_key: extract_str})
return ret_object.dict()
assert not isinstance(ret_object, list)

return ret_object.model_dump()

async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
"""Extract pydantic program."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,7 @@ async def abuild_index_from_nodes(
self._llm.apredict(self.summary_prompt, context_str=text_chunk)
for text_chunk in text_chunks_progress
]
outputs: List[Tuple[str, str]] = await asyncio.gather(*tasks)
summaries = [output[0] for output in outputs]
summaries = await asyncio.gather(*tasks)

event.on_end(payload={"summaries": summaries, "level": level})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import partial
from contextlib import contextmanager
from contextvars import Context, ContextVar, Token, copy_context
from typing import Any, Callable, Generator, List, Optional, Dict, Protocol
from typing import Any, Callable, Generator, List, Optional, Dict, Protocol, TypeVar
import inspect
import logging
import uuid
Expand All @@ -26,6 +26,7 @@
active_instrument_tags: ContextVar[Dict[str, Any]] = ContextVar(
"instrument_tags", default={}
)
_R = TypeVar("_R")


@contextmanager
Expand Down Expand Up @@ -239,7 +240,7 @@ def span_exit(
else:
c = c.parent

def span(self, func: Callable) -> Any:
def span(self, func: Callable[..., _R]) -> Callable[..., _R]:
# The `span` decorator should be idempotent.
try:
if hasattr(func, DISPATCHER_SPAN_DECORATED_ATTR):
Expand Down
29 changes: 19 additions & 10 deletions llama-index-core/llama_index/core/llms/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@

if TYPE_CHECKING:
from llama_index.core.chat_engine.types import AgentChatResponse
from llama_index.core.program.utils import FlexibleModel
from llama_index.core.tools.types import BaseTool
from llama_index.core.llms.structured_llm import StructuredLLM

Expand Down Expand Up @@ -322,11 +323,11 @@ def _as_query_component(self, **kwargs: Any) -> QueryComponent:
@dispatcher.span
def structured_predict(
self,
output_cls: Type[BaseModel],
output_cls: Type[Model],
prompt: PromptTemplate,
llm_kwargs: Optional[Dict[str, Any]] = None,
**prompt_args: Any,
) -> BaseModel:
) -> Model:
r"""Structured predict.
Args:
Expand Down Expand Up @@ -372,17 +373,19 @@ class Test(BaseModel):
)

result = program(llm_kwargs=llm_kwargs, **prompt_args)
assert not isinstance(result, list)

dispatcher.event(LLMStructuredPredictEndEvent(output=result))
return result

@dispatcher.span
async def astructured_predict(
self,
output_cls: Type[BaseModel],
output_cls: Type[Model],
prompt: PromptTemplate,
llm_kwargs: Optional[Dict[str, Any]] = None,
**prompt_args: Any,
) -> BaseModel:
) -> Model:
r"""Async Structured predict.
Args:
Expand Down Expand Up @@ -429,17 +432,19 @@ class Test(BaseModel):
)

result = await program.acall(llm_kwargs=llm_kwargs, **prompt_args)
assert not isinstance(result, list)

dispatcher.event(LLMStructuredPredictEndEvent(output=result))
return result

@dispatcher.span
def stream_structured_predict(
self,
output_cls: Type[BaseModel],
output_cls: Type[Model],
prompt: PromptTemplate,
llm_kwargs: Optional[Dict[str, Any]] = None,
**prompt_args: Any,
) -> Generator[Union[Model, List[Model]], None, None]:
) -> Generator[Union[Model, "FlexibleModel"], None, None]:
r"""Stream Structured predict.
Args:
Expand Down Expand Up @@ -489,18 +494,19 @@ class Test(BaseModel):
result = program.stream_call(llm_kwargs=llm_kwargs, **prompt_args)
for r in result:
dispatcher.event(LLMStructuredPredictInProgressEvent(output=r))
assert not isinstance(r, list)
yield r

dispatcher.event(LLMStructuredPredictEndEvent(output=r))

@dispatcher.span
async def astream_structured_predict(
self,
output_cls: Type[BaseModel],
output_cls: Type[Model],
prompt: PromptTemplate,
llm_kwargs: Optional[Dict[str, Any]] = None,
**prompt_args: Any,
) -> AsyncGenerator[Union[Model, List[Model]], None]:
) -> AsyncGenerator[Union[Model, "FlexibleModel"], None]:
r"""Async Stream Structured predict.
Args:
Expand Down Expand Up @@ -534,8 +540,10 @@ class Test(BaseModel):
```
"""

async def gen() -> AsyncGenerator[Union[Model, List[Model]], None]:
from llama_index.core.program.utils import get_program_for_llm
async def gen() -> AsyncGenerator[Union[Model, "FlexibleModel"], None]:
from llama_index.core.program.utils import (
get_program_for_llm,
)

dispatcher.event(
LLMStructuredPredictStartEvent(
Expand All @@ -552,6 +560,7 @@ async def gen() -> AsyncGenerator[Union[Model, List[Model]], None]:
result = await program.astream_call(llm_kwargs=llm_kwargs, **prompt_args)
async for r in result:
dispatcher.event(LLMStructuredPredictInProgressEvent(output=r))
assert not isinstance(r, list)
yield r

dispatcher.event(LLMStructuredPredictEndEvent(output=r))
Expand Down
6 changes: 3 additions & 3 deletions llama-index-core/llama_index/core/output_parsers/pydantic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Pydantic output parser."""

import json
from typing import Any, List, Optional, Type
from typing import Any, Generic, List, Optional, Type

from llama_index.core.output_parsers.base import ChainableOutputParser
from llama_index.core.output_parsers.utils import extract_json_str
Expand All @@ -15,7 +15,7 @@
"""


class PydanticOutputParser(ChainableOutputParser):
class PydanticOutputParser(ChainableOutputParser, Generic[Model]):
"""Pydantic Output Parser.
Args:
Expand All @@ -36,7 +36,7 @@ def __init__(

@property
def output_cls(self) -> Type[Model]:
return self._output_cls # type: ignore
return self._output_cls

@property
def format_string(self) -> str:
Expand Down
Loading