Skip to content

Commit

Permalink
Retain return type from @dispatcher.span (#17817)
Browse files Browse the repository at this point in the history
  • Loading branch information
GICodeWarrior authored Feb 17, 2025
1 parent 85a0046 commit e157ebb
Show file tree
Hide file tree
Showing 18 changed files with 145 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def stream_chat(
)

response = synthesizer.synthesize(message, context_nodes)
assert isinstance(response, StreamingResponse)

def wrapped_gen(response: StreamingResponse) -> ChatResponseGen:
full_response = ""
Expand Down Expand Up @@ -405,6 +406,7 @@ async def astream_chat(
)

response = await synthesizer.asynthesize(message, context_nodes)
assert isinstance(response, AsyncStreamingResponse)

async def wrapped_gen(response: AsyncStreamingResponse) -> ChatResponseAsyncGen:
full_response = ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def _aget_retrieved_ids_and_texts(

return (
[node.node.node_id for node in retrieved_nodes],
[node.node.text for node in retrieved_nodes],
[node.text for node in retrieved_nodes],
)


Expand Down Expand Up @@ -84,7 +84,7 @@ async def _aget_retrieved_ids_and_texts(
node = scored_node.node
if isinstance(node, ImageNode):
image_nodes.append(node)
if node.text:
if isinstance(node, TextNode):
text_nodes.append(node)

if mode == "text":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
(similar with contrastive learning)
"""

from typing import Any, Callable, Dict, List, Optional, Sequence, cast
from typing import Any, Callable, Dict, Generic, List, Optional, Sequence, cast

from llama_index.core.async_utils import DEFAULT_NUM_WORKERS, run_jobs
from llama_index.core.bridge.pydantic import (
Expand All @@ -33,7 +33,7 @@
from llama_index.core.prompts import PromptTemplate
from llama_index.core.schema import BaseNode, TextNode
from llama_index.core.settings import Settings
from llama_index.core.types import BasePydanticProgram
from llama_index.core.types import BasePydanticProgram, Model

DEFAULT_TITLE_NODE_TEMPLATE = """\
Context: {context_str}. Give a title that summarizes all of \
Expand Down Expand Up @@ -462,15 +462,15 @@ async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
"""


class PydanticProgramExtractor(BaseExtractor):
class PydanticProgramExtractor(BaseExtractor, Generic[Model]):
"""Pydantic program extractor.
Uses an LLM to extract out a Pydantic object. Return attributes of that object
in a dictionary.
"""

program: SerializeAsAny[BasePydanticProgram] = Field(
program: SerializeAsAny[BasePydanticProgram[Model]] = Field(
..., description="Pydantic program to extract."
)
input_key: str = Field(
Expand Down Expand Up @@ -500,7 +500,9 @@ async def _acall_program(self, node: BaseNode) -> Dict[str, Any]:
)

ret_object = await self.program.acall(**{self.input_key: extract_str})
return ret_object.dict()
assert not isinstance(ret_object, list)

return ret_object.model_dump()

async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
"""Extract pydantic program."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,7 @@ async def abuild_index_from_nodes(
self._llm.apredict(self.summary_prompt, context_str=text_chunk)
for text_chunk in text_chunks_progress
]
outputs: List[Tuple[str, str]] = await asyncio.gather(*tasks)
summaries = [output[0] for output in outputs]
summaries = await asyncio.gather(*tasks)

event.on_end(payload={"summaries": summaries, "level": level})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import partial
from contextlib import contextmanager
from contextvars import Context, ContextVar, Token, copy_context
from typing import Any, Callable, Generator, List, Optional, Dict, Protocol
from typing import Any, Callable, Generator, List, Optional, Dict, Protocol, TypeVar
import inspect
import logging
import uuid
Expand All @@ -26,6 +26,7 @@
active_instrument_tags: ContextVar[Dict[str, Any]] = ContextVar(
"instrument_tags", default={}
)
_R = TypeVar("_R")


@contextmanager
Expand Down Expand Up @@ -239,7 +240,7 @@ def span_exit(
else:
c = c.parent

def span(self, func: Callable) -> Any:
def span(self, func: Callable[..., _R]) -> Callable[..., _R]:
# The `span` decorator should be idempotent.
try:
if hasattr(func, DISPATCHER_SPAN_DECORATED_ATTR):
Expand Down
29 changes: 19 additions & 10 deletions llama-index-core/llama_index/core/llms/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@

if TYPE_CHECKING:
from llama_index.core.chat_engine.types import AgentChatResponse
from llama_index.core.program.utils import FlexibleModel
from llama_index.core.tools.types import BaseTool
from llama_index.core.llms.structured_llm import StructuredLLM

Expand Down Expand Up @@ -322,11 +323,11 @@ def _as_query_component(self, **kwargs: Any) -> QueryComponent:
@dispatcher.span
def structured_predict(
self,
output_cls: Type[BaseModel],
output_cls: Type[Model],
prompt: PromptTemplate,
llm_kwargs: Optional[Dict[str, Any]] = None,
**prompt_args: Any,
) -> BaseModel:
) -> Model:
r"""Structured predict.
Args:
Expand Down Expand Up @@ -372,17 +373,19 @@ class Test(BaseModel):
)

result = program(llm_kwargs=llm_kwargs, **prompt_args)
assert not isinstance(result, list)

dispatcher.event(LLMStructuredPredictEndEvent(output=result))
return result

@dispatcher.span
async def astructured_predict(
self,
output_cls: Type[BaseModel],
output_cls: Type[Model],
prompt: PromptTemplate,
llm_kwargs: Optional[Dict[str, Any]] = None,
**prompt_args: Any,
) -> BaseModel:
) -> Model:
r"""Async Structured predict.
Args:
Expand Down Expand Up @@ -429,17 +432,19 @@ class Test(BaseModel):
)

result = await program.acall(llm_kwargs=llm_kwargs, **prompt_args)
assert not isinstance(result, list)

dispatcher.event(LLMStructuredPredictEndEvent(output=result))
return result

@dispatcher.span
def stream_structured_predict(
self,
output_cls: Type[BaseModel],
output_cls: Type[Model],
prompt: PromptTemplate,
llm_kwargs: Optional[Dict[str, Any]] = None,
**prompt_args: Any,
) -> Generator[Union[Model, List[Model]], None, None]:
) -> Generator[Union[Model, "FlexibleModel"], None, None]:
r"""Stream Structured predict.
Args:
Expand Down Expand Up @@ -489,18 +494,19 @@ class Test(BaseModel):
result = program.stream_call(llm_kwargs=llm_kwargs, **prompt_args)
for r in result:
dispatcher.event(LLMStructuredPredictInProgressEvent(output=r))
assert not isinstance(r, list)
yield r

dispatcher.event(LLMStructuredPredictEndEvent(output=r))

@dispatcher.span
async def astream_structured_predict(
self,
output_cls: Type[BaseModel],
output_cls: Type[Model],
prompt: PromptTemplate,
llm_kwargs: Optional[Dict[str, Any]] = None,
**prompt_args: Any,
) -> AsyncGenerator[Union[Model, List[Model]], None]:
) -> AsyncGenerator[Union[Model, "FlexibleModel"], None]:
r"""Async Stream Structured predict.
Args:
Expand Down Expand Up @@ -534,8 +540,10 @@ class Test(BaseModel):
```
"""

async def gen() -> AsyncGenerator[Union[Model, List[Model]], None]:
from llama_index.core.program.utils import get_program_for_llm
async def gen() -> AsyncGenerator[Union[Model, "FlexibleModel"], None]:
from llama_index.core.program.utils import (
get_program_for_llm,
)

dispatcher.event(
LLMStructuredPredictStartEvent(
Expand All @@ -552,6 +560,7 @@ async def gen() -> AsyncGenerator[Union[Model, List[Model]], None]:
result = await program.astream_call(llm_kwargs=llm_kwargs, **prompt_args)
async for r in result:
dispatcher.event(LLMStructuredPredictInProgressEvent(output=r))
assert not isinstance(r, list)
yield r

dispatcher.event(LLMStructuredPredictEndEvent(output=r))
Expand Down
6 changes: 3 additions & 3 deletions llama-index-core/llama_index/core/output_parsers/pydantic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Pydantic output parser."""

import json
from typing import Any, List, Optional, Type
from typing import Any, Generic, List, Optional, Type

from llama_index.core.output_parsers.base import ChainableOutputParser
from llama_index.core.output_parsers.utils import extract_json_str
Expand All @@ -15,7 +15,7 @@
"""


class PydanticOutputParser(ChainableOutputParser):
class PydanticOutputParser(ChainableOutputParser, Generic[Model]):
"""Pydantic Output Parser.
Args:
Expand All @@ -36,7 +36,7 @@ def __init__(

@property
def output_cls(self) -> Type[Model]:
return self._output_cls # type: ignore
return self._output_cls

@property
def format_string(self) -> str:
Expand Down
Loading

0 comments on commit e157ebb

Please sign in to comment.