21
21
from torchrec .distributed .embedding_types import ShardedEmbeddingTable , ShardingType
22
22
from torchrec .distributed .types import Shard , ShardedTensor , ShardedTensorMetadata
23
23
from torchrec .modules .embedding_modules import reorder_inverse_indices
24
+ from torchrec .modules .itep_logger import ITEPLogger , ITEPLoggerDefault
25
+
24
26
from torchrec .sparse .jagged_tensor import _pin_and_move , _to_offsets , KeyedJaggedTensor
25
27
26
28
try :
@@ -71,8 +73,8 @@ def __init__(
71
73
pruning_interval : int = 1001 , # Default pruning interval 1001 iterations
72
74
pg : Optional [dist .ProcessGroup ] = None ,
73
75
table_name_to_sharding_type : Optional [Dict [str , str ]] = None ,
76
+ itep_logger : Optional [ITEPLogger ] = None ,
74
77
) -> None :
75
-
76
78
super (GenericITEPModule , self ).__init__ ()
77
79
78
80
if not table_name_to_sharding_type :
@@ -88,6 +90,11 @@ def __init__(
88
90
)
89
91
self .table_name_to_sharding_type = table_name_to_sharding_type
90
92
93
+ self .itep_logger : ITEPLogger = (
94
+ itep_logger if itep_logger is not None else ITEPLoggerDefault ()
95
+ )
96
+ self .itep_logger .log_run_info ()
97
+
91
98
# Map each feature to a physical address_lookup/row_util buffer
92
99
self .feature_table_map : Dict [str , int ] = {}
93
100
self .table_name_to_idx : Dict [str , int ] = {}
@@ -111,6 +118,8 @@ def print_itep_eviction_stats(
111
118
cur_iter : int ,
112
119
) -> None :
113
120
table_name_to_eviction_ratio = {}
121
+ buffer_idx_to_eviction_ratio = {}
122
+ buffer_idx_to_sizes = {}
114
123
115
124
num_buffers = len (self .buffer_offsets_list ) - 1
116
125
for buffer_idx in range (num_buffers ):
@@ -127,6 +136,8 @@ def print_itep_eviction_stats(
127
136
table_name_to_eviction_ratio [self .idx_to_table_name [buffer_idx ]] = (
128
137
eviction_ratio
129
138
)
139
+ buffer_idx_to_eviction_ratio [buffer_idx ] = eviction_ratio
140
+ buffer_idx_to_sizes [buffer_idx ] = (pruned_length .item (), buffer_length )
130
141
131
142
# Sort the mapping by eviction ratio in descending order
132
143
sorted_mapping = dict (
@@ -136,6 +147,34 @@ def print_itep_eviction_stats(
136
147
reverse = True ,
137
148
)
138
149
)
150
+
151
+ logged_eviction_mapping = {}
152
+ for idx in sorted_mapping .keys ():
153
+ try :
154
+ logged_eviction_mapping [self .reversed_feature_table_map [idx ]] = (
155
+ sorted_mapping [idx ]
156
+ )
157
+ except KeyError :
158
+ # in dummy mode, we don't have the feature_table_map or reversed_feature_table_map
159
+ pass
160
+
161
+ table_to_sizes_mapping = {}
162
+ for idx in buffer_idx_to_sizes .keys ():
163
+ try :
164
+ table_to_sizes_mapping [self .reversed_feature_table_map [idx ]] = (
165
+ buffer_idx_to_sizes [idx ]
166
+ )
167
+ except KeyError :
168
+ # in dummy mode, we don't have the feature_table_map or reversed_feature_table_map
169
+ pass
170
+
171
+ self .itep_logger .log_table_eviction_info (
172
+ iteration = None ,
173
+ rank = None ,
174
+ table_to_sizes_mapping = table_to_sizes_mapping ,
175
+ eviction_tables = logged_eviction_mapping ,
176
+ )
177
+
139
178
# Print the sorted mapping
140
179
logger .info (f"ITEP: table name to eviction ratio { sorted_mapping } " )
141
180
@@ -277,8 +316,10 @@ def init_itep_state(self) -> None:
277
316
if self .current_device is None :
278
317
self .current_device = torch .device ("cuda" )
279
318
319
+ self .reversed_feature_table_map : Dict [int , str ] = {
320
+ idx : feature_name for feature_name , idx in self .feature_table_map .items ()
321
+ }
280
322
self .buffer_offsets_list = buffer_offsets
281
-
282
323
# Create buffers for address_lookup and row_util
283
324
self .create_itep_buffers (
284
325
buffer_size = buffer_size ,
0 commit comments