Skip to content

Commit ac0d4f3

Browse files
committed
allow non leaf outputs in supercomponents
1 parent 185e1c7 commit ac0d4f3

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

haystack/core/super_component/super_component.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def __init__(
129129
self._original_input_mapping = input_mapping
130130

131131
# Set output types based on pipeline and mapping
132-
pipeline_outputs = self.pipeline.outputs()
132+
pipeline_outputs = self.pipeline.outputs(include_components_with_connected_outputs=True)
133133
resolved_output_mapping = (
134134
output_mapping if output_mapping is not None else self._create_output_mapping(pipeline_outputs)
135135
)
@@ -189,9 +189,14 @@ def run(self, **kwargs: Any) -> Dict[str, Any]:
189189
"""
190190
filtered_inputs = {param: value for param, value in kwargs.items() if value != _delegate_default}
191191
pipeline_inputs = self._map_explicit_inputs(input_mapping=self.input_mapping, inputs=filtered_inputs)
192-
pipeline_outputs = self.pipeline.run(data=pipeline_inputs)
192+
include_outputs_from = self._get_include_outputs_from()
193+
pipeline_outputs = self.pipeline.run(data=pipeline_inputs, include_outputs_from=include_outputs_from)
193194
return self._map_explicit_outputs(pipeline_outputs, self.output_mapping)
194195

196+
def _get_include_outputs_from(self) -> set[str]:
197+
# Collecting the component names from output_mapping
198+
return {self._split_component_path(path)[0] for path in self.output_mapping.keys()}
199+
195200
async def run_async(self, **kwargs: Any) -> Dict[str, Any]:
196201
"""
197202
Runs the wrapped pipeline with the provided inputs async.

test/core/super_component/test_super_component.py

+10
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,13 @@ def from_dict(cls, data):
278278

279279
assert custom_serialized["type"] == "test_super_component.CustomSuperComponent"
280280
assert custom_super_component._to_super_component_dict() == serialized
281+
282+
def test_super_component_non_leaf_output(self, rag_pipeline):
283+
# 'retriever' is not a leaf, but should now be allowed
284+
output_mapping = {"retriever.documents": "retrieved_docs", "answer_builder.answers": "final_answers"}
285+
wrapper = SuperComponent(pipeline=rag_pipeline, output_mapping=output_mapping)
286+
wrapper.warm_up()
287+
result = wrapper.run(query="What is the capital of France?")
288+
assert "final_answers" in result # leaf output
289+
assert "retrieved_docs" in result # non-leaf output
290+
assert isinstance(result["retrieved_docs"][0], Document)

0 commit comments

Comments
 (0)