From 945ae22781798599027724e430fed535123f1411 Mon Sep 17 00:00:00 2001 From: jahatef Date: Wed, 18 Dec 2024 13:55:39 -0500 Subject: [PATCH] fix activation checkpointing and logging --- megatron/logging.py | 13 +++++++------ megatron/model/gpt2_model.py | 1 + megatron/model/rwkv/v6/rwkv.py | 8 +++----- megatron/mpu/mappings.py | 3 ++- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/megatron/logging.py b/megatron/logging.py index af8a41fe5..2be6e8f5f 100644 --- a/megatron/logging.py +++ b/megatron/logging.py @@ -97,14 +97,15 @@ def get_flops(neox_args, iter_time_s) -> float: num_heads = neox_args.num_attention_heads flops_per_iteration = ( - batch_size + ckpt_activations_factor + * batch_size * seq_len * ( - 78 * hidden_size * hidden_size * num_layers - + 84 * hidden_size * num_layers - + 16 * hidden_size - + 12 * hidden_size * vocab_size - + 18 * hidden_size * hidden_size * num_layers / num_heads + 26 * hidden_size * hidden_size * num_layers + + 928 * hidden_size * num_layers + + 8 * hidden_size + + 4 * hidden_size * vocab_size + + 6 * hidden_size * hidden_size * num_layers / num_heads ) ) elif "mamba" in neox_args.attention_config: diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index 1b6aa9b54..2a2fb70ac 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -140,6 +140,7 @@ def __init__( "GMLPBlock", "ParallelTransformerLayerPipe", "ParallelMambaResidualLayerPipe", + "RWKVResidualLayerPipe" ], ) diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index 0d77278bc..3ac81c922 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -282,7 +282,7 @@ def forward(self, x): r, k, v, g, w = self.jit_func(x) if self.neox_args.rwkv_fla: - x, _ = RUN_FLA_CHUNK(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H)) + x, _ = RUN_FLA_CHUNK(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H),chunk_size=256) else: x = RUN_CUDA_RWKV(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H)) @@ -418,11 +418,9 @@ def __init__(self, neox_args, init_method, layer_number): ], verbose=True, extra_cuda_cflags=[ - "-res-usage", - "--use_fast_math", + "-ffast-math", "-O3", - "-Xptxas -O3", - "--extra-device-vectorization", + "-fvectorize", f"-D_N_={self.neox_args.head_size}", f"-D_T_={self.neox_args.seq_length}", ], diff --git a/megatron/mpu/mappings.py b/megatron/mpu/mappings.py index ceb89daa2..cc942d437 100644 --- a/megatron/mpu/mappings.py +++ b/megatron/mpu/mappings.py @@ -334,7 +334,8 @@ def reduce_scatter_to_sequence_parallel_region(input_, seq_dim=0): return _ReduceScatterToSequenceParallelRegion.apply(input_, seq_dim) -def gather_from_sequence_parallel_region(input_, seq_dim=0): +def gather_from_sequence_parallel_region(input_: torch.Tensor, seq_dim: int = 0): +#def gather_from_sequence_parallel_region(input_, seq_dim=0): return _GatherFromSequenceParallelRegion.apply(input_, seq_dim)