Skip to content

Commit 1558a3a

Browse files
Anish Khazanefacebook-github-bot
Anish Khazane
authored andcommitted
Track ITEP-enabled models on APS with Scuba logging (pytorch#2902)
Summary: Pull Request resolved: pytorch#2902 Logs run info for APS models that enable ITEP. Differential Revision: D74038683
1 parent 4eca985 commit 1558a3a

File tree

3 files changed

+104
-2
lines changed

3 files changed

+104
-2
lines changed

torchrec/distributed/itep_embeddingbag.py

+1
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def __init__(
134134
pruning_interval=module._itep_module.pruning_interval,
135135
enable_pruning=module._itep_module.enable_pruning,
136136
pg=env.process_group,
137+
itep_logger=module._itep_module.itep_logger,
137138
)
138139

139140
def prefetch(

torchrec/modules/itep_logger.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import logging
9+
from abc import ABC, abstractmethod
10+
from typing import Mapping, Optional, Tuple, Union
11+
12+
logger: logging.Logger = logging.getLogger(__name__)
13+
14+
15+
class ITEPLogger(ABC):
16+
@abstractmethod
17+
def log_table_eviction_info(
18+
self,
19+
iteration: Optional[Union[bool, float, int]],
20+
rank: Optional[int],
21+
table_to_sizes_mapping: Mapping[str, Tuple[int, int]],
22+
eviction_tables: Mapping[str, float],
23+
) -> None:
24+
pass
25+
26+
@abstractmethod
27+
def log_run_info(
28+
self,
29+
) -> None:
30+
pass
31+
32+
33+
class ITEPLoggerDefault(ITEPLogger):
34+
"""
35+
noop logger as a default
36+
"""
37+
38+
def __init__(
39+
self,
40+
) -> None:
41+
"""
42+
Initialize ITEPLoggerScuba.
43+
"""
44+
pass
45+
46+
def log_table_eviction_info(
47+
self,
48+
iteration: Optional[Union[bool, float, int]],
49+
rank: Optional[int],
50+
table_to_sizes_mapping: Mapping[str, Tuple[int, int]],
51+
eviction_tables: Mapping[str, float],
52+
) -> None:
53+
logger.info(
54+
f"iteration={iteration}, rank={rank}, table_to_sizes_mapping={table_to_sizes_mapping}, eviction_tables={eviction_tables}"
55+
)
56+
57+
def log_run_info(
58+
self,
59+
) -> None:
60+
pass

torchrec/modules/itep_modules.py

+43-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from torchrec.distributed.embedding_types import ShardedEmbeddingTable, ShardingType
2222
from torchrec.distributed.types import Shard, ShardedTensor, ShardedTensorMetadata
2323
from torchrec.modules.embedding_modules import reorder_inverse_indices
24+
from torchrec.modules.itep_logger import ITEPLogger, ITEPLoggerDefault
25+
2426
from torchrec.sparse.jagged_tensor import _pin_and_move, _to_offsets, KeyedJaggedTensor
2527

2628
try:
@@ -71,8 +73,8 @@ def __init__(
7173
pruning_interval: int = 1001, # Default pruning interval 1001 iterations
7274
pg: Optional[dist.ProcessGroup] = None,
7375
table_name_to_sharding_type: Optional[Dict[str, str]] = None,
76+
itep_logger: Optional[ITEPLogger] = None,
7477
) -> None:
75-
7678
super(GenericITEPModule, self).__init__()
7779

7880
if not table_name_to_sharding_type:
@@ -88,6 +90,11 @@ def __init__(
8890
)
8991
self.table_name_to_sharding_type = table_name_to_sharding_type
9092

93+
self.itep_logger: ITEPLogger = (
94+
itep_logger if itep_logger is not None else ITEPLoggerDefault()
95+
)
96+
self.itep_logger.log_run_info()
97+
9198
# Map each feature to a physical address_lookup/row_util buffer
9299
self.feature_table_map: Dict[str, int] = {}
93100
self.table_name_to_idx: Dict[str, int] = {}
@@ -111,6 +118,8 @@ def print_itep_eviction_stats(
111118
cur_iter: int,
112119
) -> None:
113120
table_name_to_eviction_ratio = {}
121+
buffer_idx_to_eviction_ratio = {}
122+
buffer_idx_to_sizes = {}
114123

115124
num_buffers = len(self.buffer_offsets_list) - 1
116125
for buffer_idx in range(num_buffers):
@@ -127,6 +136,8 @@ def print_itep_eviction_stats(
127136
table_name_to_eviction_ratio[self.idx_to_table_name[buffer_idx]] = (
128137
eviction_ratio
129138
)
139+
buffer_idx_to_eviction_ratio[buffer_idx] = eviction_ratio
140+
buffer_idx_to_sizes[buffer_idx] = (pruned_length.item(), buffer_length)
130141

131142
# Sort the mapping by eviction ratio in descending order
132143
sorted_mapping = dict(
@@ -136,6 +147,34 @@ def print_itep_eviction_stats(
136147
reverse=True,
137148
)
138149
)
150+
151+
logged_eviction_mapping = {}
152+
for idx in sorted_mapping.keys():
153+
try:
154+
logged_eviction_mapping[self.reversed_feature_table_map[idx]] = (
155+
sorted_mapping[idx]
156+
)
157+
except KeyError:
158+
# in dummy mode, we don't have the feature_table_map or reversed_feature_table_map
159+
pass
160+
161+
table_to_sizes_mapping = {}
162+
for idx in buffer_idx_to_sizes.keys():
163+
try:
164+
table_to_sizes_mapping[self.reversed_feature_table_map[idx]] = (
165+
buffer_idx_to_sizes[idx]
166+
)
167+
except KeyError:
168+
# in dummy mode, we don't have the feature_table_map or reversed_feature_table_map
169+
pass
170+
171+
self.itep_logger.log_table_eviction_info(
172+
iteration=None,
173+
rank=None,
174+
table_to_sizes_mapping=table_to_sizes_mapping,
175+
eviction_tables=logged_eviction_mapping,
176+
)
177+
139178
# Print the sorted mapping
140179
logger.info(f"ITEP: table name to eviction ratio {sorted_mapping}")
141180

@@ -277,8 +316,10 @@ def init_itep_state(self) -> None:
277316
if self.current_device is None:
278317
self.current_device = torch.device("cuda")
279318

319+
self.reversed_feature_table_map: Dict[int, str] = {
320+
idx: feature_name for feature_name, idx in self.feature_table_map.items()
321+
}
280322
self.buffer_offsets_list = buffer_offsets
281-
282323
# Create buffers for address_lookup and row_util
283324
self.create_itep_buffers(
284325
buffer_size=buffer_size,

0 commit comments

Comments
 (0)