@@ -75,6 +75,26 @@ def __init__(self):
75
75
self .infra_blessing_by_artifact_id = {}
76
76
self .model_push_by_artifact_id = {}
77
77
78
+ def add_downstream_artifact (
79
+ self , downstream_artifact : metadata_store_pb2 .Artifact
80
+ ):
81
+ """Adds a downstream artifact to the ModelRelations."""
82
+ artifact_type_name = downstream_artifact .type
83
+ if _is_eval_blessed (artifact_type_name , downstream_artifact ):
84
+ self .model_blessing_by_artifact_id [downstream_artifact .id ] = (
85
+ downstream_artifact
86
+ )
87
+
88
+ elif _is_infra_blessed (artifact_type_name , downstream_artifact ):
89
+ self .infra_blessing_by_artifact_id [downstream_artifact .id ] = (
90
+ downstream_artifact
91
+ )
92
+
93
+ elif artifact_type_name == ops_utils .MODEL_PUSH_TYPE_NAME :
94
+ self .model_push_by_artifact_id [downstream_artifact .id ] = (
95
+ downstream_artifact
96
+ )
97
+
78
98
def meets_policy (self , policy : Policy ) -> bool :
79
99
"""Checks if ModelRelations contains artifacts that meet the Policy."""
80
100
if policy == Policy .LATEST_EXPORTED :
@@ -398,7 +418,12 @@ def event_filter(event):
398
418
return event_lib .is_valid_output_event (event )
399
419
400
420
mlmd_resolver = metadata_resolver .MetadataResolver (self .context .store )
401
- downstream_artifacts_and_types_by_model_ids = {}
421
+ # Populate the ModelRelations associated with each Model artifact and its
422
+ # children.
423
+ model_relations_by_model_artifact_id = collections .defaultdict (
424
+ ModelRelations
425
+ )
426
+ artifact_type_by_name : Dict [str , metadata_store_pb2 .ArtifactType ] = {}
402
427
403
428
# Split `model_artifact_ids` into batches with batch size = 100 while
404
429
# fetching downstream artifacts, because
@@ -409,48 +434,24 @@ def event_filter(event):
409
434
id_index : id_index + ops_utils .BATCH_SIZE
410
435
]
411
436
# Set `max_num_hops` to 50, which should be enough for this use case.
412
- batch_downstream_artifacts_by_model_ids = (
437
+ batch_downstream_artifacts_and_types_by_model_ids = (
413
438
mlmd_resolver .get_downstream_artifacts_by_artifact_ids (
414
439
batch_model_artifact_ids ,
415
440
max_num_hops = ops_utils .LATEST_POLICY_MODEL_OP_MAX_NUM_HOPS ,
416
441
filter_query = filter_query ,
417
442
event_filter = event_filter ,
418
443
)
419
444
)
420
- downstream_artifacts_and_types_by_model_ids .update (
421
- batch_downstream_artifacts_by_model_ids
422
- )
423
- # Populate the ModelRelations associated with each Model artifact and its
424
- # children.
425
- model_relations_by_model_artifact_id = collections .defaultdict (
426
- ModelRelations
427
- )
428
-
429
- artifact_type_by_name = {}
430
- for (
431
- model_artifact_id ,
432
- downstream_artifact_and_type ,
433
- ) in downstream_artifacts_and_types_by_model_ids .items ():
434
- for downstream_artifact , artifact_type in downstream_artifact_and_type :
435
- artifact_type_by_name [artifact_type .name ] = artifact_type
436
- model_relations = model_relations_by_model_artifact_id [
437
- model_artifact_id
438
- ]
439
- artifact_type_name = downstream_artifact .type
440
- if _is_eval_blessed (artifact_type_name , downstream_artifact ):
441
- model_relations .model_blessing_by_artifact_id [
442
- downstream_artifact .id
443
- ] = downstream_artifact
444
-
445
- elif _is_infra_blessed (artifact_type_name , downstream_artifact ):
446
- model_relations .infra_blessing_by_artifact_id [
447
- downstream_artifact .id
448
- ] = downstream_artifact
449
-
450
- elif artifact_type_name == ops_utils .MODEL_PUSH_TYPE_NAME :
451
- model_relations .model_push_by_artifact_id [downstream_artifact .id ] = (
452
- downstream_artifact
453
- )
445
+ for (
446
+ model_artifact_id ,
447
+ artifacts_and_types ,
448
+ ) in batch_downstream_artifacts_and_types_by_model_ids .items ():
449
+ for downstream_artifact , artifact_type in artifacts_and_types :
450
+ artifact_type_by_name [artifact_type .name ] = artifact_type
451
+ model_relations = model_relations_by_model_artifact_id [
452
+ model_artifact_id
453
+ ]
454
+ model_relations .add_downstream_artifact (downstream_artifact )
454
455
455
456
# Find the latest model and ModelRelations that meets the Policy.
456
457
result = {}
0 commit comments