-
Notifications
You must be signed in to change notification settings - Fork 98
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
make activation checkpointing work with fsdp
- Loading branch information
Showing
7 changed files
with
111 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
DIFFUSERS_TRANSFORMER_BLOCK_NAMES = [ | ||
"transformer_blocks", | ||
"single_transformer_blocks", | ||
"temporal_transformer_blocks", | ||
"blocks", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import collections | ||
from enum import Enum | ||
|
||
import torch | ||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper | ||
|
||
from ._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES | ||
|
||
|
||
class CheckpointType(str, Enum): | ||
FULL = "full" | ||
OPS = "ops" | ||
BLOCK_SKIP = "block_skip" | ||
|
||
|
||
_SELECTIVE_ACTIVATION_CHECKPOINTING_OPS = { | ||
torch.ops.aten.mm.default, | ||
torch.ops.aten._scaled_dot_product_efficient_attention.default, | ||
torch.ops.aten._scaled_dot_product_flash_attention.default, | ||
torch.ops._c10d_functional.reduce_scatter_tensor.default, | ||
} | ||
|
||
|
||
def apply_activation_checkpointing( | ||
module: torch.nn.Module, checkpointing_type: str = CheckpointType.FULL, n_layer: int = 1 | ||
) -> torch.nn.Module: | ||
if checkpointing_type == CheckpointType.FULL: | ||
module = _apply_activation_checkpointing_blocks(module) | ||
elif checkpointing_type == CheckpointType.OPS: | ||
module = _apply_activation_checkpointing_ops(module, _SELECTIVE_ACTIVATION_CHECKPOINTING_OPS) | ||
elif checkpointing_type == CheckpointType.BLOCK_SKIP: | ||
module = _apply_activation_checkpointing_blocks(module, n_layer) | ||
else: | ||
raise ValueError( | ||
f"Checkpointing type '{checkpointing_type}' not supported. Supported types are {CheckpointType.__members__.keys()}" | ||
) | ||
return module | ||
|
||
|
||
def _apply_activation_checkpointing_blocks(module: torch.nn.Module, n_layer: int = None) -> torch.nn.Module: | ||
for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES: | ||
blocks: torch.nn.Module = getattr(module, transformer_block_name, None) | ||
if blocks is None: | ||
continue | ||
for index, (layer_id, block) in enumerate(blocks.named_children()): | ||
if n_layer is None or index % n_layer == 0: | ||
block = checkpoint_wrapper(block, preserve_rng_state=False) | ||
blocks.register_module(layer_id, block) | ||
return module | ||
|
||
|
||
def _apply_activation_checkpointing_ops(module: torch.nn.Module, ops) -> torch.nn.Module: | ||
from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts | ||
|
||
def _get_custom_policy(meta): | ||
def _custom_policy(ctx, func, *args, **kwargs): | ||
mode = "recompute" if ctx.is_recompute else "forward" | ||
mm_count_key = f"{mode}_mm_count" | ||
if func == torch.ops.aten.mm.default: | ||
meta[mm_count_key] += 1 | ||
# Saves output of all compute ops, except every second mm | ||
to_save = func in ops and not (func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0) | ||
return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE | ||
|
||
return _custom_policy | ||
|
||
def selective_checkpointing_context_fn(): | ||
meta = collections.defaultdict(int) | ||
return create_selective_checkpoint_contexts(_get_custom_policy(meta)) | ||
|
||
return checkpoint_wrapper(module, context_fn=selective_checkpointing_context_fn, preserve_rng_state=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters