Skip to content

Commit ae8322d

Browse files
committed
Merge remote-tracking branch 'turboderp/master'
2 parents 8eb2694 + ee0e84b commit ae8322d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

96 files changed

+7444
-1595
lines changed

README.md

+55-19
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,57 @@
33
ExLlamaV2 is an inference library for running local LLMs on modern consumer GPUs.
44

55

6-
## Overview of differences compared to V1
6+
## New in v0.1.0:
7+
8+
- ExLlamaV2 now supports paged attention via [Flash Attention](https://github.com/Dao-AILab/flash-attention) 2.5.7+
9+
- New generator with dynamic batching, smart prompt caching, K/V cache deduplication and simplified API
10+
11+
![alt_text](doc/dynamic_gen.gif)
12+
13+
## Dynamic generator examples
14+
15+
The dynamic generator supports all inference, sampling and speculative decoding features of the previous two
16+
generators, consolidated into one API (with the exception of FP8 cache, though the Q4 cache mode is supported and
17+
performs better anyway, see [here](doc/qcache_eval.md).)
18+
19+
- Single generation:
20+
```python
21+
output = generator.generate(prompt = "Hello, my name is", max_new_tokens = 200)
22+
```
23+
- Batched generation:
24+
```python
25+
outputs = generator.generate(
26+
prompt = [
27+
"Hello, my name is",
28+
"Once upon a time,",
29+
"Large language models are",
30+
],
31+
max_new_tokens = 200
32+
)
33+
```
34+
- Streamed generation with `asyncio`:
35+
```python
36+
job = ExLlamaV2DynamicJobAsync(
37+
generator,
38+
input_ids = tokenizer.encode("You can lead a horse to water"),
39+
banned_strings = ["make it drink"],
40+
gen_settings = ExLlamaV2Sampler.Settings.greedy(),
41+
max_new_tokens = 200
42+
)
43+
async for result in job:
44+
text = result.get("text", "")
45+
print(text, end = "")
46+
```
47+
See the full, updated examples [here](https://github.com/turboderp/exllamav2/tree/master/examples).
48+
49+
750

8-
- Faster, better kernels
9-
- Cleaner and more versatile codebase
10-
- Support for a new quant format (see below)
1151

1252

1353
## Performance
1454

15-
Some quick tests to compare performance with V1. There may be more performance optimizations in the future, and
16-
speeds will vary across GPUs, with slow CPUs still being a potential bottleneck:
55+
Some quick tests to compare performance with ExLlama V1. There may be more performance optimizations in the future,
56+
and speeds will vary across GPUs, with slow CPUs still being a potential bottleneck:
1757

1858
| Model | Mode | Size | grpsz | act | 3090Ti | 4090 |
1959
|------------|--------------|-------|-------|-----|---------|-------------|
@@ -33,13 +73,11 @@ speeds will vary across GPUs, with slow CPUs still being a potential bottleneck:
3373
## How to
3474

3575
To install from the repo you'll need the CUDA Toolkit and either gcc on Linux or (Build Tools for) Visual Studio
36-
on Windows). Also make sure you have an appropriate version of [PyTorch](https://pytorch.org/get-started/locally/),
37-
then run:
76+
on Windows). Also make sure you have an appropriate version of [PyTorch](https://pytorch.org/get-started/locally/), then run:
3877

3978
```
4079
git clone https://github.com/turboderp/exllamav2
4180
cd exllamav2
42-
# Optionally, create and activate a new conda environment
4381
pip install -r requirements.txt
4482
pip install .
4583

@@ -50,13 +88,11 @@ python test_inference.py -m <path_to_model> -p "Once upon a time,"
5088
A simple console chatbot is included. Run it with:
5189
5290
```
53-
python examples/chat.py -m <path_to_model> -mode llama
54-
# Append the '--gpu_split auto' flag for multi-GPU inference
91+
python examples/chat.py -m <path_to_model> -mode llama -gs auto
5592
```
5693
5794
58-
The `-mode` argument chooses the prompt format to use. `llama` is for the Llama(2)-chat finetunes, while `codellama`
59-
probably works better for CodeLlama-instruct. `raw` will produce a simple chatlog-style chat that works with base
95+
The `-mode` argument chooses the prompt format to use. `raw` will produce a simple chatlog-style chat that works with base
6096
models and various other finetunes. Run with `-modes` for a list of all available prompt formats. You can also provide
6197
a custom system prompt with `-sp`.
6298
@@ -100,8 +136,11 @@ C++ extension in the process. Instead, the extension will be built the first tim
100136
101137
### Method 2: Install from release (with prebuilt extension)
102138
103-
Releases are available [here](https://github.com/turboderp/exllamav2/releases), with prebuilt wheels that contain the
104-
extension binaries. Make sure to grab the right version, matching your platform, Python version (`cp`) and CUDA version.
139+
Releases are available [here](https://github.com/turboderp/exllamav2/releases), with prebuilt wheels that contain the extension binaries. Make sure to grab
140+
the right version, matching your platform, Python version (`cp`) and CUDA version. Crucially, you must also match
141+
the prebuilt wheel with your PyTorch version, since the Torch C++ extension ABI breaks with every new version of
142+
PyTorch.
143+
105144
Either download an appropriate wheel or install directly from the appropriate URL:
106145
107146
```
@@ -113,15 +152,12 @@ can also be installed this way, and it will build the extension while installing
113152
114153
### Method 3: Install from PyPI
115154
116-
A PyPI package is available as well. It can be installed with:
155+
A PyPI package is available as well. This is the same as the JIT version (see above). It can be installed with:
117156
118157
```
119158
pip install exllamav2
120159
```
121160
122-
The version available through PyPI is the JIT version (see above). Still working on a solution for distributing
123-
prebuilt wheels via PyPI.
124-
125161
126162
## EXL2 quantization
127163

conversion/adaptivegptq.py

-1
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,6 @@ def pack(self, key, qparams):
631631

632632
qst_packed = torch.zeros((qst.shape[0], qst.shape[1] * qparams.scale_bits // 32), dtype = torch.int32, device = self.device)
633633
if qparams.scale_bits == 4: ext_c.pack_rows_4(qst, qst_packed)
634-
# if qparams.scale_bits == 6: ext_c.pack_rows_6(qst, qst_packed) # TODO:
635634
output[key + ".q_scale"] = qst_packed
636635

637636
qwt_packed = []

conversion/bot_status.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
2+
import json
3+
4+
def print_stage(
5+
job: dict,
6+
stage: str,
7+
progress: int,
8+
max_progress: int,
9+
):
10+
if not job["status_output"]: return
11+
12+
status = {
13+
"stage": stage,
14+
"completion": round(progress / max_progress, 4)
15+
}
16+
17+
print("[STATUS]" + json.dumps(status) + "[/STATUS]")

conversion/compile.py

+10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from exllamav2.model import \
22
(
33
ExLlamaV2Embedding,
4+
ExLlamaV2PosEmbedding,
45
ExLlamaV2Attention,
56
ExLlamaV2MLP,
67
ExLlamaV2MoEMLP,
@@ -16,6 +17,7 @@
1617
import os, glob, shutil, json
1718
from safetensors import safe_open
1819
from safetensors.torch import save_file
20+
from conversion.bot_status import print_stage
1921

2022
def _tsize(t):
2123

@@ -69,6 +71,10 @@ def compile_model(job, save_fn, model):
6971

7072
d = get_f_module(job, module); out_dict.update(d); current_size += _dsize(d)
7173

74+
if isinstance(module, ExLlamaV2PosEmbedding):
75+
76+
d = get_f_module(job, module); out_dict.update(d); current_size += _dsize(d)
77+
7278
if isinstance(module, ExLlamaV2Attention):
7379

7480
d = get_f_module(job, module.input_layernorm); out_dict.update(d); current_size += _dsize(d)
@@ -126,6 +132,8 @@ def compile_model(job, save_fn, model):
126132

127133
if current_size > shard_bytes or index == len(model.modules):
128134

135+
print_stage(job, "Compiling", index, len(model.modules))
136+
129137
save_dict = {}
130138
dont_save_dict = {}
131139
this_shard_size = 0
@@ -237,3 +245,5 @@ def compile_model(job, save_fn, model):
237245

238246
with open(config_json, "w") as f:
239247
f.write(json.dumps(config_dict, indent = 4))
248+
249+
print_stage(job, "Compiling", len(model.modules), len(model.modules))

conversion/measure.py

+30-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from exllamav2.model import \
22
(
33
ExLlamaV2Embedding,
4+
ExLlamaV2PosEmbedding,
45
ExLlamaV2Attention,
56
ExLlamaV2MLP,
67
ExLlamaV2MoEMLP,
@@ -19,6 +20,7 @@
1920
import os, time, math, json
2021
import torch.nn.functional as F
2122
import gc
23+
from conversion.bot_status import print_stage
2224

2325
# graceful exiting
2426
import signal
@@ -68,6 +70,8 @@ def list_live_tensors():
6870

6971
def embeddings(job, save_fn, model, measure = False):
7072

73+
print_stage(job, "Embeddings", 0, 1)
74+
7175
module = model.modules[0]
7276
assert isinstance(module, ExLlamaV2Embedding)
7377

@@ -82,6 +86,8 @@ def embeddings(job, save_fn, model, measure = False):
8286
embeddings_dict = { f"row.{i:05}": hidden_state[i:i+1, :, :].contiguous() for i in range(hidden_state.shape[0]) }
8387
save_file(embeddings_dict, os.path.join(job["out_dir"], "hidden_states.safetensors"))
8488

89+
print_stage(job, "Embeddings", 1, 1)
90+
8591

8692
# Test quantization options
8793

@@ -119,18 +125,18 @@ def test_quant(source: ExLlamaV2Linear,
119125

120126
def test_error(module, hidden_states, target_states, cache, attn_params):
121127

122-
rfn_sum = 0
128+
rfn_sum = torch.tensor(0.0).cuda()
123129
rfn_count = 0
124130
for x, xref in zip(hidden_states, target_states):
125131
x = x.cuda()
126132
xref = xref.cuda()
127133
xtest = module.forward(x, cache, attn_params)
128134
xtest = xtest[0].float()
129135
xref = xref[0].float()
130-
rfn_sum += (torch.linalg.norm(xtest - xref, 'fro') / torch.linalg.norm(xref, 'fro')).item()
136+
rfn_sum += torch.linalg.norm(xtest - xref, 'fro') / torch.linalg.norm(xref, 'fro')
131137
rfn_count += 1
132138

133-
return max(1e-6, 1 - (rfn_sum / rfn_count))
139+
return max(1e-6, 1 - (rfn_sum.item() / rfn_count))
134140

135141

136142
def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_params, keep_q = False):
@@ -376,7 +382,7 @@ def print_status_box(*content_lines):
376382
print('-' * box_width)
377383

378384
@torch.inference_mode()
379-
def measure_quant(job, save_fn, model):
385+
def measure_quant(job, save_fn, model, hidden_state_offload_layers):
380386

381387
# vars for status box
382388
time_spent_list = []
@@ -412,12 +418,15 @@ def measure_quant(job, save_fn, model):
412418

413419
hidden_states = []
414420
with safe_open(states_filename, framework = "pt", device = "cpu") as f:
415-
for k in sorted(f.keys()):
416-
hidden_states.append(f.get_tensor(k))
421+
for i, k in enumerate(sorted(f.keys())):
422+
t = f.get_tensor(k)
423+
hidden_states.append(t.to("cuda:0") if i < hidden_state_offload_layers else t)
417424

418425
index = job["last_module_idx"]
419426
while True:
420427

428+
print_stage(job, "Measuring", index, len(model.modules))
429+
421430
# sig handler should catch it faster in most cases
422431
if interrupted:
423432
print("Measurement process was interrupted. Please decide:")
@@ -487,6 +496,9 @@ def measure_quant(job, save_fn, model):
487496
elif isinstance(module, ExLlamaV2RMSNorm) or isinstance(module, ExLlamaV2LayerNorm):
488497
mode = "norm"
489498

499+
elif isinstance(module, ExLlamaV2PosEmbedding):
500+
mode = "pos_emb"
501+
490502
# Reference forward pass
491503

492504
cache = None
@@ -504,18 +516,19 @@ def measure_quant(job, save_fn, model):
504516

505517
x = hidden_states[i].to("cuda:0")
506518
outputs = module.forward(x, cache, attn_params, intermediates = True)
519+
target_device = "cuda:0" if i < hidden_state_offload_layers else "cpu"
507520

508521
# Hessians
509522

510523
if mode == "self_attn":
511524
quantizers["q_proj"].add_batch(outputs["post_norm"]) # Reuse H for K and V
512525
quantizers["o_proj"].add_batch(outputs["attn_output"])
513-
target_states.append(outputs["hidden_states"].to("cpu"))
526+
target_states.append(outputs["hidden_states"].to(target_device))
514527

515528
if mode == "mlp":
516529
quantizers["up_proj"].add_batch(outputs["post_norm"]) # Reuse H for gate_proj
517530
quantizers["down_proj"].add_batch(outputs["pre_down"])
518-
target_states.append(outputs["hidden_states"].to("cpu"))
531+
target_states.append(outputs["hidden_states"].to(target_device))
519532

520533
if mode == "block_sparse_moe":
521534
for j in range(model.config.num_experts):
@@ -526,16 +539,19 @@ def measure_quant(job, save_fn, model):
526539
uncalibrated_experts[j] += 1
527540
else:
528541
uncalibrated_experts[j] += 1
529-
target_states.append(outputs["hidden_states"].to("cpu"))
542+
target_states.append(outputs["hidden_states"].to(target_device))
530543

531544
if mode == "parallel_decoder":
532545
quantizers["q_proj"].add_batch(outputs["post_norm"]) # Reuse H for K, V, up_proj and gate_proj
533546
quantizers["o_proj"].add_batch(outputs["attn_output"])
534547
quantizers["down_proj"].add_batch(outputs["pre_down"])
535548
hidden_states[i] = outputs["post_norm"]
536-
target_states_attn.append(outputs["hidden_states_attn"].to("cpu"))
537-
target_states_mlp.append(outputs["hidden_states_mlp"].to("cpu"))
538-
target_states.append(outputs["hidden_states"].to("cpu"))
549+
target_states_attn.append(outputs["hidden_states_attn"].to(target_device))
550+
target_states_mlp.append(outputs["hidden_states_mlp"].to(target_device))
551+
target_states.append(outputs["hidden_states"].to(target_device))
552+
553+
if mode == "pos_emb":
554+
target_states.append(outputs["hidden_states"].to(target_device))
539555

540556
# For MoE layers, warn if any layer received less than 10% of a calibration batch
541557

@@ -647,6 +663,8 @@ def measure_quant(job, save_fn, model):
647663

648664
last_snapshot_time = time.time()
649665

666+
print_stage(job, "Measuring", len(model.modules), len(model.modules))
667+
650668
# Export measurement
651669

652670
exp_measurement = { "measurement": job["measurement"],

0 commit comments

Comments
 (0)