Skip to content

Commit 3b06cb4

Browse files
authored
Merge pull request oobabooga#6421 from oobabooga/dev
Merge dev branch
2 parents f98431c + cca9d6e commit 3b06cb4

31 files changed

+467
-307
lines changed

README.md

+22-26
Original file line numberDiff line numberDiff line change
@@ -10,33 +10,29 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
1010

1111
## Features
1212

13-
* Multiple backends for text generation in a single UI and API, including [Transformers](https://github.com/huggingface/transformers), [llama.cpp](https://github.com/ggerganov/llama.cpp) (through [llama-cpp-python](https://github.com/abetlen/llama-cpp-python)), [ExLlamaV2](https://github.com/turboderp/exllamav2), [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ), and [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM). [AutoAWQ](https://github.com/casper-hansen/AutoAWQ), [HQQ](https://github.com/mobiusml/hqq), and [AQLM](https://github.com/Vahe1994/AQLM) are also supported through the Transformers loader.
14-
* OpenAI-compatible API server with Chat and Completions endpoints – see the [examples](https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples).
15-
* Automatic prompt formatting for each model using the Jinja2 template in its metadata.
16-
* Three chat modes: `instruct`, `chat-instruct`, and `chat`, allowing for both instruction-following and casual conversations with characters. `chat-instruct` mode automatically applies the model's template to the chat prompt, ensuring high-quality outputs without manual setup.
17-
* "Past chats" menu to quickly switch between conversations and start new ones.
18-
* Free-form generation in the Default/Notebook tabs without being limited to chat turns. Send formatted chat conversations from the Chat tab to these tabs.
19-
* Multiple sampling parameters and generation options for sophisticated text generation control.
20-
* Easy switching between different models through the UI without restarting, using the "Model" tab.
21-
* Simple LoRA fine-tuning tool to customize models with your data.
22-
* All in one folder. The requirements are installed in a self-contained `installer_files` folder that doesn't interfere with the system's environment.
23-
* Extensions support, including numerous built-in and user-contributed extensions. See [the wiki](https://github.com/oobabooga/text-generation-webui/wiki/07-%E2%80%90-Extensions) and [the extensions directory](https://github.com/oobabooga/text-generation-webui-extensions) for details.
13+
- Supports multiple text generation backends in one UI/API, including [Transformers](https://github.com/huggingface/transformers), [llama.cpp](https://github.com/ggerganov/llama.cpp), and [ExLlamaV2](https://github.com/turboderp/exllamav2). [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ), [AutoAWQ](https://github.com/casper-hansen/AutoAWQ), [HQQ](https://github.com/mobiusml/hqq), and [AQLM](https://github.com/Vahe1994/AQLM) are also supported but you need to install them manually.
14+
- OpenAI-compatible API with Chat and Completions endpoints – see [examples](https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples).
15+
- Automatic prompt formatting using Jinja2 templates.
16+
- Three chat modes: `instruct`, `chat-instruct`, and `chat`, with automatic prompt templates in `chat-instruct`.
17+
- "Past chats" menu to quickly switch between conversations.
18+
- Free-form text generation in the Default/Notebook tabs without being limited to chat turns. You can send formatted conversations from the Chat tab to these.
19+
- Multiple sampling parameters and generation options for sophisticated text generation control.
20+
- Switch between different models easily in the UI without restarting.
21+
- Simple LoRA fine-tuning tool.
22+
- Requirements installed in a self-contained `installer_files` directory that doesn't interfere with the system environment.
23+
- Extension support, with numerous built-in and user-contributed extensions available. See the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/07-%E2%80%90-Extensions) and [extensions directory](https://github.com/oobabooga/text-generation-webui-extensions) for details.
2424

2525
## How to install
2626

27-
1) Clone or [download](https://github.com/oobabooga/text-generation-webui/archive/refs/heads/main.zip) the repository.
28-
2) Run the `start_linux.sh`, `start_windows.bat`, `start_macos.sh`, or `start_wsl.bat` script depending on your OS.
27+
1) Clone or [download the repository](https://github.com/oobabooga/text-generation-webui/archive/refs/heads/main.zip).
28+
2) Run the script that matches your OS: `start_linux.sh`, `start_windows.bat`, `start_macos.sh`, or `start_wsl.bat`.
2929
3) Select your GPU vendor when asked.
3030
4) Once the installation ends, browse to `http://localhost:7860`.
3131
5) Have fun!
3232

33-
To restart the web UI in the future, run the `start_` script again.
33+
To restart the web UI later, just run the same `start_` script. If you need to reinstall, delete the `installer_files` folder created during setup and run the script again.
3434

35-
This script creates an `installer_files` folder where it sets up the project's requirements. If you need to reinstall the requirements, just delete that folder and start the web UI again.
36-
37-
The script accepts command-line flags, such as `./start_linux.sh --help`. Alternatively, you can edit the `CMD_FLAGS.txt` file with a text editor and add your flags there, such as `--api` in case you need to use the API.
38-
39-
To get updates in the future, run `update_wizard_linux.sh`, `update_wizard_windows.bat`, `update_wizard_macos.sh`, or `update_wizard_wsl.bat`.
35+
You can use command-line flags, like `./start_linux.sh --help`, or add them to `CMD_FLAGS.txt` (such as `--api` to enable API use). To update the project, run `update_wizard_linux.sh`, `update_wizard_windows.bat`, `update_wizard_macos.sh`, or `update_wizard_wsl.bat`.
4036

4137
<details>
4238
<summary>
@@ -80,12 +76,12 @@ conda activate textgen
8076

8177
| System | GPU | Command |
8278
|--------|---------|---------|
83-
| Linux/WSL | NVIDIA | `pip3 install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121` |
84-
| Linux/WSL | CPU only | `pip3 install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cpu` |
85-
| Linux | AMD | `pip3 install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/rocm5.6` |
86-
| MacOS + MPS | Any | `pip3 install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2` |
87-
| Windows | NVIDIA | `pip3 install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121` |
88-
| Windows | CPU only | `pip3 install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2` |
79+
| Linux/WSL | NVIDIA | `pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121` |
80+
| Linux/WSL | CPU only | `pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cpu` |
81+
| Linux | AMD | `pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/rocm6.1` |
82+
| MacOS + MPS | Any | `pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1` |
83+
| Windows | NVIDIA | `pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121` |
84+
| Windows | CPU only | `pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1` |
8985

9086
The up-to-date commands can be found here: https://pytorch.org/get-started/locally/.
9187

@@ -150,7 +146,7 @@ Then browse to
150146
1) For Kepler GPUs and older, you will need to install CUDA 11.8 instead of 12:
151147

152148
```
153-
pip3 install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118
149+
pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu118
154150
conda install -y -c "nvidia/label/cuda-11.8.0" cuda-runtime
155151
```
156152

docs/12 - OpenAI API.md

+25-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Add `--api` to your command-line flags.
1919

2020
### Examples
2121

22-
For the documentation with all the parameters and their types, consult `http://127.0.0.1:5000/docs` or the [typing.py](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/openai/typing.py) file.
22+
For the documentation with all the endpoints, parameters and their types, consult `http://127.0.0.1:5000/docs` or the [typing.py](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/openai/typing.py) file.
2323

2424
The official examples in the [OpenAI documentation](https://platform.openai.com/docs/api-reference) should also work, and the same parameters apply (although the API here has more optional parameters).
2525

@@ -114,6 +114,30 @@ curl -k http://127.0.0.1:5000/v1/internal/logits \
114114
}'
115115
```
116116

117+
#### List models
118+
119+
```shell
120+
curl -k http://127.0.0.1:5000/v1/internal/model/list \
121+
-H "Content-Type: application/json"
122+
```
123+
124+
#### Load model
125+
126+
```shell
127+
curl -k http://127.0.0.1:5000/v1/internal/model/load \
128+
-H "Content-Type: application/json" \
129+
-d '{
130+
"model_name": "model_name",
131+
"args": {
132+
"load_in_4bit": true,
133+
"n_gpu_layers": 12
134+
},
135+
"settings": {
136+
"instruction_template": "Alpaca"
137+
}
138+
}'
139+
```
140+
117141
#### Python chat example
118142

119143
```python

extensions/Training_PRO/script.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def ui():
241241
stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')
242242

243243
with gr.Column():
244-
max_length = gr.Slider(label='max_length', minimum=0, maximum=shared.settings['truncation_length_max'], value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.')
244+
max_length = gr.Number(label='max_length', precision=0, step=256, value=0, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.')
245245

246246
with gr.Row():
247247
start_current_evaluation = gr.Button("Evaluate loaded model")

extensions/openai/completions.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,9 @@ def convert_history(history):
154154
elif item['type'] == 'text' and isinstance(item['text'], str):
155155
content = item['text']
156156

157-
if image_url and content:
157+
if image_url:
158158
new_history.append({"image_url": image_url, "role": "user"})
159+
if content:
159160
new_history.append({"content": content, "role": "user"})
160161
else:
161162
new_history.append(entry)
@@ -234,7 +235,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
234235
raise InvalidRequestError(message="messages: missing content", param='messages')
235236

236237
# Chat Completions
237-
object_type = 'chat.completions' if not stream else 'chat.completions.chunk'
238+
object_type = 'chat.completion' if not stream else 'chat.completion.chunk'
238239
created_time = int(time.time())
239240
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
240241
resp_list = 'data' if is_legacy else 'choices'

extensions/openai/typing.py

+2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class GenerationOptions(BaseModel):
3737
dry_base: float = 1.75
3838
dry_allowed_length: int = 2
3939
dry_sequence_breakers: str = '"\\n", ":", "\\"", "*"'
40+
xtc_threshold: float = 0.1
41+
xtc_probability: float = 0
4042
truncation_length: int = 0
4143
max_tokens_second: int = 0
4244
prompt_lookup_num_tokens: int = 0

js/main.js

+11
Original file line numberDiff line numberDiff line change
@@ -600,4 +600,15 @@ headerBar.addEventListener("click", (e) => {
600600
}
601601
});
602602

603+
//------------------------------------------------
604+
// Add a confirmation dialog when leaving the page
605+
// Useful to avoid data loss
606+
//------------------------------------------------
607+
window.addEventListener('beforeunload', function (event) {
608+
// Cancel the event
609+
event.preventDefault();
610+
// Chrome requires returnValue to be set
611+
event.returnValue = '';
612+
});
613+
603614
moveToChatTab();

modules/LoRA.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from pathlib import Path
22

33
import torch
4-
from peft import PeftModel
54
from transformers import is_torch_xpu_available
65

76
import modules.shared as shared
@@ -85,6 +84,9 @@ def add_lora_autogptq(lora_names):
8584

8685

8786
def add_lora_transformers(lora_names):
87+
88+
from peft import PeftModel
89+
8890
prior_set = set(shared.lora_names)
8991
added_set = set(lora_names) - prior_set
9092
removed_set = prior_set - set(lora_names)

modules/chat.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -1059,7 +1059,12 @@ def handle_start_new_chat_click(state):
10591059

10601060
convert_to_markdown.cache_clear()
10611061

1062-
return [history, html, gr.update(choices=histories, value=histories[0][1])]
1062+
if len(histories) > 0:
1063+
past_chats_update = gr.update(choices=histories, value=histories[0][1])
1064+
else:
1065+
past_chats_update = gr.update(choices=histories)
1066+
1067+
return [history, html, past_chats_update]
10631068

10641069

10651070
def handle_delete_chat_confirm_click(state):
@@ -1110,10 +1115,15 @@ def handle_upload_chat_history(load_chat_history, state):
11101115

11111116
convert_to_markdown.cache_clear()
11121117

1118+
if len(histories) > 0:
1119+
past_chats_update = gr.update(choices=histories, value=histories[0][1])
1120+
else:
1121+
past_chats_update = gr.update(choices=histories)
1122+
11131123
return [
11141124
history,
11151125
html,
1116-
gr.update(choices=histories, value=histories[0][1])
1126+
past_chats_update
11171127
]
11181128

11191129

@@ -1132,6 +1142,11 @@ def handle_character_menu_change(state):
11321142

11331143
convert_to_markdown.cache_clear()
11341144

1145+
if len(histories) > 0:
1146+
past_chats_update = gr.update(choices=histories, value=histories[0][1])
1147+
else:
1148+
past_chats_update = gr.update(choices=histories)
1149+
11351150
return [
11361151
history,
11371152
html,
@@ -1140,7 +1155,7 @@ def handle_character_menu_change(state):
11401155
picture,
11411156
greeting,
11421157
context,
1143-
gr.update(choices=histories, value=histories[0][1]),
1158+
past_chats_update,
11441159
]
11451160

11461161

@@ -1151,12 +1166,17 @@ def handle_mode_change(state):
11511166

11521167
convert_to_markdown.cache_clear()
11531168

1169+
if len(histories) > 0:
1170+
past_chats_update = gr.update(choices=histories, value=histories[0][1])
1171+
else:
1172+
past_chats_update = gr.update(choices=histories)
1173+
11541174
return [
11551175
history,
11561176
html,
11571177
gr.update(visible=state['mode'] != 'instruct'),
11581178
gr.update(visible=state['mode'] == 'chat-instruct'),
1159-
gr.update(choices=histories, value=histories[0][1])
1179+
past_chats_update
11601180
]
11611181

11621182

modules/exllamav2.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
ExLlamaV2Cache,
88
ExLlamaV2Cache_8bit,
99
ExLlamaV2Cache_Q4,
10+
ExLlamaV2Cache_TP,
1011
ExLlamaV2Config,
1112
ExLlamaV2Tokenizer
1213
)
@@ -18,14 +19,6 @@
1819

1920
try:
2021
import flash_attn
21-
except ModuleNotFoundError:
22-
logger.warning(
23-
'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage '
24-
'to be a lot higher than it could be.\n'
25-
'Try installing flash-attention following the instructions here: '
26-
'https://github.com/Dao-AILab/flash-attention#installation-and-features'
27-
)
28-
pass
2922
except Exception:
3023
logger.warning('Failed to load flash-attention due to the following error:\n')
3124
traceback.print_exc()
@@ -54,21 +47,30 @@ def from_pretrained(self, path_to_model):
5447

5548
model = ExLlamaV2(config)
5649

57-
if not shared.args.autosplit:
58-
split = None
59-
if shared.args.gpu_split:
60-
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
50+
split = None
51+
if shared.args.gpu_split:
52+
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
6153

54+
if shared.args.enable_tp:
55+
model.load_tp(split)
56+
elif not shared.args.autosplit:
6257
model.load(split)
6358

59+
# Determine the correct cache type
6460
if shared.args.cache_8bit:
65-
cache = ExLlamaV2Cache_8bit(model, lazy=shared.args.autosplit)
61+
cache_type = ExLlamaV2Cache_8bit
6662
elif shared.args.cache_4bit:
67-
cache = ExLlamaV2Cache_Q4(model, lazy=shared.args.autosplit)
63+
cache_type = ExLlamaV2Cache_Q4
6864
else:
69-
cache = ExLlamaV2Cache(model, lazy=shared.args.autosplit)
65+
cache_type = ExLlamaV2Cache
7066

71-
if shared.args.autosplit:
67+
# Use TP if specified
68+
if shared.args.enable_tp:
69+
cache = ExLlamaV2Cache_TP(model, base=cache_type)
70+
else:
71+
cache = cache_type(model, lazy=shared.args.autosplit)
72+
73+
if shared.args.autosplit and not shared.args.enable_tp:
7274
model.load_autosplit(cache)
7375

7476
tokenizer = ExLlamaV2Tokenizer(config)

modules/exllamav2_hf.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
ExLlamaV2Cache,
1010
ExLlamaV2Cache_8bit,
1111
ExLlamaV2Cache_Q4,
12+
ExLlamaV2Cache_TP,
1213
ExLlamaV2Config
1314
)
1415
from torch.nn import CrossEntropyLoss
@@ -20,14 +21,6 @@
2021

2122
try:
2223
import flash_attn
23-
except ModuleNotFoundError:
24-
logger.warning(
25-
'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage '
26-
'to be a lot higher than it could be.\n'
27-
'Try installing flash-attention following the instructions here: '
28-
'https://github.com/Dao-AILab/flash-attention#installation-and-features'
29-
)
30-
pass
3124
except Exception:
3225
logger.warning('Failed to load flash-attention due to the following error:\n')
3326
traceback.print_exc()
@@ -42,21 +35,30 @@ def __init__(self, config: ExLlamaV2Config):
4235

4336
self.ex_model = ExLlamaV2(config)
4437

45-
if not shared.args.autosplit:
46-
split = None
47-
if shared.args.gpu_split:
48-
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
38+
split = None
39+
if shared.args.gpu_split:
40+
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
4941

42+
if shared.args.enable_tp:
43+
self.ex_model.load_tp(split)
44+
elif not shared.args.autosplit:
5045
self.ex_model.load(split)
5146

47+
# Determine the correct cache type
5248
if shared.args.cache_8bit:
53-
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=shared.args.autosplit)
49+
cache_type = ExLlamaV2Cache_8bit
5450
elif shared.args.cache_4bit:
55-
self.ex_cache = ExLlamaV2Cache_Q4(self.ex_model, lazy=shared.args.autosplit)
51+
cache_type = ExLlamaV2Cache_Q4
5652
else:
57-
self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=shared.args.autosplit)
53+
cache_type = ExLlamaV2Cache
5854

59-
if shared.args.autosplit:
55+
# Use TP if specified
56+
if shared.args.enable_tp:
57+
self.ex_cache = ExLlamaV2Cache_TP(self.ex_model, base=cache_type)
58+
else:
59+
self.ex_cache = cache_type(self.ex_model, lazy=shared.args.autosplit)
60+
61+
if shared.args.autosplit and not shared.args.enable_tp:
6062
self.ex_model.load_autosplit(self.ex_cache)
6163

6264
self.past_seq = None

0 commit comments

Comments
 (0)