Skip to content

Commit

Permalink
Merge branch 'main' into add-reinforce
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin-Anthony authored Dec 19, 2024
2 parents 82a479f + 8900d05 commit 095febc
Show file tree
Hide file tree
Showing 13 changed files with 1,132 additions and 180 deletions.
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ GPT-NeoX leverages many of the same features and technologies as the popular Meg
* Easy connections with the open source ecosystem, including Hugging Face's [tokenizers](https://github.com/huggingface/tokenizers) and [transformers](https://github.com/huggingface/transformers/) libraries, monitor experiments via [WandB](https://wandb.ai/site)/[Comet](https://www.comet.com/site/)/TensorBoard, and evaluation via our [Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness).

## News
**[10/9/2024]** We now support [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) integration

**[9/9/2024]** We now support preference learning via [DPO](https://arxiv.org/abs/2305.18290), [KTO](https://arxiv.org/abs/2402.01306), and reward modeling

**[9/9/2024]** We now support integration with [Comet ML](https://www.comet.com/site/), a machine learning monitoring platform
Expand Down Expand Up @@ -60,6 +62,7 @@ Prior to 3/9/2023, GPT-NeoX relied on [DeeperSpeed](https://github.com/EleutherA
* [Environment and Dependencies](#environment-and-dependencies)
+ [Host Setup](#host-setup)
+ [Flash Attention](#flash-attention)
+ [Transformer Engine](#transformer-engine)
+ [Multi-Node Launching](#multi-node-launching)
+ [Containerized Setup](#containerized-setup)
* [Usage](#usage)
Expand Down Expand Up @@ -130,7 +133,20 @@ This will automatically adapts building process over different GPU vendors (AMD,

### Flash Attention

To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` and set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details.
To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` or use a PyTorch NGC container with it pre-installed (note that functionality is not guaranteed using versions different from our requirements file). Then set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details.

### Transformer Engine

To use [Transformer Engine (TE)](https://github.com/NVIDIA/TransformerEngine), install the additional dependencies in `./requirements/requirements-transformer-engine.txt` or use a PyTorch NGC container with it pre-installed (note that functionality is not guaranteed using versions different from our requirements file). See [this config](https://github.com/EleutherAI/gpt-neox/configs/1-3B-transformer-engine.yml) for an example of using TE on a 1.3B model. This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere and Hopper GPUs; see the repository for more details.


TE provides very efficient kernels for both A100 and H100 GPUs. We've run some sample ablations on A100:



and H100:




### Multi-Node Launching
Expand Down
105 changes: 105 additions & 0 deletions configs/1-3B-transformer-engine.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# GPT-2 pretraining setup
{
# parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages
# across the node boundaries )
"pipe_parallel_size": 1,
"model_parallel_size": 1,

# model settings
"num_layers": 24,
"hidden_size": 2048,
"num_attention_heads": 16,
"seq_length": 2048,
"max_position_embeddings": 2048,
"norm": "layernorm",
"pos_emb": "rotary",
"no_weight_tying": true,
"gpt_j_residual": false,
"output_layer_parallelism": "column",

# Transformer Engine settings
"te_columnparallel": false,
"te_rowparallel": false,
"te_layernorm_mlp": true,
"te_mha": true,
"te_fp8_format": "hybrid",
"te_fp8_wgrad": true,
"te_fp8_amax_history_len": 1,
"te_fp8_amax_compute_algo": "most_recent",
"te_fp8_margin": 0,
"te_fp8_mha": false,

# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,

# init methods
"init_method": "small_init",
"output_layer_init_method": "wang_init",

# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0002,
"betas": [0.9, 0.95],
"eps": 1.0e-8,
}
},
"min_lr": 0.00002,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 500000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 500000000,
"contiguous_gradients": True,
},

# batch / data settings
"train_micro_batch_size_per_gpu": 4,
"data_impl": "mmap",

# activation checkpointing
"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0,
"attention_dropout": 0,

# precision settings
"fp16": {
"fp16": true,
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},

# misc. training settings
"train_iters": 320000,
"lr_decay_iters": 320000,
"distributed_backend": "nccl",
"lr_decay_style": "cosine",
"warmup": 0.01,
"checkpoint_factor": 10000,
"eval_interval": 1000,
"eval_iters": 10,

# logging
"log_interval": 100,
"steps_per_print": 10,
"keep_last_n_checkpoints": 4,
"wall_clock_breakdown": true,
}
1 change: 1 addition & 0 deletions configs/eleutherai_cluster.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"tensorboard_dir": "/mnt/ssd-1/tensorboard",
"log_dir": "/mnt/ssd-1/logs",
"wandb_team": "eleutherai",
#"wandb_run_name": "experiment"
"wandb_project": "neox",
"wandb_group": "example"
}
136 changes: 134 additions & 2 deletions megatron/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,12 @@ def human_readable_flops(num) -> str:
return "%.1f%s" % (num, "Yi")


def get_flops(neox_args, iter_time_s) -> float:
def get_actual_flops(neox_args, iter_time_s) -> float:
"""
This function finds the actual FLOPs achieved accounting for implementation and hardware details. Also used for HFU.
For more detail on flop calculations, see https://github.com/EleutherAI/cookbook/tree/main/calc and https://github.com/Zyphra/zcookbook/tree/main/calc
Use FLOPS calculation from Megatron-DeepSpeed:
https://github.com/microsoft/Megatron-DeepSpeed/blob/cc3a94c636789f74be2bc6cfc62a3d723fd5d749/megatron/utils.py#L253
They get it from https://arxiv.org/pdf/2104.04473.pdf
Expand Down Expand Up @@ -156,6 +160,83 @@ def get_flops(neox_args, iter_time_s) -> float:
return flops_per_iteration / (iter_time_s * world_size)


def get_forward_backward_flops(neox_args, iter_time_s) -> float:
"""
This function finds the estimated FLOPs required by a single forward+backward pass without accounting for implementation and hardware details. Also used for MFU.
Mostly duplicated from get_actual_flops with just a change in activation checkpointing for now, but these may diverge over time as implementation details accumulate so I think 2 separate functions are appropriate.
For more detail on flop calculations, see https://github.com/EleutherAI/cookbook/tree/main/calc and https://github.com/Zyphra/zcookbook/tree/main/calc
Use FLOPS calculation from Megatron-DeepSpeed:
https://github.com/microsoft/Megatron-DeepSpeed/blob/cc3a94c636789f74be2bc6cfc62a3d723fd5d749/megatron/utils.py#L253
They get it from https://arxiv.org/pdf/2104.04473.pdf
"""
world_size = torch.distributed.get_world_size()
vocab_size = neox_args.padded_vocab_size
batch_size = neox_args.train_batch_size
seq_len = neox_args.seq_length
hidden_size = neox_args.hidden_size
num_layers = neox_args.num_layers
fwd_bwd_factor = 3 # 1 for fwd, 2 for bwd and weight update
if "rwkv" in neox_args.attention_config:
num_heads = neox_args.num_attention_heads

flops_per_iteration = (
batch_size
* seq_len
* (
78 * hidden_size * hidden_size * num_layers
+ 84 * hidden_size * num_layers
+ 16 * hidden_size
+ 12 * hidden_size * vocab_size
+ 18 * hidden_size * hidden_size * num_layers / num_heads
)
)
elif "mamba" in neox_args.attention_config:
# from https://github.com/Zyphra/zcookbook/blob/main/calc/calc_mamba_flops.py
if neox_args.expansion_factor:
d_inner = neox_args.hidden_size * neox_args.expansion_factor
elif neox_args.intermediate_size:
d_inner = neox_args.intermediate_size
else:
d_inner = neox_args.hidden_size * 2 # default expansion factor
d_state = 16 # TODO make d_state an arg. Currently hardcoded in neox mamba definition and here
conv_dimension = 4 # TODO make conv_dimension an arg. Currently hardcoded in neox mamba definition and here
dt_rank = math.ceil(neox_args.hidden_size / 16)
ssm_flops = (
fwd_bwd_factor
* d_inner
* seq_len
* batch_size
* (11 * d_state + 4 * dt_rank + 1)
)
mamba_projectors_flops = (
fwd_bwd_factor * seq_len * batch_size * 6 * d_inner * hidden_size
)
mamba_conv_flops = (
fwd_bwd_factor * seq_len * batch_size * 2 * d_inner * conv_dimension
)
mamba_flops = ssm_flops + mamba_projectors_flops + mamba_conv_flops
embedding_flops = 6 * seq_len * batch_size * hidden_size * vocab_size
flops_per_iteration = mamba_flops * num_layers + embedding_flops
else:
flops_per_iteration = (
24
* fwd_bwd_factor
* batch_size
* seq_len
* num_layers
* (hidden_size**2)
* (
1.0
+ (seq_len / (6.0 * hidden_size))
+ (vocab_size / (16.0 * num_layers * hidden_size))
)
)
return flops_per_iteration / (iter_time_s * world_size)


def training_log(
neox_args,
timers,
Expand Down Expand Up @@ -350,6 +431,8 @@ def add_to_logging(name):
elapsed_time = timers("interval time").elapsed()
iteration_time = elapsed_time / neox_args.log_interval
samples_per_sec = neox_args.train_batch_size / iteration_time
steps_per_sec = 1 / iteration_time
tokens_per_sec = samples_per_sec * neox_args.seq_length
log_string = " samples/sec: {:.3f} |".format(samples_per_sec)
tb_wandb_log(
"runtime/samples_per_sec",
Expand All @@ -367,6 +450,22 @@ def add_to_logging(name):
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)
tb_wandb_log(
"runtime/steps_per_sec",
steps_per_sec,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)
tb_wandb_log(
"runtime/tokens_per_sec",
tokens_per_sec,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)
log_string += " iteration {:8d}/{:8d} |".format(
iteration, neox_args.train_iters
)
Expand All @@ -390,7 +489,7 @@ def add_to_logging(name):
)

# log tflop / gpu
flops_per_s_per_gpu = get_flops(neox_args, iteration_time)
flops_per_s_per_gpu = get_actual_flops(neox_args, iteration_time)

log_string += (
f" approx flops per GPU: {human_readable_flops(flops_per_s_per_gpu)} |"
Expand All @@ -404,6 +503,39 @@ def add_to_logging(name):
comet_experiment=neox_args.comet_experiment,
)

if neox_args.peak_theoretical_tflops:
# Convert peak theoretical TFLOPS to FLOPS for consistent units
peak_theoretical_flops = neox_args.peak_theoretical_tflops * (10**12)

# Calculate MFU and HFU as percentages
mfu = (
get_forward_backward_flops(neox_args, iteration_time)
/ peak_theoretical_flops
) * 100
hfu = (flops_per_s_per_gpu / peak_theoretical_flops) * 100

# Add to log string
log_string += f" MFU: {mfu:.2f}% | HFU: {hfu:.2f}% |"

# Log to tracking systems
tb_wandb_log(
"runtime/model_flops_utilization",
mfu,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)

tb_wandb_log(
"runtime/hardware_flops_utilization",
hfu,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)

for key in total_loss_dict:
if key not in [skipped_iters_key, got_nan_key]:
v = (
Expand Down
5 changes: 5 additions & 0 deletions megatron/model/positional_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def _prepare_cache(self, seq_len, precision, base):
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)

self.emb = emb.reshape(emb.size(0), 1, 1, emb.size(1))

cos_cached = emb.cos()[:, None, None, :]
sin_cached = emb.sin()[:, None, None, :]

Expand All @@ -76,6 +78,9 @@ def _prepare_cache(self, seq_len, precision, base):
inv_freq.to(precision),
)

def get_emb(self):
return self.emb.to(self.precision).cuda()

def forward(self, x, seq_dim=0, seq_len=None):
if seq_len is None:
seq_len = x.shape[seq_dim]
Expand Down
Loading

0 comments on commit 095febc

Please sign in to comment.