Skip to content

Commit

Permalink
make activation checkpointing work with fsdp
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w committed Jan 31, 2025
1 parent 069a653 commit c7003ec
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 36 deletions.
47 changes: 24 additions & 23 deletions finetrainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,14 @@ def __init__(self, args: Args) -> None:
# Scheduler
self.scheduler = None

# Optimizer & LR scheduler
self.optimizer = None
self.lr_scheduler = None

# Trainer-specific conditions
self.caption_preprocessing_conditions: List[Processor] = []
self.caption_postprocessing_conditions: List[Processor] = []

self.state.model_name = self.args.model_name
self.state.condition_types = self.args.conditions

self._init_distributed()
Expand All @@ -72,13 +75,6 @@ def __init__(self, args: Args) -> None:
self._init_config_options()
self._init_non_model_conditions()

# Peform any patches needed for training
if len(self.args.layerwise_upcasting_modules) > 0:
perform_peft_patches()
# TODO(aryan): handle text encoders
# if any(["text_encoder" in component_name for component_name in self.args.layerwise_upcasting_modules]):
# perform_text_encoder_patches()

model_specification_cls = get_model_specifiction_cls(self.args.model_name, self.args.training_type)
self.model_specification: ModelSpecification = model_specification_cls(
pretrained_model_name_or_path=self.args.pretrained_model_name_or_path,
Expand Down Expand Up @@ -377,7 +373,7 @@ def register_saving_loading_hooks(self, transformer_lora_config):
# self.state.accelerator.register_save_state_pre_hook(save_model_hook)
# self.state.accelerator.register_load_state_pre_hook(load_model_hook)

def prepare_optimizer(self) -> None:
def prepare_trainable_parameters(self) -> None:
logger.info("Initializing trainable parameters")

if self.args.precompute_conditions:
Expand Down Expand Up @@ -406,9 +402,6 @@ def prepare_optimizer(self) -> None:
non_blocking=True,
)

if self.args.gradient_checkpointing:
self.transformer.enable_gradient_checkpointing()

if self.args.training_type == "lora":
transformer_lora_config = LoraConfig(
r=self.args.rank,
Expand All @@ -422,13 +415,10 @@ def prepare_optimizer(self) -> None:

# TODO(aryan): it might be nice to add some assertions here to make sure that lora parameters are still in fp32
# even if layerwise upcasting. Would be nice to have a test as well

self.register_saving_loading_hooks(transformer_lora_config)

# ============ TODO(aryan): cleanup

# Setup distributed optimizer and lr scheduler
logger.info("Initializing optimizer and lr scheduler")

self.state.train_state = TrainState()

# Make sure the trainable params are in float32
Expand Down Expand Up @@ -465,21 +455,28 @@ def prepare_optimizer(self) -> None:
self.lr_scheduler = lr_scheduler

def prepare_for_training(self) -> None:
if self.state.parallel.context_parallel_enabled:
parallel_state = self.state.parallel
world_mesh = parallel_state.get_mesh()

if parallel_state.context_parallel_enabled:
raise NotImplementedError(
"Context parallelism is not supported yet. This will be supported in the future."
)

world_mesh = self.state.parallel.get_mesh()
# Enable gradient checkpointing
if self.args.gradient_checkpointing:
# TODO(aryan): support other checkpointing types
utils.apply_gradient_checkpointing(self.transformer, checkpointing_type="full")

if self.state.parallel.data_sharding_enabled:
if self.state.parallel.data_replication_enabled:
# Enable DDP, FSDP or HSDP
if parallel_state.data_sharding_enabled:
if parallel_state.data_replication_enabled:
logger.info("Applying HSDP to the model")
else:
logger.info("Applying FSDP to the model")

# Apply FSDP or HSDP
if self.state.parallel.data_replication_enabled or self.state.parallel.context_parallel_enabled:
if parallel_state.data_replication_enabled or parallel_state.context_parallel_enabled:
dp_mesh_names = ("dp_replicate", "dp_shard_cp")
else:
dp_mesh_names = ("dp_shard_cp",)
Expand All @@ -491,10 +488,10 @@ def prepare_for_training(self) -> None:
dp_mesh=world_mesh[dp_mesh_names],
param_dtype=param_dtype,
reduce_dtype=torch.float32,
pp_enabled=self.state.parallel.pipeline_parallel_enabled,
pp_enabled=parallel_state.pipeline_parallel_enabled,
cpu_offload=False, # TODO(aryan): needs to be tested and allowed for enabling later
)
elif self.state.parallel.data_replication_enabled:
elif parallel_state.data_replication_enabled:
logger.info("Applying DDP to the model")

if world_mesh.ndim > 1:
Expand Down Expand Up @@ -941,6 +938,10 @@ def _init_config_options(self) -> None:
if self.args.allow_tf32 and torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True

# Peform any patches needed for training
if len(self.args.layerwise_upcasting_modules) > 0:
perform_peft_patches()

def _init_non_model_conditions(self) -> None:
if ProcessorType.CAPTION_TEXT_DROPOUT in self.state.condition_types:
params = {"dropout_p": self.args.caption_dropout_p}
Expand Down
6 changes: 5 additions & 1 deletion finetrainers/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import inspect
from typing import Any, Dict, Optional, Set

from .checkpointing import get_intermediate_ckpt_path, get_latest_ckpt_path_to_resume_from
from ._rename_this_file_checkpointing import get_intermediate_ckpt_path, get_latest_ckpt_path_to_resume_from
from .checkpoint_utils import apply_activation_checkpointing
from .data_utils import determine_batch_size, should_perform_precomputation
from .diffusion_utils import (
default_flow_shift,
Expand All @@ -21,6 +22,9 @@
from .torch_utils import align_device_and_dtype, expand_tensor_dims, get_device_info, synchronize_device, unwrap_model


apply_gradient_checkpointing = apply_activation_checkpointing


def get_parameter_names(obj: Any, method_name: Optional[str] = None) -> Set[str]:
if method_name is not None:
obj = getattr(obj, method_name)
Expand Down
6 changes: 6 additions & 0 deletions finetrainers/utils/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
DIFFUSERS_TRANSFORMER_BLOCK_NAMES = [
"transformer_blocks",
"single_transformer_blocks",
"temporal_transformer_blocks",
"blocks",
]
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from accelerate.logging import get_logger

from ..constants import FINETRAINERS_LOG_LEVEL
from ..utils.file_utils import delete_files, find_files
from .file_utils import delete_files, find_files


logger = get_logger("finetrainers")
Expand Down
71 changes: 71 additions & 0 deletions finetrainers/utils/checkpoint_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import collections
from enum import Enum

import torch
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper

from ._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES


class CheckpointType(str, Enum):
FULL = "full"
OPS = "ops"
BLOCK_SKIP = "block_skip"


_SELECTIVE_ACTIVATION_CHECKPOINTING_OPS = {
torch.ops.aten.mm.default,
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
}


def apply_activation_checkpointing(
module: torch.nn.Module, checkpointing_type: str = CheckpointType.FULL, n_layer: int = 1
) -> torch.nn.Module:
if checkpointing_type == CheckpointType.FULL:
module = _apply_activation_checkpointing_blocks(module)
elif checkpointing_type == CheckpointType.OPS:
module = _apply_activation_checkpointing_ops(module, _SELECTIVE_ACTIVATION_CHECKPOINTING_OPS)
elif checkpointing_type == CheckpointType.BLOCK_SKIP:
module = _apply_activation_checkpointing_blocks(module, n_layer)
else:
raise ValueError(
f"Checkpointing type '{checkpointing_type}' not supported. Supported types are {CheckpointType.__members__.keys()}"
)
return module


def _apply_activation_checkpointing_blocks(module: torch.nn.Module, n_layer: int = None) -> torch.nn.Module:
for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES:
blocks: torch.nn.Module = getattr(module, transformer_block_name, None)
if blocks is None:
continue
for index, (layer_id, block) in enumerate(blocks.named_children()):
if n_layer is None or index % n_layer == 0:
block = checkpoint_wrapper(block, preserve_rng_state=False)
blocks.register_module(layer_id, block)
return module


def _apply_activation_checkpointing_ops(module: torch.nn.Module, ops) -> torch.nn.Module:
from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts

def _get_custom_policy(meta):
def _custom_policy(ctx, func, *args, **kwargs):
mode = "recompute" if ctx.is_recompute else "forward"
mm_count_key = f"{mode}_mm_count"
if func == torch.ops.aten.mm.default:
meta[mm_count_key] += 1
# Saves output of all compute ops, except every second mm
to_save = func in ops and not (func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0)
return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE

return _custom_policy

def selective_checkpointing_context_fn():
meta = collections.defaultdict(int)
return create_selective_checkpoint_contexts(_get_custom_policy(meta))

return checkpoint_wrapper(module, context_fn=selective_checkpointing_context_fn, preserve_rng_state=False)
13 changes: 3 additions & 10 deletions finetrainers/utils/parallel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,7 @@
from torch.distributed._composable.replicate import replicate

from ..logging import logger


_DIFFUSERS_TRANSFORMER_BLOCK_NAMES = [
"transformer_blocks",
"single_transformer_blocks",
"temporal_transformer_blocks",
"blocks",
]
from ._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES


def apply_fsdp(
Expand All @@ -36,7 +29,7 @@ def apply_fsdp(
fsdp_config["offload_policy"] = CPUOffloadPolicy(pin_memory=True)

def apply_fully_shard(blocks):
for layer_index, block in blocks:
for layer_index, block in enumerate(blocks):
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
Expand All @@ -47,7 +40,7 @@ def apply_fully_shard(blocks):
reshard_after_forward = layer_index < len(blocks) - 1
fully_shard(block, **fsdp_config, reshard_after_forward=reshard_after_forward)

for transformer_block_name in _DIFFUSERS_TRANSFORMER_BLOCK_NAMES:
for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES:
blocks = getattr(model, transformer_block_name, None)
if blocks is not None:
apply_fully_shard(blocks)
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def main():
trainer.prepare_dataset()
trainer.prepare_models()
trainer.prepare_precomputations()
trainer.prepare_optimizer()
trainer.prepare_trainable_parameters()
trainer.prepare_for_training()
trainer.train()
# trainer.evaluate()
Expand Down

0 comments on commit c7003ec

Please sign in to comment.