Skip to content

Commit 8e897c0

Browse files
committed
Add a helper function add_downstream_artifact() to ModelRelations
PiperOrigin-RevId: 627617632
1 parent 9b7dfc9 commit 8e897c0

File tree

2 files changed

+113
-36
lines changed

2 files changed

+113
-36
lines changed

tfx/dsl/input_resolution/ops/latest_policy_model_op.py

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,26 @@ def __init__(self):
7575
self.infra_blessing_by_artifact_id = {}
7676
self.model_push_by_artifact_id = {}
7777

78+
def add_downstream_artifact(
79+
self, downstream_artifact: metadata_store_pb2.Artifact
80+
):
81+
"""Adds a downstream artifact to the ModelRelations."""
82+
artifact_type_name = downstream_artifact.type
83+
if _is_eval_blessed(artifact_type_name, downstream_artifact):
84+
self.model_blessing_by_artifact_id[downstream_artifact.id] = (
85+
downstream_artifact
86+
)
87+
88+
elif _is_infra_blessed(artifact_type_name, downstream_artifact):
89+
self.infra_blessing_by_artifact_id[downstream_artifact.id] = (
90+
downstream_artifact
91+
)
92+
93+
elif artifact_type_name == ops_utils.MODEL_PUSH_TYPE_NAME:
94+
self.model_push_by_artifact_id[downstream_artifact.id] = (
95+
downstream_artifact
96+
)
97+
7898
def meets_policy(self, policy: Policy) -> bool:
7999
"""Checks if ModelRelations contains artifacts that meet the Policy."""
80100
if policy == Policy.LATEST_EXPORTED:
@@ -398,7 +418,12 @@ def event_filter(event):
398418
return event_lib.is_valid_output_event(event)
399419

400420
mlmd_resolver = metadata_resolver.MetadataResolver(self.context.store)
401-
downstream_artifacts_and_types_by_model_ids = {}
421+
# Populate the ModelRelations associated with each Model artifact and its
422+
# children.
423+
model_relations_by_model_artifact_id = collections.defaultdict(
424+
ModelRelations
425+
)
426+
artifact_type_by_name: Dict[str, metadata_store_pb2.ArtifactType] = {}
402427

403428
# Split `model_artifact_ids` into batches with batch size = 100 while
404429
# fetching downstream artifacts, because
@@ -409,48 +434,24 @@ def event_filter(event):
409434
id_index : id_index + ops_utils.BATCH_SIZE
410435
]
411436
# Set `max_num_hops` to 50, which should be enough for this use case.
412-
batch_downstream_artifacts_by_model_ids = (
437+
batch_downstream_artifacts_and_types_by_model_ids = (
413438
mlmd_resolver.get_downstream_artifacts_by_artifact_ids(
414439
batch_model_artifact_ids,
415440
max_num_hops=ops_utils.LATEST_POLICY_MODEL_OP_MAX_NUM_HOPS,
416441
filter_query=filter_query,
417442
event_filter=event_filter,
418443
)
419444
)
420-
downstream_artifacts_and_types_by_model_ids.update(
421-
batch_downstream_artifacts_by_model_ids
422-
)
423-
# Populate the ModelRelations associated with each Model artifact and its
424-
# children.
425-
model_relations_by_model_artifact_id = collections.defaultdict(
426-
ModelRelations
427-
)
428-
429-
artifact_type_by_name = {}
430-
for (
431-
model_artifact_id,
432-
downstream_artifact_and_type,
433-
) in downstream_artifacts_and_types_by_model_ids.items():
434-
for downstream_artifact, artifact_type in downstream_artifact_and_type:
435-
artifact_type_by_name[artifact_type.name] = artifact_type
436-
model_relations = model_relations_by_model_artifact_id[
437-
model_artifact_id
438-
]
439-
artifact_type_name = downstream_artifact.type
440-
if _is_eval_blessed(artifact_type_name, downstream_artifact):
441-
model_relations.model_blessing_by_artifact_id[
442-
downstream_artifact.id
443-
] = downstream_artifact
444-
445-
elif _is_infra_blessed(artifact_type_name, downstream_artifact):
446-
model_relations.infra_blessing_by_artifact_id[
447-
downstream_artifact.id
448-
] = downstream_artifact
449-
450-
elif artifact_type_name == ops_utils.MODEL_PUSH_TYPE_NAME:
451-
model_relations.model_push_by_artifact_id[downstream_artifact.id] = (
452-
downstream_artifact
453-
)
445+
for (
446+
model_artifact_id,
447+
artifacts_and_types,
448+
) in batch_downstream_artifacts_and_types_by_model_ids.items():
449+
for downstream_artifact, artifact_type in artifacts_and_types:
450+
artifact_type_by_name[artifact_type.name] = artifact_type
451+
model_relations = model_relations_by_model_artifact_id[
452+
model_artifact_id
453+
]
454+
model_relations.add_downstream_artifact(downstream_artifact)
454455

455456
# Find the latest model and ModelRelations that meets the Policy.
456457
result = {}

tfx/dsl/input_resolution/ops/latest_policy_model_op_test.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from tfx.dsl.input_resolution import resolver_op
2121
from tfx.dsl.input_resolution.ops import latest_policy_model_op
2222
from tfx.dsl.input_resolution.ops import ops
23+
from tfx.dsl.input_resolution.ops import ops_utils
2324
from tfx.dsl.input_resolution.ops import test_utils
2425
from tfx.orchestration.portable.input_resolution import exceptions
2526

@@ -36,6 +37,81 @@
3637
_LATEST_PUSHED = latest_policy_model_op.Policy.LATEST_PUSHED
3738

3839

40+
class ModelRelationsTest(tf.test.TestCase):
41+
42+
def test_add_downstream_non_blessed_artifact_not_added(self):
43+
model_relations = latest_policy_model_op.ModelRelations()
44+
45+
self.assertEmpty(model_relations.model_blessing_by_artifact_id)
46+
self.assertEmpty(model_relations.infra_blessing_by_artifact_id)
47+
self.assertEmpty(model_relations.model_push_by_artifact_id)
48+
49+
artifact = metadata_store_pb2.Artifact(
50+
id=0,
51+
type=ops_utils.MODEL_BLESSING_TYPE_NAME,
52+
custom_properties={'blessed': metadata_store_pb2.Value(int_value=0)},
53+
)
54+
model_relations.add_downstream_artifact(artifact)
55+
56+
self.assertEmpty(model_relations.model_blessing_by_artifact_id)
57+
self.assertEmpty(model_relations.infra_blessing_by_artifact_id)
58+
self.assertEmpty(model_relations.model_push_by_artifact_id)
59+
60+
def test_add_downstream_artifact_model(self):
61+
model_relations = latest_policy_model_op.ModelRelations()
62+
63+
model_blessing_artifact = metadata_store_pb2.Artifact(
64+
id=0,
65+
type=ops_utils.MODEL_BLESSING_TYPE_NAME,
66+
custom_properties={'blessed': metadata_store_pb2.Value(int_value=1)},
67+
)
68+
model_relations.add_downstream_artifact(model_blessing_artifact)
69+
self.assertDictEqual(
70+
model_relations.model_blessing_by_artifact_id,
71+
{0: model_blessing_artifact},
72+
)
73+
self.assertEmpty(model_relations.infra_blessing_by_artifact_id)
74+
self.assertEmpty(model_relations.model_push_by_artifact_id)
75+
76+
infra_blessing_artifact = metadata_store_pb2.Artifact(
77+
id=1,
78+
type=ops_utils.MODEL_INFRA_BLESSSING_TYPE_NAME,
79+
custom_properties={
80+
'blessing_status': metadata_store_pb2.Value(
81+
string_value='INFRA_BLESSED'
82+
)
83+
},
84+
)
85+
model_relations.add_downstream_artifact(infra_blessing_artifact)
86+
self.assertDictEqual(
87+
model_relations.model_blessing_by_artifact_id,
88+
{0: model_blessing_artifact},
89+
)
90+
self.assertDictEqual(
91+
model_relations.infra_blessing_by_artifact_id,
92+
{1: infra_blessing_artifact},
93+
)
94+
self.assertEmpty(model_relations.model_push_by_artifact_id)
95+
96+
model_push_artifact = metadata_store_pb2.Artifact(
97+
id=2,
98+
type=ops_utils.MODEL_PUSH_TYPE_NAME,
99+
)
100+
model_relations.add_downstream_artifact(model_push_artifact)
101+
self.assertDictEqual(
102+
model_relations.model_blessing_by_artifact_id,
103+
{0: model_blessing_artifact},
104+
)
105+
self.assertDictEqual(
106+
model_relations.infra_blessing_by_artifact_id,
107+
{1: infra_blessing_artifact},
108+
)
109+
self.assertDictEqual(
110+
model_relations.model_push_by_artifact_id,
111+
{2: model_push_artifact},
112+
)
113+
114+
39115
class LatestPolicyModelOpTest(
40116
test_utils.ResolverTestCase,
41117
):

0 commit comments

Comments
 (0)