Skip to content

Commit 29ee66d

Browse files
PEFT compatible GEMM (#324)
1 parent ebe8fc3 commit 29ee66d

File tree

2 files changed

+181
-27
lines changed

2 files changed

+181
-27
lines changed

awq/modules/linear/gemm.py

+105-27
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import torch.nn as nn
3+
from torch.autograd import Function
34
from awq.utils.utils import get_best_device
45
from awq.utils.packing_utils import dequantize_gemm
56

@@ -10,9 +11,94 @@
1011
except:
1112
AWQ_INSTALLED = False
1213

14+
# Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev
15+
class WQLinearMMFunction(Function):
16+
@staticmethod
17+
# ctx is the first argument to forward
18+
def forward(
19+
ctx,
20+
x,
21+
qweight,
22+
qzeros,
23+
scales,
24+
w_bit=4,
25+
group_size=128,
26+
bias=None,
27+
out_features=0
28+
):
29+
# The forward pass can use ctx.
30+
ctx.save_for_backward(x, qweight, qzeros, scales, bias)
31+
ctx.out_features = out_features
32+
33+
out_shape = x.shape[:-1] + (out_features, )
34+
x = x.to(torch.float16)
35+
36+
if AWQ_INSTALLED:
37+
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0]*x.shape[1] >= 1024
38+
39+
if FP16_MATMUL_HEURISTIC_CONDITION:
40+
out = awq_ext.dequantize_weights_cuda(
41+
qweight,
42+
scales,
43+
qzeros,
44+
0,
45+
0,
46+
0,
47+
False
48+
)
49+
out = torch.matmul(x, out)
50+
else:
51+
out = awq_ext.gemm_forward_cuda(
52+
x.reshape(-1, x.shape[-1]),
53+
qweight,
54+
scales,
55+
qzeros,
56+
8
57+
)
58+
else:
59+
out = dequantize_gemm(
60+
qweight,
61+
qzeros,
62+
scales,
63+
w_bit,
64+
group_size
65+
)
66+
out = torch.matmul(x, out)
67+
68+
out = out + bias if bias is not None else out
69+
out = out.reshape(out_shape)
70+
71+
# always want 3D tensor if tensor is 2D
72+
if len(out.shape) == 2:
73+
out = out.unsqueeze(0)
74+
75+
return out
76+
77+
@staticmethod
78+
def backward(ctx, grad_output):
79+
input, qweight, qzeros, scales, bias = ctx.saved_tensors
80+
81+
weights = awq_ext.dequantize_weights_cuda(
82+
qweight,
83+
scales,
84+
qzeros,
85+
1,
86+
0,
87+
0,
88+
False
89+
)
90+
91+
if ctx.needs_input_grad[0]:
92+
# 2D matrix multiplication, unsqueeze to 3D
93+
grad_input = grad_output.squeeze(0).mm(
94+
weights.transpose(0, 1)
95+
).unsqueeze(0)
96+
97+
return grad_input, None, None, None, None, None, None, None
98+
1399

14100
class WQLinear_GEMM(nn.Module):
15-
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
101+
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev, training=False):
16102
super().__init__()
17103

18104
if w_bit not in [4]:
@@ -22,6 +108,7 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
22108
self.out_features = out_features
23109
self.w_bit = w_bit
24110
self.group_size = group_size if group_size != -1 else in_features
111+
self.training = training
25112

26113
# quick sanity check (make sure aligment)
27114
assert self.in_features % self.group_size == 0
@@ -145,45 +232,36 @@ def from_linear(
145232

146233
return awq_linear
147234

148-
@torch.no_grad()
149235
def forward(self, x):
150236
out_shape = x.shape[:-1] + (self.out_features,)
151237

152238
input_dtype = x.dtype
153239
if input_dtype != torch.float16:
154240
x = x.half()
155241

156-
if AWQ_INSTALLED:
157-
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024
158-
159-
if FP16_MATMUL_HEURISTIC_CONDITION:
160-
out = awq_ext.dequantize_weights_cuda(
161-
self.qweight,
162-
self.scales,
163-
self.qzeros,
164-
0,
165-
0,
166-
0,
167-
False,
168-
)
169-
out = torch.matmul(x, out)
170-
else:
171-
out = awq_ext.gemm_forward_cuda(
172-
x.reshape(-1, x.shape[-1]),
173-
self.qweight,
174-
self.scales,
175-
self.qzeros,
176-
8,
177-
)
178-
else:
179-
out = dequantize_gemm(
242+
if self.training:
243+
out = WQLinearMMFunction.apply(
244+
x,
180245
self.qweight,
181246
self.qzeros,
182247
self.scales,
183248
self.w_bit,
184249
self.group_size,
250+
self.bias,
251+
self.out_features,
185252
)
186-
out = torch.matmul(x, out)
253+
else:
254+
with torch.no_grad():
255+
out = WQLinearMMFunction.apply(
256+
x,
257+
self.qweight,
258+
self.qzeros,
259+
self.scales,
260+
self.w_bit,
261+
self.group_size,
262+
self.bias,
263+
self.out_features,
264+
)
187265

188266
if input_dtype != torch.float16:
189267
out = out.to(dtype=input_dtype)

examples/awq_train.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import datasets
2+
from awq import AutoAWQForCausalLM
3+
from transformers import (
4+
AutoTokenizer,
5+
TrainingArguments,
6+
Trainer,
7+
DataCollatorForLanguageModeling
8+
)
9+
from peft import get_peft_model, LoraConfig, TaskType
10+
11+
def prepare_split(tokenizer):
12+
data = datasets.load_dataset("mhenrichsen/alpaca_2k_test", split="train")
13+
prompt_template = "<s>[INST] {system} {prompt} [/INST] {output}</s>"
14+
15+
def format_prompt(x):
16+
return prompt_template.format(
17+
system="",
18+
prompt=x["instruction"],
19+
output=x["output"]
20+
)
21+
22+
data = data.map(
23+
lambda x: {"text": format_prompt(x)},
24+
).select_columns(["text"])
25+
data = data.map(lambda x: tokenizer(x["text"]), batched=True)
26+
27+
return data
28+
29+
model_path = "ybelkada/opt-125m-awq"
30+
31+
# Load model
32+
model = AutoAWQForCausalLM.from_quantized(model_path, fuse_layers=False)
33+
tokenizer = AutoTokenizer.from_pretrained(model_path)
34+
tokenizer.pad_token = tokenizer.eos_token
35+
36+
# Prepare data
37+
data_train = prepare_split(tokenizer)
38+
39+
# Config Lora
40+
lora_config = LoraConfig(
41+
r=4,
42+
lora_alpha=8,
43+
lora_dropout=0.5,
44+
bias="none",
45+
task_type=TaskType.CAUSAL_LM,
46+
inference_mode=False
47+
)
48+
49+
model = get_peft_model(model.model, lora_config)
50+
51+
model.print_trainable_parameters()
52+
53+
training_arguments = TrainingArguments(
54+
output_dir="./output",
55+
per_device_train_batch_size=1,
56+
optim="adamw_torch",
57+
num_train_epochs=1,
58+
learning_rate=1e-4,
59+
# fp16=True,
60+
evaluation_strategy="no",
61+
save_strategy="epoch",
62+
save_steps=100,
63+
logging_steps=50,
64+
eval_steps=None,
65+
load_best_model_at_end=False
66+
)
67+
68+
trainer = Trainer(
69+
model=model,
70+
train_dataset=data_train,
71+
args=training_arguments,
72+
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
73+
)
74+
75+
trainer.train()
76+
trainer.save_model("output")

0 commit comments

Comments
 (0)