Skip to content

Track ITEP-enabled models on APS with scuba logging (#2736) #2902

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchrec/distributed/itep_embeddingbag.py
Original file line number Diff line number Diff line change
@@ -134,6 +134,7 @@ def __init__(
pruning_interval=module._itep_module.pruning_interval,
enable_pruning=module._itep_module.enable_pruning,
pg=env.process_group,
itep_logger=module._itep_module.itep_logger,
)

def prefetch(
60 changes: 60 additions & 0 deletions torchrec/modules/itep_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from abc import ABC, abstractmethod
from typing import Mapping, Optional, Tuple, Union

logger: logging.Logger = logging.getLogger(__name__)


class ITEPLogger(ABC):
@abstractmethod
def log_table_eviction_info(
self,
iteration: Optional[Union[bool, float, int]],
rank: Optional[int],
table_to_sizes_mapping: Mapping[str, Tuple[int, int]],
eviction_tables: Mapping[str, float],
) -> None:
pass

@abstractmethod
def log_run_info(
self,
) -> None:
pass


class ITEPLoggerDefault(ITEPLogger):
"""
noop logger as a default
"""

def __init__(
self,
) -> None:
"""
Initialize ITEPLoggerScuba.
"""
pass

def log_table_eviction_info(
self,
iteration: Optional[Union[bool, float, int]],
rank: Optional[int],
table_to_sizes_mapping: Mapping[str, Tuple[int, int]],
eviction_tables: Mapping[str, float],
) -> None:
logger.info(
f"iteration={iteration}, rank={rank}, table_to_sizes_mapping={table_to_sizes_mapping}, eviction_tables={eviction_tables}"
)

def log_run_info(
self,
) -> None:
pass
45 changes: 43 additions & 2 deletions torchrec/modules/itep_modules.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,8 @@
from torchrec.distributed.embedding_types import ShardedEmbeddingTable, ShardingType
from torchrec.distributed.types import Shard, ShardedTensor, ShardedTensorMetadata
from torchrec.modules.embedding_modules import reorder_inverse_indices
from torchrec.modules.itep_logger import ITEPLogger, ITEPLoggerDefault

from torchrec.sparse.jagged_tensor import _pin_and_move, _to_offsets, KeyedJaggedTensor

try:
@@ -71,8 +73,8 @@ def __init__(
pruning_interval: int = 1001, # Default pruning interval 1001 iterations
pg: Optional[dist.ProcessGroup] = None,
table_name_to_sharding_type: Optional[Dict[str, str]] = None,
itep_logger: Optional[ITEPLogger] = None,
) -> None:

super(GenericITEPModule, self).__init__()

if not table_name_to_sharding_type:
@@ -88,6 +90,11 @@ def __init__(
)
self.table_name_to_sharding_type = table_name_to_sharding_type

self.itep_logger: ITEPLogger = (
itep_logger if itep_logger is not None else ITEPLoggerDefault()
)
self.itep_logger.log_run_info()

# Map each feature to a physical address_lookup/row_util buffer
self.feature_table_map: Dict[str, int] = {}
self.table_name_to_idx: Dict[str, int] = {}
@@ -111,6 +118,8 @@ def print_itep_eviction_stats(
cur_iter: int,
) -> None:
table_name_to_eviction_ratio = {}
buffer_idx_to_eviction_ratio = {}
buffer_idx_to_sizes = {}

num_buffers = len(self.buffer_offsets_list) - 1
for buffer_idx in range(num_buffers):
@@ -127,6 +136,8 @@ def print_itep_eviction_stats(
table_name_to_eviction_ratio[self.idx_to_table_name[buffer_idx]] = (
eviction_ratio
)
buffer_idx_to_eviction_ratio[buffer_idx] = eviction_ratio
buffer_idx_to_sizes[buffer_idx] = (pruned_length.item(), buffer_length)

# Sort the mapping by eviction ratio in descending order
sorted_mapping = dict(
@@ -136,6 +147,34 @@ def print_itep_eviction_stats(
reverse=True,
)
)

logged_eviction_mapping = {}
for idx in sorted_mapping.keys():
try:
logged_eviction_mapping[self.reversed_feature_table_map[idx]] = (
sorted_mapping[idx]
)
except KeyError:
# in dummy mode, we don't have the feature_table_map or reversed_feature_table_map
pass

table_to_sizes_mapping = {}
for idx in buffer_idx_to_sizes.keys():
try:
table_to_sizes_mapping[self.reversed_feature_table_map[idx]] = (
buffer_idx_to_sizes[idx]
)
except KeyError:
# in dummy mode, we don't have the feature_table_map or reversed_feature_table_map
pass

self.itep_logger.log_table_eviction_info(
iteration=None,
rank=None,
table_to_sizes_mapping=table_to_sizes_mapping,
eviction_tables=logged_eviction_mapping,
)

# Print the sorted mapping
logger.info(f"ITEP: table name to eviction ratio {sorted_mapping}")

@@ -277,8 +316,10 @@ def init_itep_state(self) -> None:
if self.current_device is None:
self.current_device = torch.device("cuda")

self.reversed_feature_table_map: Dict[int, str] = {
idx: feature_name for feature_name, idx in self.feature_table_map.items()
}
self.buffer_offsets_list = buffer_offsets

# Create buffers for address_lookup and row_util
self.create_itep_buffers(
buffer_size=buffer_size,
Loading