1
1
from conversion .qparams import QParams
2
+ from exllamav2 .ext import exllamav2_ext as ext_c , none_tensor
2
3
import math
3
4
import itertools
5
+ import time
4
6
5
7
def optimize (job , save_fn , model ):
6
8
@@ -9,11 +11,19 @@ def optimize(job, save_fn, model):
9
11
mlp_key_up = model .config .arch .mlp_key_up
10
12
mlp_key_down = model .config .arch .mlp_key_down
11
13
12
- error_norm = 2.4
13
- max_step_size = 2
14
- first_layer_bias = 10
15
- bias_layers = 2
16
- bias_iter = 10
14
+ norm_interval = (1.5 , 3.5 )
15
+ norm_2ndstage = 0.15
16
+ anneal_temp_max = 2
17
+ anneal_temp_min = 0.0001
18
+ anneal_cooling_factor = 0.995
19
+ anneal_iter = 1000
20
+ anneal_samples = 80
21
+ anneal_stages = 3
22
+
23
+ # max_step_size = 2
24
+ # first_layer_bias = 4
25
+ # bias_layers = 2
26
+ # bias_iter = 0
17
27
18
28
key = "model.layers.0"
19
29
key_q = key + ".self_attn.q_proj"
@@ -57,21 +67,14 @@ def optimize(job, save_fn, model):
57
67
numel = sum (m .numel () for m in model .modules [1 : num_modules + 1 ])
58
68
59
69
target_bpw = job ["bits" ]
60
- weight_budget = numel * target_bpw
70
+ weight_budget = int ( numel * target_bpw )
61
71
62
72
# Compile options
63
73
64
74
measurement = job ["measurement" ]
65
-
66
- def fn (x , idx ):
67
- if idx < bias_layers :
68
- return 1 - ((1 - x ) ** error_norm ) * first_layer_bias
69
- else :
70
- return 1 - ((1 - x ) ** error_norm )
71
-
72
- weights = []
73
- values = []
75
+ slots = []
74
76
params = []
77
+
75
78
for i in range (num_layers ):
76
79
if model .config .arch .parallel_decoder_blocks :
77
80
m1 = measurement ["model.layers." + str (i ) + ".parallel_decoder" ]["attn" ]
@@ -80,162 +83,83 @@ def fn(x, idx):
80
83
m1 = measurement ["model.layers." + str (i ) + ".self_attn" ]
81
84
m2 = measurement ["model.layers." + str (i ) + "." + mlp_mode ]
82
85
for m in [m1 , m2 ]:
83
- v = [fn (e ["accuracy" ], i ) for e in m ]
84
- w = [e ["total_bits" ] for e in m ]
85
- weights .append (w )
86
- values .append (v )
87
- params .append (m )
88
-
89
- print (" -- Pruning..." )
90
-
91
- # Sort options by weight, eliminate strictly worse options
92
-
93
- for i in range (num_layers * 2 ):
94
- combined = sorted (zip (weights [i ], values [i ], params [i ]))
95
- w_ , v_ , p_ = zip (* combined )
96
- w_ = list (w_ )
97
- v_ = list (v_ )
98
- p_ = list (p_ )
99
- j = 1
100
- while j < len (v_ ):
101
- if v_ [j ] <= v_ [j - 1 ]:
102
- w_ .pop (j )
103
- v_ .pop (j )
104
- p_ .pop (j )
105
- else :
106
- j += 1
107
- weights [i ] = w_
108
- values [i ] = v_
109
- params [i ] = p_
110
-
111
- # Quick and dirty iterative solver
112
-
113
- print (" -- Solving..." )
114
-
115
- f_solution = [0 ] * num_layers * 2
116
- weight = sum (weights [i ][0 ] for i in range (num_layers * 2 ))
117
- value = 1
118
- for i in range (num_layers * 2 ): value *= values [i ][0 ]
119
-
120
- iteration = 0
121
-
122
- while True :
123
- min_idx = - 1
124
- min_value = float ("inf" )
125
- iteration += 1
126
- for i in range (bias_layers if iteration < bias_iter else num_layers * 2 ):
127
- s = f_solution [i ]
128
- if values [i ][s ] < min_value :
129
- if s < len (weights [i ]) - 1 :
130
- added_w = weights [i ][s + 1 ] - weights [i ][s ]
131
- if added_w + weight <= weight_budget :
132
- min_idx = i
133
- min_value = values [i ][s ]
134
- if min_idx == - 1 : break
135
- s = f_solution [min_idx ]
136
- weight += weights [min_idx ][s + 1 ] - weights [min_idx ][s ]
137
- value *= values [min_idx ][s + 1 ] / values [min_idx ][s ]
138
- f_solution [min_idx ] += 1
139
-
140
- bpw = weight / numel
141
- print (f" -- Score: { value :.8f} bpw: { bpw :.4f} " )
142
-
143
- def improve (solution , s_weight , hold = None ):
144
-
145
- if hold is None : hold = []
146
- best_idx = - 1
147
- best_ratio = 0
148
- best_add_w = 0
149
- best_add_v = 0
150
- for idx in range (num_layers * 2 ):
151
- if idx in hold : continue
152
-
153
- si = solution [idx ]
154
- if si == len (weights [idx ]) - 1 : continue
155
-
156
- add_w = weights [idx ][si + 1 ] - weights [idx ][si ]
157
- if s_weight + add_w > weight_budget : continue
158
-
159
- add_v = values [idx ][si + 1 ] / values [idx ][si ]
160
- ratio = add_v / add_w
161
- if ratio > best_ratio :
162
- best_ratio = ratio
163
- best_idx = idx
164
- best_add_w = add_w
165
- best_add_v = add_v
166
-
167
- return best_idx , best_add_w , best_add_v
168
-
169
- # while True:
170
- # b_idx, b_add_w, b_add_v = improve(f_solution, weight)
171
- # if b_idx == -1:
172
- # break
173
- #
174
- # f_solution[b_idx] += 1
175
- # weight += b_add_w
176
- # value += b_add_v
177
- #
178
- # bpw = weight / numel
179
- # print(f" -- Score: {math.exp(value):.8f} bpw: {bpw:.4f}")
180
-
181
- best_value = value
182
- prev_best_value = value
183
- step_size = 1
184
-
185
- while True :
186
-
187
- for i , j in itertools .permutations (range (num_layers * 2 ), 2 ):
188
-
189
- t_solution = f_solution .copy ()
190
- t_solution [i ] = max (t_solution [i ] - step_size , 0 )
191
- t_solution [j ] = max (t_solution [j ] - step_size , 0 )
192
-
193
- t_weight = sum (weights [k ][t_solution [k ]] for k in range (num_layers * 2 ))
194
- t_value = 1
195
- for k in range (num_layers * 2 ): t_value *= values [k ][t_solution [k ]]
196
-
197
- while True :
198
- b_idx , b_add_w , b_add_v = improve (t_solution , t_weight , [i , j ])
199
- if b_idx == - 1 :
200
- break
201
- t_solution [b_idx ] += 1
202
- t_weight += b_add_w
203
- t_value *= b_add_v
204
-
205
- if t_value > best_value :
206
- f_solution = t_solution
207
- best_value = t_value
208
- break
209
-
210
- if best_value == prev_best_value :
211
- step_size += 1
212
- if step_size > max_step_size : break
213
- continue
214
-
215
- bpw = t_weight / numel
216
- print (f" -- Score: { best_value :.8f} bpw: { bpw :.4f} " )
217
- prev_best_value = best_value
86
+ slot = []
87
+ param = []
88
+ for opt in m :
89
+ o = (int (opt ["total_bits" ]), 1 - opt ["accuracy" ])
90
+ slot .append (o )
91
+ param .append (opt )
92
+ slots .append (slot )
93
+ params .append (param )
94
+
95
+ # Find some solutions
96
+
97
+ last_update = 0
98
+ m = float ("inf" )
99
+ p = float ("inf" )
100
+ for i in range (anneal_stages * anneal_samples ):
101
+ if time .time () - last_update > 1 or i == anneal_samples - 1 :
102
+ print (f" -- Optimizing: { i + 1 :4} /{ anneal_stages * anneal_samples :4} " )
103
+ last_update = time .time ()
104
+
105
+ if i < anneal_samples :
106
+ t = i / (anneal_samples - 1 )
107
+ norm = (1 - t ) * norm_interval [0 ] + t * norm_interval [1 ]
108
+
109
+ elif i < anneal_samples * 2 :
110
+ if i == anneal_samples :
111
+ norm_a = bestnorm - norm_2ndstage / 2
112
+ norm_b = bestnorm + norm_2ndstage / 2
113
+ t = i / (anneal_samples - 1 ) - 1
114
+ norm = (1 - t ) * norm_a + t * norm_b
115
+
116
+ else :
117
+ norm = bestnorm
118
+
119
+ s_ , si_ , p_ , c_ , m_ = ext_c .sim_anneal (slots ,
120
+ weight_budget ,
121
+ anneal_temp_max ,
122
+ anneal_cooling_factor ,
123
+ anneal_temp_min ,
124
+ anneal_iter ,
125
+ norm )
126
+
127
+ if i < anneal_samples * 2 :
128
+ if m_ < m :
129
+ m = m_
130
+ bestnorm = norm
131
+ else :
132
+ if p_ < p :
133
+ s , si , p , m = s_ , si_ , p_ , m_
134
+
135
+ solution_idx = si
136
+ print (f" -- max(err): { m :.6f} " )
137
+ print (f" -- error_norm: { bestnorm :.6f} " )
138
+
218
139
219
140
# Save strategy
220
141
221
142
print (" -- Quantization strategy:" )
222
143
223
- errp = 1
144
+ logerr = 0
145
+ maxerr = 0
224
146
job ["strategy" ] = {}
225
147
for layer_ in range (num_layers ):
226
148
227
149
k1 = "model.layers." + str (layer_ ) + ".self_attn"
228
150
k2 = "model.layers." + str (layer_ ) + "." + mlp_mode
229
- p1 = params [layer_ * 2 ][f_solution [layer_ * 2 ]]
230
- p2 = params [layer_ * 2 + 1 ][f_solution [layer_ * 2 + 1 ]]
151
+ p1 = params [layer_ * 2 ][solution_idx [layer_ * 2 ]]
152
+ p2 = params [layer_ * 2 + 1 ][solution_idx [layer_ * 2 + 1 ]]
231
153
232
154
for (k , p , n ) in zip ((k1 , k2 ), (p1 , p2 ), (numel_attn , numel_mlp )):
233
155
job ["strategy" ][k ] = p
234
156
bpw = p ["total_bits" ] / n
235
157
err = 1 - p ["accuracy" ]
236
158
print (f" -- { k :50} { bpw :1.4f} bpw - exp. error: { err :1.8f} " )
237
- errp *= (1 - err )
159
+ logerr += math .log (err )
160
+ maxerr = max (err , maxerr )
238
161
239
- print (f" -- Total exp. error: { 1 - errp :1.12f} " )
162
+ print (f" -- sum(log(err)): { logerr :.6f} " )
163
+ print (f" -- max(err): { maxerr :.6f} " )
240
164
241
165
xx = 0
0 commit comments