Skip to content

Commit ed118b4

Browse files
authored
Merge pull request #416 from turboderp/dev
Merge dev branch
2 parents dafb508 + b68c0bd commit ed118b4

33 files changed

+12395
-189
lines changed

conversion/adaptivegptq.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,14 @@ def find_params(self, x):
6969
self.scale = qscale_tw * best_p
7070
self.qscale_max = qscale_max_t * best_p
7171

72+
# Make sure scales are rounded correctly for sanity test
73+
prescale = torch.tensor([1 / 256], dtype = torch.half, device = self.scale.device)
74+
self.scale = ((self.qscale * self.qscale).to(torch.half) * (self.qscale_max.half() * prescale)).float()
75+
7276

7377
class AdaptiveGPTQ:
7478

75-
percdamp: float = 0.07
79+
percdamp: float = 0.12
7680

7781
layer: nn.Linear
7882
device: torch.device

conversion/measure.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def measure_quant(job, save_fn, model):
387387
overall_rolling_accuracy = 0
388388

389389
last_snapshot_time = time.time()
390-
snapshot_interval_s = 90
390+
snapshot_interval_s = 180
391391

392392
temp_filename = os.path.join(job["out_dir"], "hidden_states_temp.safetensors")
393393
states_filename = os.path.join(job["out_dir"], "hidden_states.safetensors")

conversion/optimize.py

+79-155
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from conversion.qparams import QParams
2+
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
23
import math
34
import itertools
5+
import time
46

57
def optimize(job, save_fn, model):
68

@@ -9,11 +11,19 @@ def optimize(job, save_fn, model):
911
mlp_key_up = model.config.arch.mlp_key_up
1012
mlp_key_down = model.config.arch.mlp_key_down
1113

12-
error_norm = 2.4
13-
max_step_size = 2
14-
first_layer_bias = 10
15-
bias_layers = 2
16-
bias_iter = 10
14+
norm_interval = (1.5, 3.5)
15+
norm_2ndstage = 0.15
16+
anneal_temp_max = 2
17+
anneal_temp_min = 0.0001
18+
anneal_cooling_factor = 0.995
19+
anneal_iter = 1000
20+
anneal_samples = 80
21+
anneal_stages = 3
22+
23+
# max_step_size = 2
24+
# first_layer_bias = 4
25+
# bias_layers = 2
26+
# bias_iter = 0
1727

1828
key = "model.layers.0"
1929
key_q = key + ".self_attn.q_proj"
@@ -57,21 +67,14 @@ def optimize(job, save_fn, model):
5767
numel = sum(m.numel() for m in model.modules[1 : num_modules + 1])
5868

5969
target_bpw = job["bits"]
60-
weight_budget = numel * target_bpw
70+
weight_budget = int(numel * target_bpw)
6171

6272
# Compile options
6373

6474
measurement = job["measurement"]
65-
66-
def fn(x, idx):
67-
if idx < bias_layers:
68-
return 1 - ((1 - x) ** error_norm) * first_layer_bias
69-
else:
70-
return 1 - ((1 - x) ** error_norm)
71-
72-
weights = []
73-
values = []
75+
slots = []
7476
params = []
77+
7578
for i in range(num_layers):
7679
if model.config.arch.parallel_decoder_blocks:
7780
m1 = measurement["model.layers." + str(i) + ".parallel_decoder"]["attn"]
@@ -80,162 +83,83 @@ def fn(x, idx):
8083
m1 = measurement["model.layers." + str(i) + ".self_attn"]
8184
m2 = measurement["model.layers." + str(i) + "." + mlp_mode]
8285
for m in [m1, m2]:
83-
v = [fn(e["accuracy"], i) for e in m]
84-
w = [e["total_bits"] for e in m]
85-
weights.append(w)
86-
values.append(v)
87-
params.append(m)
88-
89-
print(" -- Pruning...")
90-
91-
# Sort options by weight, eliminate strictly worse options
92-
93-
for i in range(num_layers * 2):
94-
combined = sorted(zip(weights[i], values[i], params[i]))
95-
w_, v_, p_ = zip(*combined)
96-
w_ = list(w_)
97-
v_ = list(v_)
98-
p_ = list(p_)
99-
j = 1
100-
while j < len(v_):
101-
if v_[j] <= v_[j - 1]:
102-
w_.pop(j)
103-
v_.pop(j)
104-
p_.pop(j)
105-
else:
106-
j += 1
107-
weights[i] = w_
108-
values[i] = v_
109-
params[i] = p_
110-
111-
# Quick and dirty iterative solver
112-
113-
print(" -- Solving...")
114-
115-
f_solution = [0] * num_layers * 2
116-
weight = sum(weights[i][0] for i in range(num_layers * 2))
117-
value = 1
118-
for i in range(num_layers * 2): value *= values[i][0]
119-
120-
iteration = 0
121-
122-
while True:
123-
min_idx = -1
124-
min_value = float("inf")
125-
iteration += 1
126-
for i in range(bias_layers if iteration < bias_iter else num_layers * 2):
127-
s = f_solution[i]
128-
if values[i][s] < min_value:
129-
if s < len(weights[i]) - 1:
130-
added_w = weights[i][s + 1] - weights[i][s]
131-
if added_w + weight <= weight_budget:
132-
min_idx = i
133-
min_value = values[i][s]
134-
if min_idx == -1: break
135-
s = f_solution[min_idx]
136-
weight += weights[min_idx][s + 1] - weights[min_idx][s]
137-
value *= values[min_idx][s + 1] / values[min_idx][s]
138-
f_solution[min_idx] += 1
139-
140-
bpw = weight / numel
141-
print(f" -- Score: {value:.8f} bpw: {bpw:.4f}")
142-
143-
def improve(solution, s_weight, hold = None):
144-
145-
if hold is None: hold = []
146-
best_idx = -1
147-
best_ratio = 0
148-
best_add_w = 0
149-
best_add_v = 0
150-
for idx in range(num_layers * 2):
151-
if idx in hold: continue
152-
153-
si = solution[idx]
154-
if si == len(weights[idx]) - 1: continue
155-
156-
add_w = weights[idx][si + 1] - weights[idx][si]
157-
if s_weight + add_w > weight_budget: continue
158-
159-
add_v = values[idx][si + 1] / values[idx][si]
160-
ratio = add_v / add_w
161-
if ratio > best_ratio:
162-
best_ratio = ratio
163-
best_idx = idx
164-
best_add_w = add_w
165-
best_add_v = add_v
166-
167-
return best_idx, best_add_w, best_add_v
168-
169-
# while True:
170-
# b_idx, b_add_w, b_add_v = improve(f_solution, weight)
171-
# if b_idx == -1:
172-
# break
173-
#
174-
# f_solution[b_idx] += 1
175-
# weight += b_add_w
176-
# value += b_add_v
177-
#
178-
# bpw = weight / numel
179-
# print(f" -- Score: {math.exp(value):.8f} bpw: {bpw:.4f}")
180-
181-
best_value = value
182-
prev_best_value = value
183-
step_size = 1
184-
185-
while True:
186-
187-
for i, j in itertools.permutations(range(num_layers * 2), 2):
188-
189-
t_solution = f_solution.copy()
190-
t_solution[i] = max(t_solution[i] - step_size, 0)
191-
t_solution[j] = max(t_solution[j] - step_size, 0)
192-
193-
t_weight = sum(weights[k][t_solution[k]] for k in range(num_layers * 2))
194-
t_value = 1
195-
for k in range(num_layers * 2): t_value *= values[k][t_solution[k]]
196-
197-
while True:
198-
b_idx, b_add_w, b_add_v = improve(t_solution, t_weight, [i, j])
199-
if b_idx == -1:
200-
break
201-
t_solution[b_idx] += 1
202-
t_weight += b_add_w
203-
t_value *= b_add_v
204-
205-
if t_value > best_value:
206-
f_solution = t_solution
207-
best_value = t_value
208-
break
209-
210-
if best_value == prev_best_value:
211-
step_size += 1
212-
if step_size > max_step_size: break
213-
continue
214-
215-
bpw = t_weight / numel
216-
print(f" -- Score: {best_value:.8f} bpw: {bpw:.4f}")
217-
prev_best_value = best_value
86+
slot = []
87+
param = []
88+
for opt in m:
89+
o = (int(opt["total_bits"]), 1 - opt["accuracy"])
90+
slot.append(o)
91+
param.append(opt)
92+
slots.append(slot)
93+
params.append(param)
94+
95+
# Find some solutions
96+
97+
last_update = 0
98+
m = float("inf")
99+
p = float("inf")
100+
for i in range(anneal_stages * anneal_samples):
101+
if time.time() - last_update > 1 or i == anneal_samples - 1:
102+
print(f" -- Optimizing: {i + 1:4}/{anneal_stages * anneal_samples:4}")
103+
last_update = time.time()
104+
105+
if i < anneal_samples:
106+
t = i / (anneal_samples - 1)
107+
norm = (1 - t) * norm_interval[0] + t * norm_interval[1]
108+
109+
elif i < anneal_samples * 2:
110+
if i == anneal_samples:
111+
norm_a = bestnorm - norm_2ndstage / 2
112+
norm_b = bestnorm + norm_2ndstage / 2
113+
t = i / (anneal_samples - 1) - 1
114+
norm = (1 - t) * norm_a + t * norm_b
115+
116+
else:
117+
norm = bestnorm
118+
119+
s_, si_, p_, c_, m_ = ext_c.sim_anneal(slots,
120+
weight_budget,
121+
anneal_temp_max,
122+
anneal_cooling_factor,
123+
anneal_temp_min,
124+
anneal_iter,
125+
norm)
126+
127+
if i < anneal_samples * 2:
128+
if m_ < m:
129+
m = m_
130+
bestnorm = norm
131+
else:
132+
if p_ < p:
133+
s, si, p, m = s_, si_, p_, m_
134+
135+
solution_idx = si
136+
print(f" -- max(err): {m:.6f}")
137+
print(f" -- error_norm: {bestnorm:.6f}")
138+
218139

219140
# Save strategy
220141

221142
print(" -- Quantization strategy:")
222143

223-
errp = 1
144+
logerr = 0
145+
maxerr = 0
224146
job["strategy"] = {}
225147
for layer_ in range(num_layers):
226148

227149
k1 = "model.layers." + str(layer_) + ".self_attn"
228150
k2 = "model.layers." + str(layer_) + "." + mlp_mode
229-
p1 = params[layer_ * 2][f_solution[layer_ * 2]]
230-
p2 = params[layer_ * 2 + 1][f_solution[layer_ * 2 + 1]]
151+
p1 = params[layer_ * 2][solution_idx[layer_ * 2]]
152+
p2 = params[layer_ * 2 + 1][solution_idx[layer_ * 2 + 1]]
231153

232154
for (k, p, n) in zip((k1, k2), (p1, p2), (numel_attn, numel_mlp)):
233155
job["strategy"][k] = p
234156
bpw = p["total_bits"] / n
235157
err = 1 - p["accuracy"]
236158
print(f" -- {k:50} {bpw:1.4f} bpw - exp. error: {err:1.8f}")
237-
errp *= (1 - err)
159+
logerr += math.log(err)
160+
maxerr = max(err, maxerr)
238161

239-
print(f" -- Total exp. error: {1 - errp:1.12f}")
162+
print(f" -- sum(log(err)): {logerr:.6f}")
163+
print(f" -- max(err): {maxerr:.6f}")
240164

241165
xx = 0

conversion/quantize.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def quant_parallel_decoder(job, module, hidden_states, target_states, quantizers
253253
def quant(job, save_fn, model):
254254

255255
last_snapshot_time = time.time()
256-
snapshot_interval_s = 90
256+
snapshot_interval_s = 180
257257

258258
temp_filename = os.path.join(job["out_dir"], "hidden_states_temp.safetensors")
259259
states_filename = os.path.join(job["out_dir"], "hidden_states.safetensors")
@@ -526,4 +526,4 @@ def quant(job, save_fn, model):
526526
del job["invalid"]
527527
save_fn()
528528

529-
time_since_snapshot = time.time()
529+
last_snapshot_time = time.time()

convert.py

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from conversion.optimize import optimize
99
from conversion.compile import compile_model
1010
from conversion.qparams import qparams_headoptions
11+
import torch
1112

1213
parser = argparse.ArgumentParser(description = "Convert model to ExLlamaV2")
1314
parser.add_argument("-i", "--in_dir", type = str, help = "Input directory", default = "")
@@ -29,6 +30,8 @@
2930

3031
args = parser.parse_args()
3132

33+
torch.set_printoptions(precision = 7, sci_mode = False, linewidth = 200)
34+
3235
# Check some args
3336

3437
if not args.in_dir:

doc/qcache_eval.md

+19-14
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,23 @@ The tl;dr:
1515
Token-level perplexity tests for various full-precision and quantized models using FP16, FP8 and Q4 cache
1616
modes. Dataset is The Pile, 10 rows of 512 tokens per test.
1717

18-
Model | Precision | FP16 cache | FP8 cache | Q4 cache
19-
--------|-----------|---------------|-----------|---------
20-
Mistral 7B Instruct | 3.0 bpw | 13.33 | 13.43 | 13.41
21-
-- | 3.5 bpw | 13.07 | 13.14 | 13.12
22-
-- | 4.0 bpw | 12.90 | 12.90 | 12.90
23-
-- | 5.0 bpw | 12.73 | 12.73 | 12.75
24-
-- | 6.0 bpw | 12.73 | 12.75 | 12.74
25-
-- | FP16 | 12.69 | 12.71 | 12.72
26-
Mixtral 8x7B | 3.5 bpw | 10.27 | 10.41 | 10.39
27-
-- | 4.0 bpw | 10.09 | 10.26 | 10.23
28-
-- | 5.0 bpw | 10.02 | 10.16 | 10.15
29-
Llama2 7B | 4.0 bpw | 11.43 | 11.92 | 11.74
30-
-- | 5.0 bpw | 11.13 | 11.40 | 11.31
31-
-- | FP16 | 10.91 | 11.24 | 11.16
18+
Results are updated for the new method which uses Hadamard rotations on the keys/values. Old results for version
19+
0.0.18 and prior kept for reference.
20+
21+
Model | Precision | FP16 cache | FP8 cache | Q4 cache (old) | Q4 cache
22+
--------|---------|-------------|-----------|-------|----------
23+
Mistral 7B Instruct | 3.0 bpw | **13.33** | 13.43 | 13.41 | **13.37**
24+
-- | 3.5 bpw | **13.07** | 13.14 | 13.12 | **13.09**
25+
-- | 4.0 bpw | **12.90** | 12.90 | 12.90 | **12.90**
26+
-- | 5.0 bpw | **12.73** | 12.73 | 12.75 | **12.75**
27+
-- | 6.0 bpw | **12.73** | 12.75 | 12.74 | **12.74**
28+
-- | FP16 | **12.69** | 12.71 | 12.72 | **12.69**
29+
Mixtral 8x7B | 3.5 bpw | **10.27** | 10.41 | 10.39 | **10.32**
30+
-- | 4.0 bpw | **10.09** | 10.26 | 10.23 | **10.19**
31+
-- | 5.0 bpw | **10.02** | 10.16 | 10.15 | **10.04**
32+
Llama2 7B | 4.0 bpw | **11.43** | 11.92 | 11.74 | **11.60**
33+
-- | 5.0 bpw | **11.13** | 11.40 | 11.31 | **11.19**
34+
-- | FP16 | **10.91** | 11.24 | 11.16 | **11.05**
3235

3336

3437
### HumanEval
@@ -37,6 +40,8 @@ The following are HumanEval tests on various full-precision and quantized models
3740
respectively. Number of samples per task is limited to 10 (still giving 39360 completions in total produced
3841
over about 24 hours.)
3942

43+
The following tests were done prior to the improvements in 0.0.18-dev.
44+
4045
#### pass@1
4146

4247
Model | Precision | FP16 cache | Q4 cache | diff

0 commit comments

Comments
 (0)