Skip to content

Commit d011040

Browse files
authored
Merge pull request oobabooga#6300 from oobabooga/dev
Merge dev branch
2 parents 498fec2 + 608545d commit d011040

23 files changed

+123
-93
lines changed

css/main.css

+2-1
Original file line numberDiff line numberDiff line change
@@ -404,13 +404,14 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* {
404404
.message-body h3,
405405
.message-body h4 {
406406
color: var(--body-text-color);
407+
margin: 20px 0 10px 0;
407408
}
408409

409410
.dark .message q {
410411
color: #f5b031;
411412
}
412413

413-
.message q::before, .message q::after {
414+
.message-body q::before, .message-body q::after {
414415
content: "";
415416
}
416417

download-model.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,15 @@ def get_single_file(self, url, output_folder, start_from_scratch=False):
212212
total_size = int(r.headers.get('content-length', 0))
213213
block_size = 1024 * 1024 # 1MB
214214

215+
filename_str = str(filename) # Convert PosixPath to string if necessary
216+
215217
tqdm_kwargs = {
216218
'total': total_size,
217-
'unit': 'iB',
219+
'unit': 'B',
218220
'unit_scale': True,
219-
'bar_format': '{l_bar}{bar}| {n_fmt}/{total_fmt} {rate_fmt}'
221+
'unit_divisor': 1024,
222+
'bar_format': '{desc}{percentage:3.0f}%|{bar:50}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]',
223+
'desc': f"{filename_str}: "
220224
}
221225

222226
if 'COLAB_GPU' in os.environ:
@@ -233,7 +237,7 @@ def get_single_file(self, url, output_folder, start_from_scratch=False):
233237
t.update(len(data))
234238
if total_size != 0 and self.progress_bar is not None:
235239
count += len(data)
236-
self.progress_bar(float(count) / float(total_size), f"{filename}")
240+
self.progress_bar(float(count) / float(total_size), f"{filename_str}")
237241

238242
break # Exit loop if successful
239243
except (RequestException, ConnectionError, Timeout) as e:

extensions/openai/completions.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,6 @@ def chat_streaming_chunk(content):
319319
yield {'prompt': prompt}
320320
return
321321

322-
token_count = len(encode(prompt)[0])
323322
debug_msg({'prompt': prompt, 'generate_params': generate_params})
324323

325324
if stream:
@@ -330,7 +329,6 @@ def chat_streaming_chunk(content):
330329

331330
answer = ''
332331
seen_content = ''
333-
completion_token_count = 0
334332

335333
for a in generator:
336334
answer = a['internal'][-1][1]
@@ -345,6 +343,7 @@ def chat_streaming_chunk(content):
345343
chunk = chat_streaming_chunk(new_content)
346344
yield chunk
347345

346+
token_count = len(encode(prompt)[0])
348347
completion_token_count = len(encode(answer)[0])
349348
stop_reason = "stop"
350349
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
@@ -429,8 +428,6 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
429428
prompt = decode(prompt)[0]
430429

431430
prefix = prompt if echo else ''
432-
token_count = len(encode(prompt)[0])
433-
total_prompt_token_count += token_count
434431

435432
# generate reply #######################################
436433
debug_msg({'prompt': prompt, 'generate_params': generate_params})
@@ -440,6 +437,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
440437
for a in generator:
441438
answer = a
442439

440+
token_count = len(encode(prompt)[0])
441+
total_prompt_token_count += token_count
443442
completion_token_count = len(encode(answer)[0])
444443
total_completion_token_count += completion_token_count
445444
stop_reason = "stop"

js/main.js

+7-1
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ function doSyntaxHighlighting() {
213213
renderMathInElement(element, {
214214
delimiters: [
215215
{ left: "$$", right: "$$", display: true },
216+
{ left: "$", right: "$", display: false },
216217
{ left: "\\(", right: "\\)", display: false },
217218
{ left: "\\[", right: "\\]", display: true },
218219
],
@@ -459,7 +460,12 @@ function updateCssProperties() {
459460

460461
// Adjust scrollTop based on input height change
461462
if (chatInputHeight !== currentChatInputHeight) {
462-
chatContainer.scrollTop += chatInputHeight - currentChatInputHeight;
463+
if (!isScrolled && chatInputHeight < currentChatInputHeight) {
464+
chatContainer.scrollTop = chatContainer.scrollHeight;
465+
} else {
466+
chatContainer.scrollTop += chatInputHeight - currentChatInputHeight;
467+
}
468+
463469
currentChatInputHeight = chatInputHeight;
464470
}
465471
}

modules/chat.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@
2626
from modules.text_generation import (
2727
generate_reply,
2828
get_encoded_length,
29-
get_max_prompt_length,
30-
stop_everything_event
29+
get_max_prompt_length
3130
)
3231
from modules.utils import delete_file, get_available_characters, save_file
3332

@@ -93,8 +92,16 @@ def generate_chat_prompt(user_input, state, **kwargs):
9392
chat_template_str = replace_character_names(chat_template_str, state['name1'], state['name2'])
9493

9594
instruction_template = jinja_env.from_string(state['instruction_template_str'])
96-
instruct_renderer = partial(instruction_template.render, add_generation_prompt=False)
9795
chat_template = jinja_env.from_string(chat_template_str)
96+
97+
instruct_renderer = partial(
98+
instruction_template.render,
99+
builtin_tools=None,
100+
tools=None,
101+
tools_in_user_message=False,
102+
add_generation_prompt=False
103+
)
104+
98105
chat_renderer = partial(
99106
chat_template.render,
100107
add_generation_prompt=False,
@@ -1036,13 +1043,6 @@ def handle_remove_last_click(state):
10361043
return [history, html, last_input]
10371044

10381045

1039-
def handle_stop_click(state):
1040-
stop_everything_event()
1041-
html = redraw_html(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
1042-
1043-
return html
1044-
1045-
10461046
def handle_unique_id_select(state):
10471047
history = load_history(state['unique_id'], state['character_menu'], state['mode'])
10481048
html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])

modules/html_generator.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ def replace_blockquote(m):
7272
@functools.lru_cache(maxsize=None)
7373
def convert_to_markdown(string):
7474

75+
# Make \[ \] LaTeX equations inline
76+
pattern = r'^\s*\\\[\s*\n([\s\S]*?)\n\s*\\\]\s*$'
77+
replacement = r'\\[ \1 \\]'
78+
string = re.sub(pattern, replacement, string, flags=re.MULTILINE)
79+
80+
# Escape backslashes
81+
string = string.replace('\\', '\\\\')
82+
7583
# Quote to <q></q>
7684
string = replace_quotes(string)
7785

@@ -95,12 +103,27 @@ def convert_to_markdown(string):
95103

96104
result = ''
97105
is_code = False
106+
is_latex = False
98107
for line in string.split('\n'):
99-
if line.lstrip(' ').startswith('```'):
108+
stripped_line = line.strip()
109+
110+
if stripped_line.startswith('```'):
100111
is_code = not is_code
112+
elif stripped_line.startswith('$$'):
113+
is_latex = not is_latex
114+
elif stripped_line.endswith('$$'):
115+
is_latex = False
116+
elif stripped_line.startswith('\\\\['):
117+
is_latex = True
118+
elif stripped_line.startswith('\\\\]'):
119+
is_latex = False
120+
elif stripped_line.endswith('\\\\]'):
121+
is_latex = False
101122

102123
result += line
103-
if is_code or line.startswith('|'): # Don't add an extra \n for tables or code
124+
125+
# Don't add an extra \n for tables, code, or LaTeX
126+
if is_code or is_latex or line.startswith('|'):
104127
result += '\n'
105128
else:
106129
result += '\n\n'

modules/logits.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313

1414

1515
def get_next_logits(*args, **kwargs):
16-
if shared.args.idle_timeout > 0 and shared.model is None and shared.previous_model_name not in [None, 'None']:
17-
shared.model, shared.tokenizer = load_model(shared.previous_model_name)
16+
if shared.args.idle_timeout > 0 and shared.model is None and shared.model_name not in [None, 'None']:
17+
shared.model, shared.tokenizer = load_model(shared.model_name)
1818

1919
needs_lock = not args[2] # use_samplers
2020
if needs_lock:

modules/models.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -368,14 +368,15 @@ def clear_torch_cache():
368368
torch.cuda.empty_cache()
369369

370370

371-
def unload_model():
371+
def unload_model(keep_model_name=False):
372372
shared.model = shared.tokenizer = None
373-
shared.previous_model_name = shared.model_name
374-
shared.model_name = 'None'
375373
shared.lora_names = []
376374
shared.model_dirty_from_training = False
377375
clear_torch_cache()
378376

377+
if not keep_model_name:
378+
shared.model_name = 'None'
379+
379380

380381
def reload_model():
381382
unload_model()
@@ -393,7 +394,7 @@ def unload_model_if_idle():
393394
if time.time() - last_generation_time > shared.args.idle_timeout * 60:
394395
if shared.model is not None:
395396
logger.info("Unloading the model for inactivity.")
396-
unload_model()
397+
unload_model(keep_model_name=True)
397398
finally:
398399
shared.generation_lock.release()
399400

modules/shared.py

-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
model = None
1414
tokenizer = None
1515
model_name = 'None'
16-
previous_model_name = 'None'
1716
is_seq2seq = False
1817
model_dirty_from_training = False
1918
lora_names = []
@@ -44,8 +43,6 @@
4443
'negative_prompt': '',
4544
'seed': -1,
4645
'truncation_length': 2048,
47-
'truncation_length_min': 0,
48-
'truncation_length_max': 200000,
4946
'max_tokens_second': 0,
5047
'max_updates_second': 0,
5148
'prompt_lookup_num_tokens': 0,

modules/text_generation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232

3333

3434
def generate_reply(*args, **kwargs):
35-
if shared.args.idle_timeout > 0 and shared.model is None and shared.previous_model_name not in [None, 'None']:
36-
shared.model, shared.tokenizer = load_model(shared.previous_model_name)
35+
if shared.args.idle_timeout > 0 and shared.model is None and shared.model_name not in [None, 'None']:
36+
shared.model, shared.tokenizer = load_model(shared.model_name)
3737

3838
shared.generation_lock.acquire()
3939
try:

modules/training.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def create_ui():
165165
stride_length = gr.Slider(label='Stride', minimum=0, maximum=32768, value=512, step=256, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')
166166

167167
with gr.Column():
168-
max_length = gr.Slider(label='max_length', minimum=0, maximum=shared.settings['truncation_length_max'], value=0, step=256, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.')
168+
max_length = gr.Number(label='max_length', precision=0, step=256, value=0, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.')
169169

170170
with gr.Row():
171171
start_current_evaluation = gr.Button("Evaluate loaded model", interactive=not mu)

modules/ui_chat.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from modules import chat, shared, ui, utils
99
from modules.html_generator import chat_html_wrapper
10+
from modules.text_generation import stop_everything_event
1011
from modules.utils import gradio
1112

1213
inputs = ('Chat input', 'interface_state')
@@ -221,8 +222,8 @@ def create_event_handlers():
221222
chat.handle_remove_last_click, gradio('interface_state'), gradio('history', 'display', 'textbox'), show_progress=False)
222223

223224
shared.gradio['Stop'].click(
224-
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
225-
chat.handle_stop_click, gradio('interface_state'), gradio('display'), show_progress=False)
225+
stop_everything_event, None, None, queue=False).then(
226+
chat.redraw_html, gradio(reload_arr), gradio('display'), show_progress=False)
226227

227228
if not shared.args.multi_user:
228229
shared.gradio['unique_id'].select(

modules/ui_model_menu.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -93,19 +93,19 @@ def create_ui():
9393

9494
shared.gradio['hqq_backend'] = gr.Dropdown(label="hqq_backend", choices=["PYTORCH", "PYTORCH_COMPILE", "ATEN"], value=shared.args.hqq_backend)
9595
shared.gradio['n_gpu_layers'] = gr.Slider(label="n-gpu-layers", minimum=0, maximum=256, value=shared.args.n_gpu_layers, info='Must be set to more than 0 for your GPU to be used.')
96-
shared.gradio['n_ctx'] = gr.Slider(minimum=0, maximum=shared.settings['truncation_length_max'], step=256, label="n_ctx", value=shared.args.n_ctx, info='Context length. Try lowering this if you run out of memory while loading the model.')
96+
shared.gradio['n_ctx'] = gr.Number(label="n_ctx", precision=0, step=256, value=shared.args.n_ctx, info='Context length. Try lowering this if you run out of memory while loading the model.')
9797
shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')
9898
shared.gradio['n_batch'] = gr.Slider(label="n_batch", minimum=1, maximum=2048, step=1, value=shared.args.n_batch)
9999
shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=256, value=shared.args.threads)
100100
shared.gradio['threads_batch'] = gr.Slider(label="threads_batch", minimum=0, step=1, maximum=256, value=shared.args.threads_batch)
101101
shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=shared.args.wbits if shared.args.wbits > 0 else "None")
102102
shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=shared.args.groupsize if shared.args.groupsize > 0 else "None")
103103
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
104-
shared.gradio['max_seq_len'] = gr.Slider(label='max_seq_len', minimum=0, maximum=shared.settings['truncation_length_max'], step=256, info='Context length. Try lowering this if you run out of memory while loading the model.', value=shared.args.max_seq_len)
104+
shared.gradio['max_seq_len'] = gr.Number(label='max_seq_len', precision=0, step=256, value=shared.args.max_seq_len, info='Context length. Try lowering this if you run out of memory while loading the model.')
105105
with gr.Blocks():
106106
shared.gradio['alpha_value'] = gr.Number(label='alpha_value', value=shared.args.alpha_value, precision=2, info='Positional embeddings alpha factor for NTK RoPE scaling. Recommended values (NTKv1): 1.75 for 1.5x context, 2.5 for 2x context. Use either this or compress_pos_emb, not both.')
107107
shared.gradio['rope_freq_base'] = gr.Number(label='rope_freq_base', value=shared.args.rope_freq_base, precision=0, info='Positional embeddings frequency base for NTK RoPE scaling. Related to alpha_value by rope_freq_base = 10000 * alpha_value ^ (64 / 63). 0 = from model.')
108-
shared.gradio['compress_pos_emb'] = gr.Number(label='compress_pos_emb', value=shared.args.compress_pos_emb, precision=0, info='Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.')
108+
shared.gradio['compress_pos_emb'] = gr.Number(label='compress_pos_emb', value=shared.args.compress_pos_emb, precision=2, info='Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.')
109109

110110
shared.gradio['autogptq_info'] = gr.Markdown('ExLlamav2_HF is recommended over AutoGPTQ for models derived from Llama.')
111111

modules/ui_parameters.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def create_ui(default_preset):
8989
shared.gradio['sampler_priority'] = gr.Textbox(value=generate_params['sampler_priority'], lines=12, label='Sampler priority', info='Parameter names separated by new lines or commas.')
9090

9191
with gr.Column():
92-
shared.gradio['truncation_length'] = gr.Slider(value=get_truncation_length(), minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=256, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
92+
shared.gradio['truncation_length'] = gr.Number(precision=0, step=256, value=get_truncation_length(), label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
9393
shared.gradio['prompt_lookup_num_tokens'] = gr.Slider(value=shared.settings['prompt_lookup_num_tokens'], minimum=0, maximum=10, step=1, label='prompt_lookup_num_tokens', info='Activates Prompt Lookup Decoding.')
9494
shared.gradio['max_tokens_second'] = gr.Slider(value=shared.settings['max_tokens_second'], minimum=0, maximum=20, step=1, label='Maximum tokens/second', info='To make text readable in real time.')
9595
shared.gradio['max_updates_second'] = gr.Slider(value=shared.settings['max_updates_second'], minimum=0, maximum=24, step=1, label='Maximum UI updates/second', info='Set this if you experience lag in the UI during streaming.')

0 commit comments

Comments
 (0)