Skip to content

Commit 4b371b4

Browse files
authoredAug 18, 2023
[Low-level-API] Add docs about LLAPI (huggingface#836)
* add docs about LLAPI * address comments
1 parent 87c1d24 commit 4b371b4

File tree

5 files changed

+209
-3
lines changed

5 files changed

+209
-3
lines changed
 

‎README.md

+36
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,42 @@ any GPU memory savings. Please refer issue [[FSDP] FSDP with CPU offload consume
355355

356356
2. When using ZeRO3 with zero3_init_flag=True, if you find the gpu memory increase with training steps. we might need to update deepspeed after [deepspeed commit 42858a9891422abc](https://github.com/microsoft/DeepSpeed/commit/42858a9891422abcecaa12c1bd432d28d33eb0d4) . The related issue is [[BUG] Peft Training with Zero.Init() and Zero3 will increase GPU memory every forward step ](https://github.com/microsoft/DeepSpeed/issues/3002)
357357

358+
## 🤗 PEFT as a utility library
359+
360+
Inject trainable adapters on any `torch` model using `inject_adapter_in_model` method. Note the method will make no further change to the model.
361+
362+
```python
363+
import torch
364+
from peft import inject_adapter_in_model, LoraConfig
365+
366+
class DummyModel(torch.nn.Module):
367+
def __init__(self):
368+
super().__init__()
369+
self.embedding = torch.nn.Embedding(10, 10)
370+
self.linear = torch.nn.Linear(10, 10)
371+
self.lm_head = torch.nn.Linear(10, 10)
372+
373+
def forward(self, input_ids):
374+
x = self.embedding(input_ids)
375+
x = self.linear(x)
376+
x = self.lm_head(x)
377+
return x
378+
379+
lora_config = LoraConfig(
380+
lora_alpha=16,
381+
lora_dropout=0.1,
382+
r=64,
383+
bias="none",
384+
target_modules=["linear"],
385+
)
386+
387+
model = DummyModel()
388+
model = inject_adapter_in_model(lora_config, model)
389+
390+
dummy_inputs = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]])
391+
dummy_outputs = model(dummy_inputs)
392+
```
393+
358394
## Backlog:
359395
- [x] Add tests
360396
- [x] Multi Adapter training and inference support

‎docs/source/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
sections:
3333
- local: developer_guides/custom_models
3434
title: Working with custom models
35+
- local: developer_guides/low_level_api
36+
title: PEFT low level API
3537

3638
- title: 🤗 Accelerate integrations
3739
sections:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# PEFT as a utility library
14+
15+
Let's cover in this section how you can leverage PEFT's low level API to inject trainable adapters into any `torch` module.
16+
The development of this API has been motivated by the need for super users to not rely on modling classes that are exposed in PEFT library and still be able to use adapter methods such as LoRA, IA3 and AdaLoRA.
17+
18+
## Supported tuner types
19+
20+
Currently the supported adapter types are the 'injectable' adapters, meaning adapters where an inplace modification of the model is sufficient to correctly perform the fine tuning. As such, only [LoRA](./conceptual_guides/lora), AdaLoRA and [IA3](./conceptual_guides/ia3) are currently supported in this API.
21+
22+
## `inject_adapter_in_model` method
23+
24+
To perform the adapter injection, simply use `inject_adapter_in_model` method that takes 3 arguments, the PEFT config and the model itself and an optional adapter name. You can also attach multiple adapters in the model if you call multiple times `inject_adapter_in_model` with different adapter names.
25+
26+
Below is a basic example usage of how to inject LoRA adapters into the submodule `linear` of the module `DummyModel`.
27+
```python
28+
import torch
29+
from peft import inject_adapter_in_model, LoraConfig
30+
31+
32+
class DummyModel(torch.nn.Module):
33+
def __init__(self):
34+
super().__init__()
35+
self.embedding = torch.nn.Embedding(10, 10)
36+
self.linear = torch.nn.Linear(10, 10)
37+
self.lm_head = torch.nn.Linear(10, 10)
38+
39+
def forward(self, input_ids):
40+
x = self.embedding(input_ids)
41+
x = self.linear(x)
42+
x = self.lm_head(x)
43+
return x
44+
45+
46+
lora_config = LoraConfig(
47+
lora_alpha=16,
48+
lora_dropout=0.1,
49+
r=64,
50+
bias="none",
51+
target_modules=["linear"],
52+
)
53+
54+
model = DummyModel()
55+
model = inject_adapter_in_model(lora_config, model)
56+
57+
dummy_inputs = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]])
58+
dummy_outputs = model(dummy_inputs)
59+
```
60+
61+
If you print the model, you will notice that the adapters have been correctly injected into the model
62+
63+
```bash
64+
DummyModel(
65+
(embedding): Embedding(10, 10)
66+
(linear): Linear(
67+
in_features=10, out_features=10, bias=True
68+
(lora_dropout): ModuleDict(
69+
(default): Dropout(p=0.1, inplace=False)
70+
)
71+
(lora_A): ModuleDict(
72+
(default): Linear(in_features=10, out_features=64, bias=False)
73+
)
74+
(lora_B): ModuleDict(
75+
(default): Linear(in_features=64, out_features=10, bias=False)
76+
)
77+
(lora_embedding_A): ParameterDict()
78+
(lora_embedding_B): ParameterDict()
79+
)
80+
(lm_head): Linear(in_features=10, out_features=10, bias=True)
81+
)
82+
```
83+
Note that it should be up to users to properly take care of saving the adapters (in case they want to save adapters only), as `model.state_dict()` will return the full state dict of the model.
84+
In case you want to extract the adapters state dict you can use the `get_peft_model_state_dict` method:
85+
86+
```python
87+
from peft import get_peft_model_state_dict
88+
89+
peft_state_dict = get_peft_model_state_dict(model)
90+
print(peft_state_dict)
91+
```
92+
93+
## Pros and cons
94+
95+
When to use this API and when to not use it? Let's discuss in this section the pros and cons
96+
97+
Pros:
98+
- The model gets modified in-place, meaning the model will preserve all its original attributes and methods
99+
- Works for any torch module, and any modality (vision, text, multi-modal)
100+
101+
Cons:
102+
- You need to manually writing Hugging Face `from_pretrained` and `save_pretrained` utility methods if you want to easily save / load adapters from the Hugging Face Hub.
103+
- You cannot use any of the utility method provided by `PeftModel` such as disabling adapters, merging adapters, etc.

‎src/peft/mapping.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def get_peft_model(model: PreTrainedModel, peft_config: PeftConfig, adapter_name
106106
return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name)
107107

108108

109-
def inject_adapter_in_model(peft_config: PeftConfig, model: torch.nn.Module, adapter_name: str):
109+
def inject_adapter_in_model(peft_config: PeftConfig, model: torch.nn.Module, adapter_name: str = "default"):
110110
r"""
111111
A simple API to create and inject adapter in-place into a model. Currently the API does not support prompt learning
112112
methods and adaption prompt. Make sure to have the correct `target_names` set in the `peft_config` object. The API
@@ -117,8 +117,8 @@ def inject_adapter_in_model(peft_config: PeftConfig, model: torch.nn.Module, ada
117117
Configuration object containing the parameters of the Peft model.
118118
model (`torch.nn.Module`):
119119
The input model where the adapter will be injected.
120-
adapter_name (`str`):
121-
The name of the adapter to be injected.
120+
adapter_name (`str`, `optional`, defaults to `"default"`):
121+
The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
122122
"""
123123
if peft_config.is_prompt_learning or peft_config.is_adaption_prompt:
124124
raise ValueError("`create_and_replace` does not support prompt learning and adaption prompt yet.")

‎tests/test_low_level_api.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#!/usr/bin/env python3
2+
3+
# coding=utf-8
4+
# Copyright 2023-present the HuggingFace Inc. team.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
import unittest
18+
19+
import torch
20+
21+
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model
22+
23+
24+
class DummyModel(torch.nn.Module):
25+
def __init__(self):
26+
super().__init__()
27+
self.embedding = torch.nn.Embedding(10, 10)
28+
self.linear = torch.nn.Linear(10, 10)
29+
self.lm_head = torch.nn.Linear(10, 10)
30+
31+
def forward(self, input_ids):
32+
x = self.embedding(input_ids)
33+
x = self.linear(x)
34+
x = self.lm_head(x)
35+
return x
36+
37+
38+
class TestPeft(unittest.TestCase):
39+
def setUp(self):
40+
self.model = DummyModel()
41+
42+
lora_config = LoraConfig(
43+
lora_alpha=16,
44+
lora_dropout=0.1,
45+
r=64,
46+
bias="none",
47+
target_modules=["linear"],
48+
)
49+
50+
self.model = inject_adapter_in_model(lora_config, self.model)
51+
52+
def test_inject_adapter_in_model(self):
53+
dummy_inputs = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]])
54+
_ = self.model(dummy_inputs)
55+
56+
for name, module in self.model.named_modules():
57+
if name == "linear":
58+
self.assertTrue(hasattr(module, "lora_A"))
59+
self.assertTrue(hasattr(module, "lora_B"))
60+
61+
def test_get_peft_model_state_dict(self):
62+
peft_state_dict = get_peft_model_state_dict(self.model)
63+
64+
for key in peft_state_dict.keys():
65+
self.assertTrue("lora" in key)

0 commit comments

Comments
 (0)
Failed to load comments.