Skip to content

Commit 0d71b0a

Browse files
authored
Configurable embeddings upcast (#2621)
* fsdp embeddings should be float32 per comment * patch peft to not upcast everything * add tabs back to code check * fix import * add configurable option and fix check * add check for dtypes * move embeddings test to patch dir * fix test * fix comment and logic
1 parent 63aaccf commit 0d71b0a

File tree

6 files changed

+154
-2
lines changed

6 files changed

+154
-2
lines changed

docs/config.qmd

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ tokenizer_legacy:
3232
resize_token_embeddings_to_32x:
3333
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
3434
shrink_embeddings:
35+
# Optional[bool] Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs
36+
embeddings_skip_upcast:
3537
# Whether to load the model with randomly initialized weights. Useful for
3638
# pre-training a model from scratch or debugging purposes.
3739
random_init_weights:

src/axolotl/monkeypatch/peft/__init__.py

Whitespace-only changes.

src/axolotl/monkeypatch/peft/utils.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""
2+
Patch prepare_model_for_kbit_training to not upcast everything
3+
"""
4+
5+
import inspect
6+
import logging
7+
8+
import peft
9+
10+
import axolotl
11+
from axolotl.monkeypatch.utils import detab_code
12+
13+
LOG = logging.getLogger(__name__)
14+
15+
ORIGINAL_PREPARE_CODE = """
16+
for param in model.parameters():
17+
if (
18+
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
19+
) and param.__class__.__name__ != "Params4bit":
20+
param.data = param.data.to(torch.float32)
21+
"""
22+
23+
PATCHED_PREPARE_CODE = """
24+
for name, param in model.named_parameters():
25+
if (
26+
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
27+
) and param.__class__.__name__ != "Params4bit" and all(embed_name not in name for embed_name in ["embed_tokens", "lm_head"]):
28+
param.data = param.data.to(torch.float32)
29+
"""
30+
31+
32+
def get_peft_prep_code() -> str:
33+
prepare = inspect.getsource(peft.utils.other.prepare_model_for_kbit_training)
34+
return prepare
35+
36+
37+
def check_peft_prep_code_is_patchable() -> bool:
38+
prep_code = get_peft_prep_code()
39+
prep_code, _ = detab_code(prep_code)
40+
return ORIGINAL_PREPARE_CODE in prep_code
41+
42+
43+
def patch_peft_prep_code():
44+
"""
45+
monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs
46+
"""
47+
48+
try:
49+
prep_code = get_peft_prep_code()
50+
except OSError:
51+
return
52+
peft.utils.other._original_create_accelerator_and_postprocess = ( # pylint: disable=protected-access
53+
prep_code
54+
)
55+
prep_code, _ = detab_code(prep_code)
56+
if ORIGINAL_PREPARE_CODE not in prep_code:
57+
return
58+
59+
prep_code = prep_code.replace(ORIGINAL_PREPARE_CODE, PATCHED_PREPARE_CODE)
60+
prep_code = prep_code.replace(
61+
"def prepare_model_for_kbit_training(",
62+
"def fixed_prepare_model_for_kbit_training(",
63+
1,
64+
)
65+
66+
items_to_import = []
67+
for item in dir(peft.utils.other):
68+
if item in prep_code:
69+
items_to_import.append(item)
70+
71+
exec( # pylint: disable=exec-used # nosec B102
72+
"from peft.utils.other import (" + ", ".join(x for x in items_to_import) + ")",
73+
globals(),
74+
)
75+
exec(prep_code, globals()) # pylint: disable=exec-used # nosec B102
76+
LOG.info("patching prepare_model_for_kbit_training to allow for overrides")
77+
peft.utils.other.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821
78+
axolotl.utils.models.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821

src/axolotl/utils/models.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,11 @@ def apply_patches(self) -> None:
566566

567567
patch_accelerate_fsdp_utils()
568568

569+
if self.cfg.adapter and self.cfg.embeddings_skip_upcast:
570+
from axolotl.monkeypatch.peft.utils import patch_peft_prep_code
571+
572+
patch_peft_prep_code()
573+
569574
if self.cfg.flex_attention:
570575
from axolotl.monkeypatch.attention.flex_attn import (
571576
patch_flex_make_mask,
@@ -1185,7 +1190,7 @@ def set_z3_leaf_modules(self) -> None:
11851190
],
11861191
)
11871192

1188-
def prepare_model(self, qlora_fsdp) -> None:
1193+
def prepare_model(self, qlora_fsdp: bool) -> None:
11891194
skip_prepare_model_for_kbit_training = False
11901195
if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora":
11911196
# Qwen doesn't play nicely with LoRA if this is enabled
@@ -1315,7 +1320,10 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
13151320
# make sure these are fp32 per Ramesh et al. (2021)
13161321
embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type)
13171322
if not self.cfg.fsdp:
1318-
# FSDP doesn't like mixed Float and BFloat16
1323+
# we don't run this during FSDP because this will leave mixed
1324+
# float and bfloat16 dtypes in the model which FSDP doesn't like
1325+
if self.cfg.load_in_4bit and self.cfg.embeddings_skip_upcast:
1326+
embedding_modules = []
13191327
self.convert_embedding_modules_dtype(
13201328
embedding_modules,
13211329
dist_dtype=torch.float32,

src/axolotl/utils/schemas/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class AxolotlInputConfig(
8282
mean_resizing_embeddings: bool | None = False
8383
# optionally shrink the embeddings when the tokenizer vocab size is smaller
8484
shrink_embeddings: bool | None = None
85+
embeddings_skip_upcast: bool | None = None
8586

8687
rl: RLType | None = None
8788
trl: TRLConfig | None = Field(
+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""
2+
Test case for handling embeddings when using peft
3+
"""
4+
5+
import torch
6+
7+
from axolotl.train import setup_model_and_tokenizer
8+
from axolotl.utils.config import normalize_config, validate_config
9+
from axolotl.utils.dict import DictDefault
10+
11+
12+
class TestLlamaPeftEmbeddings:
13+
"""
14+
test class for handling embeddings when using peft
15+
"""
16+
17+
def test_peft_embeddings_upcast(self, temp_dir):
18+
# pylint: disable=duplicate-code
19+
cfg = DictDefault(
20+
{
21+
"base_model": "HuggingFaceTB/SmolLM2-135M",
22+
"load_in_4bit": True,
23+
"adapter": "qlora",
24+
"lora_r": 8,
25+
"lora_alpha": 16,
26+
"lora_target_linear": True,
27+
"trust_remote_code": True,
28+
"sequence_len": 512,
29+
"val_set_size": 0.01,
30+
"special_tokens": {
31+
"pad_token": "<|endoftext|>",
32+
},
33+
"datasets": [
34+
{
35+
"path": "mhenrichsen/alpaca_2k_test",
36+
"type": "alpaca",
37+
},
38+
],
39+
"num_epochs": 1,
40+
"max_steps": 2,
41+
"micro_batch_size": 1,
42+
"gradient_accumulation_steps": 1,
43+
"output_dir": temp_dir,
44+
"learning_rate": 0.00001,
45+
"optimizer": "adamw_8bit",
46+
"lr_scheduler": "cosine",
47+
"flash_attention": True,
48+
"sample_packing": False,
49+
"bf16": "auto",
50+
"save_safetensors": True,
51+
"embeddings_skip_upcast": True,
52+
}
53+
)
54+
55+
cfg = validate_config(cfg)
56+
normalize_config(cfg)
57+
58+
model, _, _, _ = setup_model_and_tokenizer(cfg)
59+
60+
# Check if the embeddings are upcast correctly
61+
# only embed_tokens is a parameter that may be upcast
62+
assert model.base_model.model.model.embed_tokens.weight.dtype == torch.bfloat16
63+
assert model.base_model.model.lm_head.weight.dtype == torch.bfloat16

0 commit comments

Comments
 (0)