Skip to content

Commit dd46229

Browse files
authored
Merge pull request oobabooga#5530 from oobabooga/dev
Merge dev branch
2 parents 771c592 + af0bbf5 commit dd46229

22 files changed

+310
-182
lines changed

download-model.py

+22-13
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,16 @@
2626

2727
class ModelDownloader:
2828
def __init__(self, max_retries=5):
29-
self.session = requests.Session()
30-
if max_retries:
31-
self.session.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries))
32-
self.session.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries))
29+
self.max_retries = max_retries
30+
31+
def get_session(self):
32+
session = requests.Session()
33+
if self.max_retries:
34+
session.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=self.max_retries))
35+
session.mount('https://huggingface.co', HTTPAdapter(max_retries=self.max_retries))
3336

3437
if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None:
35-
self.session.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS'))
38+
session.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS'))
3639

3740
try:
3841
from huggingface_hub import get_token
@@ -41,7 +44,9 @@ def __init__(self, max_retries=5):
4144
token = os.getenv("HF_TOKEN")
4245

4346
if token is not None:
44-
self.session.headers = {'authorization': f'Bearer {token}'}
47+
session.headers = {'authorization': f'Bearer {token}'}
48+
49+
return session
4550

4651
def sanitize_model_and_branch_names(self, model, branch):
4752
if model[-1] == '/':
@@ -65,6 +70,7 @@ def sanitize_model_and_branch_names(self, model, branch):
6570
return model, branch
6671

6772
def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None):
73+
session = self.get_session()
6874
page = f"/api/models/{model}/tree/{branch}"
6975
cursor = b""
7076

@@ -78,7 +84,7 @@ def get_download_links_from_huggingface(self, model, branch, text_only=False, sp
7884
is_lora = False
7985
while True:
8086
url = f"{base}{page}" + (f"?cursor={cursor.decode()}" if cursor else "")
81-
r = self.session.get(url, timeout=10)
87+
r = session.get(url, timeout=10)
8288
r.raise_for_status()
8389
content = r.content
8490

@@ -156,9 +162,8 @@ def get_download_links_from_huggingface(self, model, branch, text_only=False, sp
156162
is_llamacpp = has_gguf and specific_file is not None
157163
return links, sha256, is_lora, is_llamacpp
158164

159-
def get_output_folder(self, model, branch, is_lora, is_llamacpp=False, base_folder=None):
160-
if base_folder is None:
161-
base_folder = 'models' if not is_lora else 'loras'
165+
def get_output_folder(self, model, branch, is_lora, is_llamacpp=False):
166+
base_folder = 'models' if not is_lora else 'loras'
162167

163168
# If the model is of type GGUF, save directly in the base_folder
164169
if is_llamacpp:
@@ -172,14 +177,15 @@ def get_output_folder(self, model, branch, is_lora, is_llamacpp=False, base_fold
172177
return output_folder
173178

174179
def get_single_file(self, url, output_folder, start_from_scratch=False):
180+
session = self.get_session()
175181
filename = Path(url.rsplit('/', 1)[1])
176182
output_path = output_folder / filename
177183
headers = {}
178184
mode = 'wb'
179185
if output_path.exists() and not start_from_scratch:
180186

181187
# Check if the file has already been downloaded completely
182-
r = self.session.get(url, stream=True, timeout=10)
188+
r = session.get(url, stream=True, timeout=10)
183189
total_size = int(r.headers.get('content-length', 0))
184190
if output_path.stat().st_size >= total_size:
185191
return
@@ -188,7 +194,7 @@ def get_single_file(self, url, output_folder, start_from_scratch=False):
188194
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
189195
mode = 'ab'
190196

191-
with self.session.get(url, stream=True, headers=headers, timeout=10) as r:
197+
with session.get(url, stream=True, headers=headers, timeout=10) as r:
192198
r.raise_for_status() # Do not continue the download if the request was unsuccessful
193199
total_size = int(r.headers.get('content-length', 0))
194200
block_size = 1024 * 1024 # 1MB
@@ -303,7 +309,10 @@ def check_model_files(self, model, branch, links, sha256, output_folder):
303309
links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=args.text_only, specific_file=specific_file)
304310

305311
# Get the output folder
306-
output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp, base_folder=args.output)
312+
if args.output:
313+
output_folder = Path(args.output)
314+
else:
315+
output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp)
307316

308317
if args.check:
309318
# Check previously downloaded files

instruction-templates/Mistral.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ instruction_template: |-
44
{{- message['content'] -}}
55
{%- else -%}
66
{%- if message['role'] == 'user' -%}
7-
{{-' [INST] ' + message['content'].rstrip() + ' [/INST] '-}}
7+
{{-'[INST] ' + message['content'].rstrip() + ' [/INST]'-}}
88
{%- else -%}
99
{{-'' + message['content'] + '</s>' -}}
1010
{%- endif -%}

modules/chat.py

+36-32
Original file line numberDiff line numberDiff line change
@@ -166,53 +166,54 @@ def make_prompt(messages):
166166
prompt = remove_extra_bos(prompt)
167167
return prompt
168168

169-
# Handle truncation
170-
max_length = get_max_prompt_length(state)
171169
prompt = make_prompt(messages)
172-
encoded_length = get_encoded_length(prompt)
173170

174-
while len(messages) > 0 and encoded_length > max_length:
171+
# Handle truncation
172+
if shared.tokenizer is not None:
173+
max_length = get_max_prompt_length(state)
174+
encoded_length = get_encoded_length(prompt)
175+
while len(messages) > 0 and encoded_length > max_length:
175176

176-
# Remove old message, save system message
177-
if len(messages) > 2 and messages[0]['role'] == 'system':
178-
messages.pop(1)
177+
# Remove old message, save system message
178+
if len(messages) > 2 and messages[0]['role'] == 'system':
179+
messages.pop(1)
179180

180-
# Remove old message when no system message is present
181-
elif len(messages) > 1 and messages[0]['role'] != 'system':
182-
messages.pop(0)
181+
# Remove old message when no system message is present
182+
elif len(messages) > 1 and messages[0]['role'] != 'system':
183+
messages.pop(0)
183184

184-
# Resort to truncating the user input
185-
else:
185+
# Resort to truncating the user input
186+
else:
187+
188+
user_message = messages[-1]['content']
189+
190+
# Bisect the truncation point
191+
left, right = 0, len(user_message) - 1
186192

187-
user_message = messages[-1]['content']
193+
while right - left > 1:
194+
mid = (left + right) // 2
188195

189-
# Bisect the truncation point
190-
left, right = 0, len(user_message) - 1
196+
messages[-1]['content'] = user_message[mid:]
197+
prompt = make_prompt(messages)
198+
encoded_length = get_encoded_length(prompt)
191199

192-
while right - left > 1:
193-
mid = (left + right) // 2
200+
if encoded_length <= max_length:
201+
right = mid
202+
else:
203+
left = mid
194204

195-
messages[-1]['content'] = user_message[mid:]
205+
messages[-1]['content'] = user_message[right:]
196206
prompt = make_prompt(messages)
197207
encoded_length = get_encoded_length(prompt)
198-
199-
if encoded_length <= max_length:
200-
right = mid
208+
if encoded_length > max_length:
209+
logger.error(f"Failed to build the chat prompt. The input is too long for the available context length.\n\nTruncation length: {state['truncation_length']}\nmax_new_tokens: {state['max_new_tokens']} (is it too high?)\nAvailable context length: {max_length}\n")
210+
raise ValueError
201211
else:
202-
left = mid
212+
logger.warning(f"The input has been truncated. Context length: {state['truncation_length']}, max_new_tokens: {state['max_new_tokens']}, available context length: {max_length}.")
213+
break
203214

204-
messages[-1]['content'] = user_message[right:]
205215
prompt = make_prompt(messages)
206216
encoded_length = get_encoded_length(prompt)
207-
if encoded_length > max_length:
208-
logger.error(f"Failed to build the chat prompt. The input is too long for the available context length.\n\nTruncation length: {state['truncation_length']}\nmax_new_tokens: {state['max_new_tokens']} (is it too high?)\nAvailable context length: {max_length}\n")
209-
raise ValueError
210-
else:
211-
logger.warning(f"The input has been truncated. Context length: {state['truncation_length']}, max_new_tokens: {state['max_new_tokens']}, available context length: {max_length}.")
212-
break
213-
214-
prompt = make_prompt(messages)
215-
encoded_length = get_encoded_length(prompt)
216217

217218
if also_return_rows:
218219
return prompt, [message['content'] for message in messages]
@@ -690,6 +691,9 @@ def load_character(character, name1, name2):
690691

691692

692693
def load_instruction_template(template):
694+
if template == 'None':
695+
return ''
696+
693697
for filepath in [Path(f'instruction-templates/{template}.yaml'), Path('instruction-templates/Alpaca.yaml')]:
694698
if filepath.exists():
695699
break

modules/exllamav2.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,21 @@ def from_pretrained(self, path_to_model):
5151

5252
model = ExLlamaV2(config)
5353

54-
split = None
55-
if shared.args.gpu_split:
56-
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
57-
58-
model.load(split)
59-
60-
tokenizer = ExLlamaV2Tokenizer(config)
6154
if shared.args.cache_8bit:
62-
cache = ExLlamaV2Cache_8bit(model)
55+
cache = ExLlamaV2Cache_8bit(model, lazy=True)
6356
else:
64-
cache = ExLlamaV2Cache(model)
57+
cache = ExLlamaV2Cache(model, lazy=True)
6558

59+
if shared.args.autosplit:
60+
model.load_autosplit(cache)
61+
else:
62+
split = None
63+
if shared.args.gpu_split:
64+
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
65+
66+
model.load(split)
67+
68+
tokenizer = ExLlamaV2Tokenizer(config)
6669
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
6770

6871
result = self()

modules/exllamav2_hf.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,22 @@ def __init__(self, config: ExLlamaV2Config):
3737
super().__init__(PretrainedConfig())
3838
self.ex_config = config
3939
self.ex_model = ExLlamaV2(config)
40-
split = None
41-
if shared.args.gpu_split:
42-
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
43-
44-
self.ex_model.load(split)
45-
self.generation_config = GenerationConfig()
4640
self.loras = None
41+
self.generation_config = GenerationConfig()
4742

4843
if shared.args.cache_8bit:
49-
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model)
44+
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=True)
5045
else:
51-
self.ex_cache = ExLlamaV2Cache(self.ex_model)
46+
self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=True)
47+
48+
if shared.args.autosplit:
49+
self.ex_model.load_autosplit(self.ex_cache)
50+
else:
51+
split = None
52+
if shared.args.gpu_split:
53+
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
54+
55+
self.ex_model.load(split)
5256

5357
self.past_seq = None
5458
if shared.args.cfg_cache:

modules/loaders.py

+2
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
'no_flash_attn',
7979
'num_experts_per_token',
8080
'cache_8bit',
81+
'autosplit',
8182
'alpha_value',
8283
'compress_pos_emb',
8384
'trust_remote_code',
@@ -89,6 +90,7 @@
8990
'no_flash_attn',
9091
'num_experts_per_token',
9192
'cache_8bit',
93+
'autosplit',
9294
'alpha_value',
9395
'compress_pos_emb',
9496
'exllamav2_info',

modules/models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def llamacpp_HF_loader(model_name):
257257
path = Path(f'{shared.args.model_dir}/{model_name}')
258258

259259
# Check if a HF tokenizer is available for the model
260-
if all((path / file).exists() for file in ['tokenizer.model', 'tokenizer_config.json']):
260+
if all((path / file).exists() for file in ['tokenizer_config.json']):
261261
logger.info(f'Using tokenizer from: \"{path}\"')
262262
else:
263263
logger.error("Could not load the model because a tokenizer in Transformers format was not found.")

modules/models_settings.py

+47-18
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ def infer_loader(model_name, model_settings):
153153
loader = 'ExLlamav2_HF'
154154
elif (path_to_model / 'quant_config.json').exists() or re.match(r'.*-awq', model_name.lower()):
155155
loader = 'AutoAWQ'
156+
elif len(list(path_to_model.glob('*.gguf'))) > 0 and path_to_model.is_dir() and (path_to_model / 'tokenizer_config.json').exists():
157+
loader = 'llamacpp_HF'
156158
elif len(list(path_to_model.glob('*.gguf'))) > 0:
157159
loader = 'llama.cpp'
158160
elif re.match(r'.*\.gguf', model_name.lower()):
@@ -225,7 +227,7 @@ def apply_model_settings_to_state(model, state):
225227
loader = model_settings.pop('loader')
226228

227229
# If the user is using an alternative loader for the same model type, let them keep using it
228-
if not (loader == 'ExLlamav2_HF' and state['loader'] in ['GPTQ-for-LLaMa', 'ExLlamav2', 'AutoGPTQ']) and not (loader == 'llama.cpp' and state['loader'] in ['llamacpp_HF', 'ctransformers']):
230+
if not (loader == 'ExLlamav2_HF' and state['loader'] in ['GPTQ-for-LLaMa', 'ExLlamav2', 'AutoGPTQ']) and not (loader == 'llama.cpp' and state['loader'] in ['ctransformers']):
229231
state['loader'] = loader
230232

231233
for k in model_settings:
@@ -243,27 +245,54 @@ def save_model_settings(model, state):
243245
Save the settings for this model to models/config-user.yaml
244246
'''
245247
if model == 'None':
246-
yield ("Not saving the settings because no model is loaded.")
248+
yield ("Not saving the settings because no model is selected in the menu.")
247249
return
248250

249-
with Path(f'{shared.args.model_dir}/config-user.yaml') as p:
250-
if p.exists():
251-
user_config = yaml.safe_load(open(p, 'r').read())
252-
else:
253-
user_config = {}
251+
user_config = shared.load_user_config()
252+
model_regex = model + '$' # For exact matches
253+
if model_regex not in user_config:
254+
user_config[model_regex] = {}
255+
256+
for k in ui.list_model_elements():
257+
if k == 'loader' or k in loaders.loaders_and_params[state['loader']]:
258+
user_config[model_regex][k] = state[k]
254259

255-
model_regex = model + '$' # For exact matches
256-
if model_regex not in user_config:
257-
user_config[model_regex] = {}
260+
shared.user_config = user_config
258261

259-
for k in ui.list_model_elements():
260-
if k == 'loader' or k in loaders.loaders_and_params[state['loader']]:
261-
user_config[model_regex][k] = state[k]
262+
output = yaml.dump(user_config, sort_keys=False)
263+
p = Path(f'{shared.args.model_dir}/config-user.yaml')
264+
with open(p, 'w') as f:
265+
f.write(output)
262266

263-
shared.user_config = user_config
267+
yield (f"Settings for `{model}` saved to `{p}`.")
264268

265-
output = yaml.dump(user_config, sort_keys=False)
266-
with open(p, 'w') as f:
267-
f.write(output)
268269

269-
yield (f"Settings for `{model}` saved to `{p}`.")
270+
def save_instruction_template(model, template):
271+
'''
272+
Similar to the function above, but it saves only the instruction template.
273+
'''
274+
if model == 'None':
275+
yield ("Not saving the template because no model is selected in the menu.")
276+
return
277+
278+
user_config = shared.load_user_config()
279+
model_regex = model + '$' # For exact matches
280+
if model_regex not in user_config:
281+
user_config[model_regex] = {}
282+
283+
if template == 'None':
284+
user_config[model_regex].pop('instruction_template', None)
285+
else:
286+
user_config[model_regex]['instruction_template'] = template
287+
288+
shared.user_config = user_config
289+
290+
output = yaml.dump(user_config, sort_keys=False)
291+
p = Path(f'{shared.args.model_dir}/config-user.yaml')
292+
with open(p, 'w') as f:
293+
f.write(output)
294+
295+
if template == 'None':
296+
yield (f"Instruction template for `{model}` unset in `{p}`, as the value for template was `{template}`.")
297+
else:
298+
yield (f"Instruction template for `{model}` saved to `{p}` as `{template}`.")

0 commit comments

Comments
 (0)