Skip to content

Commit 54d2ac1

Browse files
authored
Mixtral fixes 20240124 (axolotl-ai-cloud#1192) [skip ci]
* mixtral nccl fixes * make sure to patch for z3
1 parent af02430 commit 54d2ac1

File tree

14 files changed

+71
-13
lines changed

14 files changed

+71
-13
lines changed

Diff for: README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,7 @@ tokens:
861861
fsdp:
862862
fsdp_config:
863863

864-
# Deepspeed config path. e.g., deepspeed/zero3.json
864+
# Deepspeed config path. e.g., deepspeed_configs/zero3.json
865865
deepspeed:
866866

867867
# Advanced DDP Arguments
@@ -982,11 +982,11 @@ for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usa
982982
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
983983

984984
```yaml
985-
deepspeed: deepspeed/zero1.json
985+
deepspeed: deepspeed_configs/zero1.json
986986
```
987987
988988
```shell
989-
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
989+
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed_configs/zero1.json
990990
```
991991

992992
##### FSDP
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

Diff for: examples/llama-2/fft_optimized.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ evals_per_epoch: 4
6262
eval_table_size:
6363
saves_per_epoch: 1
6464
debug:
65-
deepspeed: #deepspeed/zero2.json # multi-gpu only
65+
deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
6666
weight_decay: 0.1
6767
fsdp:
6868
fsdp_config:

Diff for: examples/mistral/Mistral-7b-example/code.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -942,7 +942,7 @@
942942
"not only optimizer states but also gradients and parameters across GPUs. The bf16 indicate mixed precision training using bfloat16.\n",
943943
"For more information read axolotl's readme\n",
944944
"\"\"\"\n",
945-
"!accelerate launch -m axolotl.cli.train /folder/config.yml --deepspeed deepspeed/zero3_bf16.json"
945+
"!accelerate launch -m axolotl.cli.train /folder/config.yml --deepspeed deepspeed_configs/zero3_bf16.json"
946946
]
947947
}
948948
],

Diff for: examples/mistral/Mistral-7b-example/config.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ eval_table_max_new_tokens: 128
6565
saves_per_epoch: 1
6666
debug:
6767
#default deepspeed, can use more aggresive if needed like zero2, zero3
68-
deepspeed: deepspeed/zero1.json
68+
deepspeed: deepspeed_configs/zero1.json
6969
weight_decay: 0.0
7070
fsdp:
7171
fsdp_config:

Diff for: examples/mistral/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@ accelerate launch -m axolotl.cli.train examples/mistral/config.yml
88

99
If you run into CUDA OOM, use deepspeed with config zero2.json:
1010
```shell
11-
accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed/zero2.json
11+
accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed_configs/zero2.json
1212
```

Diff for: examples/mistral/mixtral.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ eval_table_size:
8484
eval_table_max_new_tokens: 128
8585
saves_per_epoch: 1
8686
debug:
87-
deepspeed: deepspeed/zero2.json
87+
deepspeed: deepspeed_configs/zero2.json
8888
weight_decay: 0.0
8989
fsdp:
9090
fsdp_config:

Diff for: examples/phi/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Due to some nuances with the phi code, please use deepspeed when training phi for full finetune.
44

55
```shell
6-
accelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed deepspeed/zero1.json
6+
accelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed deepspeed_configs/zero1.json
77

88
# OR
99

Diff for: src/axolotl/monkeypatch/mixtral/__init__.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,61 @@
11
"""
22
Patches to support multipack for mixtral
33
"""
4+
import torch
45
import transformers
56

67
from axolotl.monkeypatch.utils import get_unpad_data
78

89

9-
def replace_mixtral_attn_with_multipack_flash_attn():
10+
def patch_mixtral_moe_forward_zero3() -> None:
11+
import torch.nn.functional as F
12+
13+
def mlp_forward(self, hidden_states):
14+
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(
15+
hidden_states
16+
)
17+
current_hidden_states = self.w2(current_hidden_states)
18+
return current_hidden_states
19+
20+
# Ref. https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py
21+
def moe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
22+
batch_size, sequence_length, hidden_dim = hidden_states.shape
23+
hidden_states = hidden_states.view(-1, hidden_dim)
24+
# router_logits: (batch * sequence_length, n_experts)
25+
router_logits = self.gate(hidden_states)
26+
27+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
28+
topk_weight, topk_idx = torch.topk(
29+
routing_weights, self.top_k, dim=-1, sorted=False
30+
)
31+
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
32+
# we cast back to the input dtype
33+
topk_weight = topk_weight.to(hidden_states.dtype)
34+
35+
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
36+
y = torch.empty_like(hidden_states) # pylint: disable=invalid-name
37+
flat_topk_idx = topk_idx.view(-1)
38+
for i in range(self.num_experts):
39+
expert = self.experts[i]
40+
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
41+
y = ( # pylint: disable=invalid-name
42+
y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)
43+
).sum(dim=1)
44+
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
45+
return final_hidden_states, router_logits
46+
47+
from transformers.models.mixtral.modeling_mixtral import (
48+
MixtralBLockSparseTop2MLP,
49+
MixtralSparseMoeBlock,
50+
)
51+
52+
MixtralBLockSparseTop2MLP.forward = mlp_forward
53+
MixtralSparseMoeBlock.forward = moe_forward
54+
55+
56+
def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False):
1057
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
1158
get_unpad_data
1259
)
60+
if for_zero3:
61+
patch_mixtral_moe_forward_zero3()

Diff for: src/axolotl/train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from peft import PeftModel
1616
from pkg_resources import get_distribution # type: ignore
1717
from transformers import PreTrainedModel, PreTrainedTokenizer
18-
from transformers.deepspeed import is_deepspeed_zero3_enabled
18+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
1919

2020
from axolotl.common.cli import TrainerCliArgs
2121
from axolotl.logging_config import configure_logging

Diff for: src/axolotl/utils/models.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
PreTrainedModel,
2222
PreTrainedTokenizerBase,
2323
)
24-
from transformers.deepspeed import is_deepspeed_zero3_enabled
24+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
2525

2626
from axolotl.models.mamba import fix_mamba_attn_for_loss
2727
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
@@ -333,7 +333,10 @@ def load_model(
333333
)
334334

335335
LOG.info("patching mixtral with flash attention")
336-
replace_mixtral_attn_with_multipack_flash_attn()
336+
mixtral_patch_kwargs = {}
337+
if is_deepspeed_zero3_enabled():
338+
mixtral_patch_kwargs["for_zero3"] = True
339+
replace_mixtral_attn_with_multipack_flash_attn(**mixtral_patch_kwargs)
337340

338341
if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
339342
from axolotl.monkeypatch.falcon import (
@@ -646,6 +649,12 @@ def load_model(
646649
needs_fa2_dtype = cfg.adapter or cfg.fsdp
647650
skip_prepare_model_for_kbit_training = False
648651

652+
if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
653+
from deepspeed.utils import set_z3_leaf_modules
654+
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
655+
656+
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
657+
649658
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
650659
# Qwen doesn't play nicely with LoRA if this is enabled
651660
skip_prepare_model_for_kbit_training = True

0 commit comments

Comments
 (0)