Skip to content

Commit 83edd3f

Browse files
committed
Reduce duplicate code
1 parent d360b18 commit 83edd3f

File tree

3 files changed

+38
-56
lines changed

3 files changed

+38
-56
lines changed

haystack/core/pipeline/async_pipeline.py

+4-26
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,8 @@
1111
from haystack.core.errors import PipelineMaxComponentRuns, PipelineRuntimeError
1212
from haystack.core.pipeline.base import (
1313
_COMPONENT_INPUT,
14-
_COMPONENT_NAME,
1514
_COMPONENT_OUTPUT,
16-
_COMPONENT_RUN,
17-
_COMPONENT_TYPE,
15+
_COMPONENT_VISITS,
1816
ComponentPriority,
1917
PipelineBase,
2018
)
@@ -57,28 +55,8 @@ async def _run_component_async( # pylint: disable=too-many-positional-arguments
5755
raise PipelineMaxComponentRuns(f"Max runs for '{component_name}' reached.")
5856

5957
instance: Component = component["instance"]
60-
with tracing.tracer.trace(
61-
_COMPONENT_RUN,
62-
tags={
63-
_COMPONENT_NAME: component_name,
64-
_COMPONENT_TYPE: instance.__class__.__name__,
65-
"haystack.component.input_types": {k: type(v).__name__ for k, v in component_inputs.items()},
66-
"haystack.component.input_spec": {
67-
key: {
68-
"type": (value.type.__name__ if isinstance(value.type, type) else str(value.type)),
69-
"senders": value.senders,
70-
}
71-
for key, value in instance.__haystack_input__._sockets_dict.items() # type: ignore
72-
},
73-
"haystack.component.output_spec": {
74-
key: {
75-
"type": (value.type.__name__ if isinstance(value.type, type) else str(value.type)),
76-
"receivers": value.receivers,
77-
}
78-
for key, value in instance.__haystack_output__._sockets_dict.items() # type: ignore
79-
},
80-
},
81-
parent_span=parent_span,
58+
with PipelineBase._create_component_span(
59+
component_name=component_name, instance=instance, inputs=component_inputs, parent_span=parent_span
8260
) as span:
8361
span.set_content_tag(_COMPONENT_INPUT, deepcopy(component_inputs))
8462
logger.info("Running component {component_name}", component_name=component_name)
@@ -97,7 +75,7 @@ async def _run_component_async( # pylint: disable=too-many-positional-arguments
9775
if not isinstance(outputs, dict):
9876
raise PipelineRuntimeError.from_invalid_output(component_name, instance.__class__, outputs)
9977

100-
span.set_tag("haystack.component.visits", component_visits[component_name])
78+
span.set_tag(_COMPONENT_VISITS, component_visits[component_name])
10179
span.set_content_tag(_COMPONENT_OUTPUT, deepcopy(outputs))
10280

10381
return outputs

haystack/core/pipeline/base.py

+30-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import networkx # type:ignore
1414

15-
from haystack import logging
15+
from haystack import logging, tracing
1616
from haystack.core.component import Component, InputSocket, OutputSocket, component
1717
from haystack.core.errors import (
1818
DeserializationError,
@@ -54,11 +54,9 @@
5454

5555

5656
# Constants for tracing tags
57-
_COMPONENT_RUN = "haystack.component.run"
58-
_COMPONENT_NAME = "haystack.component.name"
59-
_COMPONENT_TYPE = "haystack.component.type"
6057
_COMPONENT_INPUT = "haystack.component.input"
6158
_COMPONENT_OUTPUT = "haystack.component.output"
59+
_COMPONENT_VISITS = "haystack.component.visits"
6260

6361

6462
class ComponentPriority(IntEnum):
@@ -776,6 +774,34 @@ def warm_up(self):
776774
logger.info("Warming up component {node}...", node=node)
777775
self.graph.nodes[node]["instance"].warm_up()
778776

777+
@staticmethod
778+
def _create_component_span(
779+
component_name: str, instance: Component, inputs: Dict[str, Any], parent_span: Optional[tracing.Span] = None
780+
):
781+
return tracing.tracer.trace(
782+
"haystack.component.run",
783+
tags={
784+
"haystack.component.name": component_name,
785+
"haystack.component.type": instance.__class__.__name__,
786+
"haystack.component.input_types": {k: type(v).__name__ for k, v in inputs.items()},
787+
"haystack.component.input_spec": {
788+
key: {
789+
"type": (value.type.__name__ if isinstance(value.type, type) else str(value.type)),
790+
"senders": value.senders,
791+
}
792+
for key, value in instance.__haystack_input__._sockets_dict.items() # type: ignore
793+
},
794+
"haystack.component.output_spec": {
795+
key: {
796+
"type": (value.type.__name__ if isinstance(value.type, type) else str(value.type)),
797+
"receivers": value.receivers,
798+
}
799+
for key, value in instance.__haystack_output__._sockets_dict.items() # type: ignore
800+
},
801+
},
802+
parent_span=parent_span,
803+
)
804+
779805
def _validate_input(self, data: Dict[str, Any]):
780806
"""
781807
Validates pipeline input data.

haystack/core/pipeline/pipeline.py

+4-26
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,8 @@
1010
from haystack.core.errors import PipelineRuntimeError
1111
from haystack.core.pipeline.base import (
1212
_COMPONENT_INPUT,
13-
_COMPONENT_NAME,
1413
_COMPONENT_OUTPUT,
15-
_COMPONENT_RUN,
16-
_COMPONENT_TYPE,
14+
_COMPONENT_VISITS,
1715
ComponentPriority,
1816
PipelineBase,
1917
)
@@ -51,28 +49,8 @@ def _run_component(
5149
"""
5250
instance: Component = component["instance"]
5351

54-
with tracing.tracer.trace(
55-
_COMPONENT_RUN,
56-
tags={
57-
_COMPONENT_NAME: component_name,
58-
_COMPONENT_TYPE: instance.__class__.__name__,
59-
"haystack.component.input_types": {k: type(v).__name__ for k, v in inputs.items()},
60-
"haystack.component.input_spec": {
61-
key: {
62-
"type": (value.type.__name__ if isinstance(value.type, type) else str(value.type)),
63-
"senders": value.senders,
64-
}
65-
for key, value in instance.__haystack_input__._sockets_dict.items() # type: ignore
66-
},
67-
"haystack.component.output_spec": {
68-
key: {
69-
"type": (value.type.__name__ if isinstance(value.type, type) else str(value.type)),
70-
"receivers": value.receivers,
71-
}
72-
for key, value in instance.__haystack_output__._sockets_dict.items() # type: ignore
73-
},
74-
},
75-
parent_span=parent_span,
52+
with PipelineBase._create_component_span(
53+
component_name=component_name, instance=instance, inputs=inputs, parent_span=parent_span
7654
) as span:
7755
# We deepcopy the inputs otherwise we might lose that information
7856
# when we delete them in case they're sent to other Components
@@ -87,7 +65,7 @@ def _run_component(
8765
if not isinstance(component_output, Mapping):
8866
raise PipelineRuntimeError.from_invalid_output(component_name, instance.__class__, component_output)
8967

90-
span.set_tag("haystack.component.visits", component_visits[component_name])
68+
span.set_tag(_COMPONENT_VISITS, component_visits[component_name])
9169
span.set_content_tag(_COMPONENT_OUTPUT, component_output)
9270

9371
return cast(Dict[Any, Any], component_output)

0 commit comments

Comments
 (0)