Skip to content

Commit 1a1cfe3

Browse files
authored
fix INT8 prepare function (huggingface#389)
* fix INT8 prepare function * remove unused function args * fix related tests, examples and docs
1 parent 8e53e16 commit 1a1cfe3

7 files changed

+16
-36
lines changed

docs/source/task_guides/int8-asr.mdx

+2-3
Original file line numberDiff line numberDiff line change
@@ -178,15 +178,14 @@ model.config.suppress_tokens = []
178178

179179
To get the model ready for `int8` quantization, use the utility function [`prepare_model_for_int8_training`](https://github.com/huggingface/peft/blob/34027fe813756897767b9a6f19ae7f1c4c7b418c/src/peft/utils/other.py#L35) to handle the following:
180180

181-
- casts the `LayerNorm` to full precision (`fp32`) for stability
181+
- casts all the non `int8` modules to full precision (`fp32`) for stability
182182
- adds a forward hook to the input embedding layer to calculate the gradients of the input hidden states
183183
- enables gradient checkpointing for more memory-efficient training
184-
- casts the output logits to `fp32` for smoother sampling
185184

186185
```py
187186
from peft import prepare_model_for_int8_training
188187

189-
model = prepare_model_for_int8_training(model, output_embedding_layer_name="proj_out")
188+
model = prepare_model_for_int8_training(model)
190189
```
191190

192191
Let's also apply LoRA to the training to make it even more efficient. Load a [`~peft.LoraConfig`] and configure the following parameters:

examples/int8_training/Finetune_flan_t5_large_bnb_peft.ipynb

+2-3
Original file line numberDiff line numberDiff line change
@@ -328,10 +328,9 @@
328328
},
329329
"source": [
330330
"Some pre-processing needs to be done before training such an int8 model using `peft`, therefore let's import an utiliy function `prepare_model_for_int8_training` that will: \n",
331-
"- Cast the layer norm in `float32` for stability purposes\n",
331+
"- Casts all the non `int8` modules to full precision (`fp32`) for stability\n",
332332
"- Add a `forward_hook` to the input embedding layer to enable gradient computation of the input hidden states\n",
333-
"- Enable gradient checkpointing for more memory-efficient training\n",
334-
"- Cast the output logits in `float32` for smoother sampling during the sampling procedure"
333+
"- Enable gradient checkpointing for more memory-efficient training"
335334
]
336335
},
337336
{

examples/int8_training/Finetune_opt_bnb_peft.ipynb

+2-3
Original file line numberDiff line numberDiff line change
@@ -377,10 +377,9 @@
377377
"### Prepare model for training\n",
378378
"\n",
379379
"Some pre-processing needs to be done before training such an int8 model using `peft`, therefore let's import an utiliy function `prepare_model_for_int8_training` that will: \n",
380-
"- Cast the layer norm in `float32` for stability purposes\n",
380+
"- Casts all the non `int8` modules to full precision (`fp32`) for stability\n",
381381
"- Add a `forward_hook` to the input embedding layer to enable gradient computation of the input hidden states\n",
382-
"- Enable gradient checkpointing for more memory-efficient training\n",
383-
"- Cast the output logits in `float32` for smoother sampling during the sampling procedure"
382+
"- Enable gradient checkpointing for more memory-efficient training"
384383
]
385384
},
386385
{

examples/int8_training/peft_adalora_whisper_large_training.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ def main():
561561
if args.use_peft:
562562
from peft import prepare_model_for_int8_training
563563

564-
model = prepare_model_for_int8_training(model, output_embedding_layer_name="proj_out")
564+
model = prepare_model_for_int8_training(model)
565565

566566
# as Whisper model uses Conv layer in encoder, checkpointing disables grad computation
567567
# to avoid this, make the inputs trainable

examples/int8_training/peft_bnb_whisper_large_v2_training.ipynb

+3-2
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,7 @@
11331133
]
11341134
},
11351135
{
1136+
"attachments": {},
11361137
"cell_type": "markdown",
11371138
"id": "bR-_yaEOPsfQ",
11381139
"metadata": {
@@ -1141,7 +1142,7 @@
11411142
"source": [
11421143
"### Post-processing on the model\n",
11431144
"\n",
1144-
"Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in `float32` for stability. We also cast the output of the last layer in `float32` for the same reasons."
1145+
"Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast all non `int8` layers in `float32` for stability."
11451146
]
11461147
},
11471148
{
@@ -1155,7 +1156,7 @@
11551156
"source": [
11561157
"from peft import prepare_model_for_int8_training\n",
11571158
"\n",
1158-
"model = prepare_model_for_int8_training(model, output_embedding_layer_name=\"proj_out\")"
1159+
"model = prepare_model_for_int8_training(model)"
11591160
]
11601161
},
11611162
{

src/peft/utils/other.py

+5-23
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@ def bloom_model_postprocess_past_key_value(past_key_values):
3232
return tuple(zip(keys, values))
3333

3434

35-
def prepare_model_for_int8_training(
36-
model, output_embedding_layer_name="lm_head", use_gradient_checkpointing=True, layer_norm_names=["layer_norm"]
37-
):
35+
def prepare_model_for_int8_training(model, use_gradient_checkpointing=True):
3836
r"""
3937
This method wraps the entire protocol for preparing a model before running a training. This includes:
4038
1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm
@@ -50,10 +48,10 @@ def prepare_model_for_int8_training(
5048
# freeze base model's layers
5149
param.requires_grad = False
5250

53-
if loaded_in_8bit:
54-
# cast layer norm in fp32 for stability for 8bit models
55-
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
56-
param.data = param.data.to(torch.float32)
51+
# cast all non INT8 parameters to fp32
52+
for param in model.parameters():
53+
if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
54+
param.data = param.data.to(torch.float32)
5755

5856
if loaded_in_8bit and use_gradient_checkpointing:
5957
# For backward compatibility
@@ -69,22 +67,6 @@ def make_inputs_require_grad(module, input, output):
6967
# enable gradient checkpointing for memory efficiency
7068
model.gradient_checkpointing_enable()
7169

72-
if hasattr(model, output_embedding_layer_name):
73-
output_embedding_layer = getattr(model, output_embedding_layer_name)
74-
input_dtype = output_embedding_layer.weight.dtype
75-
76-
class CastOutputToFloat(torch.nn.Sequential):
77-
r"""
78-
Manually cast to the expected dtype of the lm_head as sometimes there is a final layer norm that is casted
79-
in fp32
80-
81-
"""
82-
83-
def forward(self, x):
84-
return super().forward(x.to(input_dtype)).to(torch.float32)
85-
86-
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
87-
8870
return model
8971

9072

tests/test_gpu_examples.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def prepare_dataset(batch):
402402
model.config.forced_decoder_ids = None
403403
model.config.suppress_tokens = []
404404

405-
model = prepare_model_for_int8_training(model, output_embedding_layer_name="proj_out")
405+
model = prepare_model_for_int8_training(model)
406406

407407
config = LoraConfig(
408408
r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none"

0 commit comments

Comments
 (0)