1
1
import torch
2
2
import torch .nn as nn
3
+ from torch .autograd import Function
3
4
from awq .utils .utils import get_best_device
4
5
from awq .utils .packing_utils import dequantize_gemm
5
6
10
11
except :
11
12
AWQ_INSTALLED = False
12
13
14
+ # Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev
15
+ class WQLinearMMFunction (Function ):
16
+ @staticmethod
17
+ # ctx is the first argument to forward
18
+ def forward (
19
+ ctx ,
20
+ x ,
21
+ qweight ,
22
+ qzeros ,
23
+ scales ,
24
+ w_bit = 4 ,
25
+ group_size = 128 ,
26
+ bias = None ,
27
+ out_features = 0
28
+ ):
29
+ # The forward pass can use ctx.
30
+ ctx .save_for_backward (x , qweight , qzeros , scales , bias )
31
+ ctx .out_features = out_features
32
+
33
+ out_shape = x .shape [:- 1 ] + (out_features , )
34
+ x = x .to (torch .float16 )
35
+
36
+ if AWQ_INSTALLED :
37
+ FP16_MATMUL_HEURISTIC_CONDITION = x .shape [0 ]* x .shape [1 ] >= 1024
38
+
39
+ if FP16_MATMUL_HEURISTIC_CONDITION :
40
+ out = awq_ext .dequantize_weights_cuda (
41
+ qweight ,
42
+ scales ,
43
+ qzeros ,
44
+ 0 ,
45
+ 0 ,
46
+ 0 ,
47
+ False
48
+ )
49
+ out = torch .matmul (x , out )
50
+ else :
51
+ out = awq_ext .gemm_forward_cuda (
52
+ x .reshape (- 1 , x .shape [- 1 ]),
53
+ qweight ,
54
+ scales ,
55
+ qzeros ,
56
+ 8
57
+ )
58
+ else :
59
+ out = dequantize_gemm (
60
+ qweight ,
61
+ qzeros ,
62
+ scales ,
63
+ w_bit ,
64
+ group_size
65
+ )
66
+ out = torch .matmul (x , out )
67
+
68
+ out = out + bias if bias is not None else out
69
+ out = out .reshape (out_shape )
70
+
71
+ # always want 3D tensor if tensor is 2D
72
+ if len (out .shape ) == 2 :
73
+ out = out .unsqueeze (0 )
74
+
75
+ return out
76
+
77
+ @staticmethod
78
+ def backward (ctx , grad_output ):
79
+ input , qweight , qzeros , scales , bias = ctx .saved_tensors
80
+
81
+ weights = awq_ext .dequantize_weights_cuda (
82
+ qweight ,
83
+ scales ,
84
+ qzeros ,
85
+ 1 ,
86
+ 0 ,
87
+ 0 ,
88
+ False
89
+ )
90
+
91
+ if ctx .needs_input_grad [0 ]:
92
+ # 2D matrix multiplication, unsqueeze to 3D
93
+ grad_input = grad_output .squeeze (0 ).mm (
94
+ weights .transpose (0 , 1 )
95
+ ).unsqueeze (0 )
96
+
97
+ return grad_input , None , None , None , None , None , None , None
98
+
13
99
14
100
class WQLinear_GEMM (nn .Module ):
15
- def __init__ (self , w_bit , group_size , in_features , out_features , bias , dev ):
101
+ def __init__ (self , w_bit , group_size , in_features , out_features , bias , dev , training = False ):
16
102
super ().__init__ ()
17
103
18
104
if w_bit not in [4 ]:
@@ -22,6 +108,7 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
22
108
self .out_features = out_features
23
109
self .w_bit = w_bit
24
110
self .group_size = group_size if group_size != - 1 else in_features
111
+ self .training = training
25
112
26
113
# quick sanity check (make sure aligment)
27
114
assert self .in_features % self .group_size == 0
@@ -145,45 +232,36 @@ def from_linear(
145
232
146
233
return awq_linear
147
234
148
- @torch .no_grad ()
149
235
def forward (self , x ):
150
236
out_shape = x .shape [:- 1 ] + (self .out_features ,)
151
237
152
238
input_dtype = x .dtype
153
239
if input_dtype != torch .float16 :
154
240
x = x .half ()
155
241
156
- if AWQ_INSTALLED :
157
- FP16_MATMUL_HEURISTIC_CONDITION = x .shape [0 ] * x .shape [1 ] >= 1024
158
-
159
- if FP16_MATMUL_HEURISTIC_CONDITION :
160
- out = awq_ext .dequantize_weights_cuda (
161
- self .qweight ,
162
- self .scales ,
163
- self .qzeros ,
164
- 0 ,
165
- 0 ,
166
- 0 ,
167
- False ,
168
- )
169
- out = torch .matmul (x , out )
170
- else :
171
- out = awq_ext .gemm_forward_cuda (
172
- x .reshape (- 1 , x .shape [- 1 ]),
173
- self .qweight ,
174
- self .scales ,
175
- self .qzeros ,
176
- 8 ,
177
- )
178
- else :
179
- out = dequantize_gemm (
242
+ if self .training :
243
+ out = WQLinearMMFunction .apply (
244
+ x ,
180
245
self .qweight ,
181
246
self .qzeros ,
182
247
self .scales ,
183
248
self .w_bit ,
184
249
self .group_size ,
250
+ self .bias ,
251
+ self .out_features ,
185
252
)
186
- out = torch .matmul (x , out )
253
+ else :
254
+ with torch .no_grad ():
255
+ out = WQLinearMMFunction .apply (
256
+ x ,
257
+ self .qweight ,
258
+ self .qzeros ,
259
+ self .scales ,
260
+ self .w_bit ,
261
+ self .group_size ,
262
+ self .bias ,
263
+ self .out_features ,
264
+ )
187
265
188
266
if input_dtype != torch .float16 :
189
267
out = out .to (dtype = input_dtype )
0 commit comments