Skip to content

Commit 3873946

Browse files
committed
Create a new Orchestrator RPC GetNodeLiveOutputArtifactsByOutputKey
PiperOrigin-RevId: 524373686
1 parent 905c0de commit 3873946

File tree

2 files changed

+148
-18
lines changed

2 files changed

+148
-18
lines changed

tfx/orchestration/portable/mlmd/store_ext.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@ def get_successful_node_executions(
4444
*,
4545
pipeline_id: str,
4646
node_id: str,
47+
pipeline_run_id: Optional[str] = None,
4748
order_by: mlmd.OrderByField = mlmd.OrderByField.ID,
4849
is_asc: bool = True,
50+
limit: Optional[int] = None,
4951
) -> List[mlmd.proto.Execution]:
5052
"""Gets all successful node executions."""
5153
node_context_name = compiler_utils.node_context_name(pipeline_id, node_id)
@@ -57,11 +59,19 @@ def get_successful_node_executions(
5759
'last_known_state = CACHED',
5860
]),
5961
])
62+
if pipeline_run_id:
63+
node_executions_query.append(
64+
q.And([
65+
f'contexts_1.type = "{constants.PIPELINE_RUN_CONTEXT_TYPE_NAME}"',
66+
f'contexts_1.name = "{pipeline_run_id}"',
67+
])
68+
)
6069
return store.get_executions(
6170
list_options=mlmd.ListOptions(
6271
filter_query=str(node_executions_query),
6372
order_by=order_by,
6473
is_asc=is_asc,
74+
limit=limit,
6575
)
6676
)
6777

@@ -130,7 +140,12 @@ def get_live_output_artifacts_of_node(
130140

131141

132142
def get_live_output_artifacts_of_node_by_output_key(
133-
store: mlmd.MetadataStore, *, pipeline_id: str, node_id: str
143+
store: mlmd.MetadataStore,
144+
*,
145+
pipeline_id: str,
146+
node_id: str,
147+
pipeline_run_id: Optional[str] = None,
148+
execution_limit: Optional[int] = None,
134149
) -> Dict[str, List[List[mlmd.proto.Artifact]]]:
135150
"""Get LIVE output artifacts of the given node grouped by output key.
136151
@@ -144,22 +159,36 @@ def get_live_output_artifacts_of_node_by_output_key(
144159
5. If no LIVE output artifacts found for one execution, an empty list will be
145160
returned.
146161
162+
The value of execution_limit must be None or non-negative.
163+
1. If None or 0, live output artifacts from all executions will be returned.
164+
2. If the node has fewer executions than execution_limit, live output
165+
artifacts from all executions will be returned.
166+
3. If the node has more or equal executions than execution_limit, only live
167+
output artifacts from the execution_limit latest executions will be
168+
returned.
169+
147170
Args:
148171
store: A MetadataStore object.
149172
pipeline_id: A pipeline ID.
150173
node_id: A node ID.
174+
pipeline_run_id: The pipeline run ID that the node belongs to. Only
175+
artifacts from the specified pipeline run are returned if specified.
176+
execution_limit: Maximum number of latest executions from which live output
177+
artifacts will be returned.
151178
152179
Returns:
153180
A mapping from output key to all output artifacts from the given node.
154181
"""
155-
node_executions_ordered_by_desc_creation_time = (
156-
get_successful_node_executions(
157-
store,
158-
pipeline_id=pipeline_id,
159-
node_id=node_id,
160-
order_by=mlmd.OrderByField.CREATE_TIME,
161-
is_asc=False,
162-
)
182+
node_executions_ordered_by_desc_creation_time = get_successful_node_executions(
183+
store,
184+
pipeline_id=pipeline_id,
185+
node_id=node_id,
186+
pipeline_run_id=pipeline_run_id,
187+
order_by=mlmd.OrderByField.CREATE_TIME,
188+
# TODO(b/276893037): revisit MLMD performance degradation caused by
189+
# is_asc=False in b/274559409.
190+
is_asc=False,
191+
limit=execution_limit,
163192
)
164193
if not node_executions_ordered_by_desc_creation_time:
165194
return {}

tfx/orchestration/portable/mlmd/store_ext_test.py

Lines changed: 110 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ def testGetSuccessfulNodeExecutions(self):
4949
)
5050
self.assertEqual(_ids(result), _ids([e1, e2]))
5151

52+
with self.subTest('With execution limit'):
53+
result = store_ext.get_successful_node_executions(
54+
self.store, pipeline_id='my-pipeline', node_id='my-node', limit=1
55+
)
56+
self.assertEqual(_ids(result), _ids([e1]))
57+
5258
with self.subTest('Bad pipeline_id'):
5359
result = store_ext.get_successful_node_executions(
5460
self.store, pipeline_id='not-exist', node_id='my-node'
@@ -115,8 +121,9 @@ def testGetLiveOutputArtifactsOfNode(self):
115121
)
116122
self.assertEqual(_sorted_ids(result), _sorted_ids([y2]))
117123

118-
def testGetLiveOutputArtifactsOfNodeByOutputKey(self):
119-
c = self.put_context('node', 'my-pipeline.my-node')
124+
def testGetLiveOutputArtifactsOfNodeByOutputKeySync(self):
125+
c1 = self.put_context('pipeline_run', 'run-20230413')
126+
c2 = self.put_context('node', 'my-pipeline.my-node')
120127
x1 = self.put_artifact('X')
121128
x2 = self.put_artifact('X')
122129
x3 = self.put_artifact('X')
@@ -130,24 +137,118 @@ def testGetLiveOutputArtifactsOfNodeByOutputKey(self):
130137
z3 = self.put_artifact('Z', state='ABANDONED')
131138

132139
self.put_execution(
133-
'E', inputs={'x': [x1]}, outputs={'y': [y1], 'z': [z1]}, contexts=[c]
140+
'E',
141+
inputs={'x': [x1]},
142+
outputs={'y': [y1], 'z': [z1]},
143+
contexts=[c1, c2],
134144
)
135145
self.put_execution(
136146
'E',
137147
inputs={'x': [x2]},
138148
outputs={'y': [y2, y3, y4], 'z': [z2]},
139-
contexts=[c],
149+
contexts=[c1, c2],
140150
)
141151
self.put_execution(
142-
'E', inputs={'x': [x3]}, outputs={'y': [y5], 'z': [z3]}, contexts=[c]
152+
'E',
153+
inputs={'x': [x3]},
154+
outputs={'y': [y5], 'z': [z3]},
155+
contexts=[c1, c2],
143156
)
144157

145-
result = store_ext.get_live_output_artifacts_of_node_by_output_key(
146-
self.store, pipeline_id='my-pipeline', node_id='my-node'
158+
with self.subTest('With execution limit=None'):
159+
result = store_ext.get_live_output_artifacts_of_node_by_output_key(
160+
self.store,
161+
pipeline_id='my-pipeline',
162+
node_id='my-node',
163+
pipeline_run_id='run-20230413',
164+
)
165+
self.assertDictEqual(
166+
result, {'y': [[y5], [y3, y4], [y1]], 'z': [[], [z2], [z1]]}
167+
)
168+
with self.subTest('With execution limit=2'):
169+
result = store_ext.get_live_output_artifacts_of_node_by_output_key(
170+
self.store,
171+
pipeline_id='my-pipeline',
172+
node_id='my-node',
173+
pipeline_run_id='run-20230413',
174+
execution_limit=2,
175+
)
176+
self.assertDictEqual(result, {'y': [[y5], [y3, y4]], 'z': [[], [z2]]})
177+
with self.subTest('With execution limit=0'):
178+
result = store_ext.get_live_output_artifacts_of_node_by_output_key(
179+
self.store,
180+
pipeline_id='my-pipeline',
181+
node_id='my-node',
182+
pipeline_run_id='run-20230413',
183+
execution_limit=0,
184+
)
185+
self.assertDictEqual(
186+
result, {'y': [[y5], [y3, y4], [y1]], 'z': [[], [z2], [z1]]}
187+
)
188+
189+
def testGetLiveOutputArtifactsOfNodeByOutputKeyAsync(self):
190+
c1 = self.put_context('node', 'my-pipeline.my-node')
191+
x1 = self.put_artifact('X')
192+
x2 = self.put_artifact('X')
193+
x3 = self.put_artifact('X')
194+
y1 = self.put_artifact('Y')
195+
y2 = self.put_artifact('Y', state='DELETED')
196+
y3 = self.put_artifact('Y')
197+
y4 = self.put_artifact('Y')
198+
y5 = self.put_artifact('Y')
199+
z1 = self.put_artifact('Z')
200+
z2 = self.put_artifact('Z')
201+
z3 = self.put_artifact('Z', state='ABANDONED')
202+
203+
self.put_execution(
204+
'E',
205+
inputs={'x': [x1]},
206+
outputs={'y': [y1], 'z': [z1]},
207+
contexts=[c1],
147208
)
148-
self.assertDictEqual(
149-
result, {'y': [[y5], [y3, y4], [y1]], 'z': [[], [z2], [z1]]}
209+
self.put_execution(
210+
'E',
211+
inputs={'x': [x2]},
212+
outputs={'y': [y2, y3, y4], 'z': [z2]},
213+
contexts=[c1],
150214
)
215+
self.put_execution(
216+
'E',
217+
inputs={'x': [x3]},
218+
outputs={'y': [y5], 'z': [z3]},
219+
contexts=[c1],
220+
)
221+
222+
with self.subTest('With execution limit=None'):
223+
result = store_ext.get_live_output_artifacts_of_node_by_output_key(
224+
self.store,
225+
pipeline_id='my-pipeline',
226+
node_id='my-node',
227+
pipeline_run_id='',
228+
)
229+
self.assertDictEqual(
230+
result, {'y': [[y5], [y3, y4], [y1]], 'z': [[], [z2], [z1]]}
231+
)
232+
with self.subTest('With execution limit=2'):
233+
result = store_ext.get_live_output_artifacts_of_node_by_output_key(
234+
self.store,
235+
pipeline_id='my-pipeline',
236+
node_id='my-node',
237+
pipeline_run_id='',
238+
execution_limit=2,
239+
)
240+
self.assertDictEqual(result, {'y': [[y5], [y3, y4]], 'z': [[], [z2]]})
241+
with self.subTest('With execution limit=0'):
242+
result = store_ext.get_live_output_artifacts_of_node_by_output_key(
243+
self.store,
244+
pipeline_id='my-pipeline',
245+
node_id='my-node',
246+
pipeline_run_id='',
247+
execution_limit=0,
248+
)
249+
self.assertDictEqual(
250+
result, {'y': [[y5], [y3, y4], [y1]], 'z': [[], [z2], [z1]]}
251+
)
151252

152253

153254
if __name__ == '__main__':

0 commit comments

Comments
 (0)