Skip to content

Commit 9d3262c

Browse files
mahirshahccurme
andauthored
core: Propagate config_factories in RunnableBinding (#30603)
- **Description:** Propagates config_factories when calling decoration methods for RunnableBinding--e.g. bind, with_config, with_types, with_retry, and with_listeners. This ensures that configs attached to the original RunnableBinding are kept when creating the new RunnableBinding and the configs are merged during invocation. Picks up where #30551 left off. - **Issue:** #30531 Co-authored-by: ccurme <chester.curme@gmail.com>
1 parent 8a69de5 commit 9d3262c

File tree

2 files changed

+83
-12
lines changed

2 files changed

+83
-12
lines changed

libs/core/langchain_core/runnables/base.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5766,6 +5766,7 @@ def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
57665766
return self.__class__(
57675767
bound=self.bound,
57685768
config=self.config,
5769+
config_factories=self.config_factories,
57695770
kwargs={**self.kwargs, **kwargs},
57705771
custom_input_type=self.custom_input_type,
57715772
custom_output_type=self.custom_output_type,
@@ -5782,6 +5783,7 @@ def with_config(
57825783
bound=self.bound,
57835784
kwargs=self.kwargs,
57845785
config=cast("RunnableConfig", {**self.config, **(config or {}), **kwargs}),
5786+
config_factories=self.config_factories,
57855787
custom_input_type=self.custom_input_type,
57865788
custom_output_type=self.custom_output_type,
57875789
)
@@ -5817,22 +5819,23 @@ def with_listeners(
58175819
"""
58185820
from langchain_core.tracers.root_listeners import RootListenersTracer
58195821

5822+
def listener_config_factory(config: RunnableConfig) -> RunnableConfig:
5823+
return {
5824+
"callbacks": [
5825+
RootListenersTracer(
5826+
config=config,
5827+
on_start=on_start,
5828+
on_end=on_end,
5829+
on_error=on_error,
5830+
)
5831+
],
5832+
}
5833+
58205834
return self.__class__(
58215835
bound=self.bound,
58225836
kwargs=self.kwargs,
58235837
config=self.config,
5824-
config_factories=[
5825-
lambda config: {
5826-
"callbacks": [
5827-
RootListenersTracer(
5828-
config=config,
5829-
on_start=on_start,
5830-
on_end=on_end,
5831-
on_error=on_error,
5832-
)
5833-
],
5834-
}
5835-
],
5838+
config_factories=[listener_config_factory] + self.config_factories,
58365839
custom_input_type=self.custom_input_type,
58375840
custom_output_type=self.custom_output_type,
58385841
)
@@ -5847,6 +5850,7 @@ def with_types(
58475850
bound=self.bound,
58485851
kwargs=self.kwargs,
58495852
config=self.config,
5853+
config_factories=self.config_factories,
58505854
custom_input_type=(
58515855
input_type if input_type is not None else self.custom_input_type
58525856
),
@@ -5861,6 +5865,7 @@ def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]:
58615865
bound=self.bound.with_retry(**kwargs),
58625866
kwargs=self.kwargs,
58635867
config=self.config,
5868+
config_factories=self.config_factories,
58645869
)
58655870

58665871
@override

libs/core/tests/unit_tests/runnables/test_runnable.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1712,6 +1712,72 @@ async def test_with_listeners_async(mocker: MockerFixture) -> None:
17121712
assert mock_end.call_count == 1
17131713

17141714

1715+
def test_with_listener_propagation(mocker: MockerFixture) -> None:
1716+
prompt = (
1717+
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
1718+
+ "{question}"
1719+
)
1720+
chat = FakeListChatModel(responses=["foo"])
1721+
chain: Runnable = prompt | chat
1722+
mock_start = mocker.Mock()
1723+
mock_end = mocker.Mock()
1724+
chain_with_listeners = chain.with_listeners(on_start=mock_start, on_end=mock_end)
1725+
1726+
chain_with_listeners.with_retry().invoke({"question": "Who are you?"})
1727+
1728+
assert mock_start.call_count == 1
1729+
assert mock_start.call_args[0][0].name == "RunnableSequence"
1730+
assert mock_end.call_count == 1
1731+
1732+
mock_start.reset_mock()
1733+
mock_end.reset_mock()
1734+
1735+
chain_with_listeners.with_types(output_type=str).invoke(
1736+
{"question": "Who are you?"}
1737+
)
1738+
1739+
assert mock_start.call_count == 1
1740+
assert mock_start.call_args[0][0].name == "RunnableSequence"
1741+
assert mock_end.call_count == 1
1742+
1743+
mock_start.reset_mock()
1744+
mock_end.reset_mock()
1745+
1746+
chain_with_listeners.with_config({"tags": ["foo"]}).invoke(
1747+
{"question": "Who are you?"}
1748+
)
1749+
1750+
assert mock_start.call_count == 1
1751+
assert mock_start.call_args[0][0].name == "RunnableSequence"
1752+
assert mock_end.call_count == 1
1753+
1754+
mock_start.reset_mock()
1755+
mock_end.reset_mock()
1756+
1757+
chain_with_listeners.bind(stop=["foo"]).invoke({"question": "Who are you?"})
1758+
1759+
assert mock_start.call_count == 1
1760+
assert mock_start.call_args[0][0].name == "RunnableSequence"
1761+
assert mock_end.call_count == 1
1762+
1763+
mock_start.reset_mock()
1764+
mock_end.reset_mock()
1765+
1766+
mock_start_inner = mocker.Mock()
1767+
mock_end_inner = mocker.Mock()
1768+
1769+
chain_with_listeners.with_listeners(
1770+
on_start=mock_start_inner, on_end=mock_end_inner
1771+
).invoke({"question": "Who are you?"})
1772+
1773+
assert mock_start.call_count == 1
1774+
assert mock_start.call_args[0][0].name == "RunnableSequence"
1775+
assert mock_end.call_count == 1
1776+
assert mock_start_inner.call_count == 1
1777+
assert mock_start_inner.call_args[0][0].name == "RunnableSequence"
1778+
assert mock_end_inner.call_count == 1
1779+
1780+
17151781
@freeze_time("2023-01-01")
17161782
def test_prompt_with_chat_model(
17171783
mocker: MockerFixture,

0 commit comments

Comments
 (0)