Skip to content

Commit ca9aecf

Browse files
committed
Merge branch 'refs/heads/dev'
2 parents 106a9d1 + b3e07ee commit ca9aecf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+1191
-290
lines changed

eval/humaneval.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from exllamav2 import model_init
66
from exllamav2 import ExLlamaV2Cache, ExLlamaV2Cache_Q4, ExLlamaV2Cache_Q6, ExLlamaV2Cache_Q8
77
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler
8-
import argparse, contextlib
8+
import argparse, contextlib, subprocess
99
import util
1010

1111
# Args
@@ -20,6 +20,7 @@
2020
parser.add_argument("--max_tokens", type = int, default = 768, help = "Max number of tokens for each completion")
2121
parser.add_argument("-pf", "--prompt_format", type = str, help = "Instruct format to apply. Default is raw completion (for base models) ")
2222
parser.add_argument("-v", "--verbose", action = "store_true", help = "Spam completions to console while generating")
23+
parser.add_argument("-e", "--eval", action = "store_true", help = "Run evaluation script on output file after sampling")
2324
model_init.add_args(parser)
2425
args = parser.parse_args()
2526

@@ -52,6 +53,13 @@
5253
"<|start_header_id|>assistant<|end_header_id|>\n\n"
5354
"Sure! Here is how you might implement the function:\n\n```python\n{{problem}} ",
5455
" "
56+
),
57+
"gemma": (
58+
"<bos><start_of_turn>user\n"
59+
"Complete the following Python function:\n\n{{problem}}<|eot_id|>"
60+
"<start_of_turn>model\n"
61+
"```python\n{{problem}} ",
62+
" "
5563
)
5664
}
5765

@@ -192,3 +200,8 @@
192200
print(f" -- Saving: {args.output}")
193201
write_jsonl(args.output, samples)
194202

203+
# Optionally launch eval script
204+
205+
if args.eval:
206+
subprocess.run(["evaluate_functional_correctness", args.output])
207+

examples/chat.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161

6262
parser.add_argument("-ngram", "--ngram_decoding", action = "store_true", help = "Use n-gram speculative decoding")
6363

64-
parser.add_argument("-pt", "--print_timings", action = "store_true", help = "Output timings after each prompt")
64+
parser.add_argument("-pt", "--print_timings", action = "store_true", help = "Output timings/stats after each prompt")
6565
parser.add_argument("-amnesia", "--amnesia", action = "store_true", help = "Forget context after every response")
6666

6767
# Arrrgs
@@ -235,7 +235,9 @@ def get_tokenized_context(max_len):
235235

236236
# Stop conditions
237237

238-
generator.set_stop_conditions(prompt_format.stop_conditions(tokenizer))
238+
sc = prompt_format.stop_conditions(tokenizer)
239+
sc = [x for x in sc if x]
240+
generator.set_stop_conditions(sc)
239241

240242
# ANSI color codes
241243

@@ -393,8 +395,9 @@ def get_tokenized_context(max_len):
393395
else:
394396
sd_stats = ""
395397

398+
ctx_tokens = active_context.shape[-1]
396399
print()
397-
print(col_sysprompt + f"(Response: {response_tokens} tokens, {speed:.2f} tokens/second{sd_stats})" + col_default)
400+
print(col_sysprompt + f"(Context: {ctx_tokens} tokens, response: {response_tokens} tokens, {speed:.2f} tokens/second{sd_stats})" + col_default)
398401

399402
# Optionally forget context after each response
400403

examples/chat_prompts.py

+1
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def subs_prompt(self):
229229
def stop_conditions(self, tokenizer):
230230
return \
231231
[tokenizer.eos_token_id,
232+
tokenizer.single_id("<|im_end|>"),
232233
"""<|im_end|>"""]
233234

234235
def encoding_options(self):

examples/dynamic_gen.py

+2
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def main():
136136
if use_draft_model:
137137

138138
draft_config = ExLlamaV2Config(draft_model_dir)
139+
draft_config.arch_compat_overrides()
139140
draft_model = ExLlamaV2(draft_config)
140141

141142
draft_cache = ExLlamaV2Cache(
@@ -155,6 +156,7 @@ def main():
155156
# 2048, which will also be the limit of the chunk size for prefill used by the dynamic generator.
156157

157158
config = ExLlamaV2Config(model_dir)
159+
config.arch_compat_overrides()
158160
config.max_input_len = max_chunk_size
159161
config.max_attention_size = max_chunk_size ** 2
160162
model = ExLlamaV2(config)

examples/inference.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
model_dir = "/mnt/str/models/mistral-7b-exl2/4.0bpw"
99
config = ExLlamaV2Config(model_dir)
10+
config.arch_compat_overrides()
1011
model = ExLlamaV2(config)
1112
cache = ExLlamaV2Cache(model, max_seq_len = 32768, lazy = True)
1213
model.load_autosplit(cache, progress = True)

examples/inference_async.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
async def main():
1010
model_dir = "/mnt/str/models/llama3-8b-exl2/4.0bpw"
1111
config = ExLlamaV2Config(model_dir)
12+
config.arch_compat_overrides()
1213
model = ExLlamaV2(config)
1314
cache = ExLlamaV2Cache(model, lazy = True)
1415
model.load_autosplit(cache, progress = True)

examples/inference_banned_strings.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
model_dir = "/mnt/str/models/llama3-8b-instruct-exl2/6.0bpw/"
1111
config = ExLlamaV2Config(model_dir)
12+
config.arch_compat_overrides()
1213
model = ExLlamaV2(config)
1314
cache = ExLlamaV2Cache(model, lazy = True)
1415
model.load_autosplit(cache, progress = True)

examples/inference_cfg.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
model_dir = "/mnt/str/models/llama3-8b-instruct-exl2/4.0bpw"
1010
config = ExLlamaV2Config(model_dir)
11+
config.arch_compat_overrides()
1112
model = ExLlamaV2(config)
1213
cache = ExLlamaV2Cache(model, max_seq_len = 32768, lazy = True)
1314
model.load_autosplit(cache, progress = True)

examples/inference_dedup.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
model_dir = "/mnt/str/models/llama3-8b-instruct-exl2/4.0bpw"
1010
config = ExLlamaV2Config(model_dir)
11+
config.arch_compat_overrides()
1112
model = ExLlamaV2(config)
1213
cache = ExLlamaV2Cache(model, max_seq_len = 8192, lazy = True)
1314
model.load_autosplit(cache, progress = True)

examples/inference_json.py

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
model_dir = "/mnt/str/models/mistral-7b-exl2/4.0bpw"
1515
config = ExLlamaV2Config(model_dir)
16+
config.arch_compat_overrides()
1617
model = ExLlamaV2(config)
1718
cache = ExLlamaV2Cache(model, max_seq_len = 32768, lazy = True)
1819
model.load_autosplit(cache, progress = True)

examples/inference_lora.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
model_dir = "/mnt/str/models/llama2-7b-exl2/5.0bpw"
99
config = ExLlamaV2Config(model_dir)
10+
config.arch_compat_overrides()
1011
model = ExLlamaV2(config)
1112
cache = ExLlamaV2Cache(model, max_seq_len = 32768, lazy = True)
1213
model.load_autosplit(cache, progress = True)

examples/inference_speculative.py

+2
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212

1313
draft_model_dir = "/mnt/str/models/qwen2-1.5b-instruct-exl2/4.0bpw"
1414
draft_config = ExLlamaV2Config(draft_model_dir)
15+
draft_config.arch_compat_overrides()
1516
draft_model = ExLlamaV2(draft_config)
1617
draft_cache = ExLlamaV2Cache(draft_model, max_seq_len = total_cache_tokens, lazy = True)
1718
draft_model.load_autosplit(draft_cache, progress = True)
1819

1920
model_dir = "/mnt/str/models/qwen2-72b-instruct-exl2/6.0bpw"
2021
config = ExLlamaV2Config(model_dir)
22+
config.arch_compat_overrides()
2123
model = ExLlamaV2(config)
2224
cache = ExLlamaV2Cache(model, max_seq_len = total_cache_tokens, lazy = True)
2325
model.load_autosplit(cache, progress = True)

examples/inference_stream.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
model_dir = "/mnt/str/models/mistral-7b-exl2/4.0bpw"
1010
config = ExLlamaV2Config(model_dir)
11+
config.arch_compat_overrides()
1112
model = ExLlamaV2(config)
1213
cache = ExLlamaV2Cache(model, lazy = True)
1314
model.load_autosplit(cache, progress = True)

examples/util.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ def format_prompt(prompt_format, sp, p):
2929
f"{p}<|im_end|>\n"
3030
f"<|im_start|>assistant\n"
3131
)
32+
elif prompt_format == "gemma":
33+
return (
34+
f"<bos><start_of_turn>user\n"
35+
f"{p}<end_of_turn>\n"
36+
f"<start_of_turn>model\n"
37+
)
3238

3339
def get_stop_conditions(prompt_format, tokenizer):
3440
if prompt_format == "llama":
@@ -37,7 +43,8 @@ def get_stop_conditions(prompt_format, tokenizer):
3743
return [tokenizer.single_id("<|eot_id|>")]
3844
elif prompt_format == "granite":
3945
return [tokenizer.eos_token_id, "\n\nQuestion:"]
40-
46+
elif prompt_format == "gemma":
47+
return [tokenizer.eos_token_id, "<end_of_turn>"]
4148

4249
# Cached dataset loader
4350

exllamav2/architecture.py

+108
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
["ln_2"]]
1010
layer_keys_yi_norms = [["ln1", "input_layernorm"],
1111
["ln2", "post_attention_layernorm"]]
12+
layer_keys_gemma2_norms = [["input_layernorm"],
13+
["post_attention_layernorm"],
14+
["pre_feedforward_layernorm"],
15+
["post_feedforward_layernorm"]]
16+
layer_keys_internlm2_norms = [["attention_norm"],
17+
["ffn_norm"]]
1218
layer_keys_llama_attn = [["self_attn.q_proj"],
1319
["self_attn.k_proj"],
1420
["self_attn.v_proj"],
@@ -17,6 +23,10 @@
1723
["self_attn.c_attn", "self_attn.k_proj"],
1824
["self_attn.c_attn", "self_attn.v_proj"],
1925
["self_attn.o_proj"]]
26+
layer_keys_internlm2_attn = [["self_attn.wqkv", "self_attn.q_proj"],
27+
["self_attn.wqkv", "self_attn.k_proj"],
28+
["self_attn.wqkv", "self_attn.v_proj"],
29+
["self_attn.o_proj"]]
2030
layer_keys_dbrx_attn = [["self_attn.Wqkv", "self_attn.q_proj"],
2131
["self_attn.Wqkv", "self_attn.k_proj"],
2232
["self_attn.Wqkv", "self_attn.v_proj"],
@@ -28,6 +38,9 @@
2838
layer_keys_llama_mlp = [["mlp.down_proj"],
2939
["mlp.gate_proj"],
3040
["mlp.up_proj"]]
41+
layer_keys_internlm2_mlp = [["feed_forward.w1"],
42+
["feed_forward.w2"],
43+
["feed_forward.w3"]]
3144
layer_keys_phi3_mlp = [["mlp.down_proj"],
3245
["mlp.gate_up_proj", "mlp.gate_proj"],
3346
["mlp.gate_up_proj", "mlp.up_proj"]]
@@ -76,6 +89,10 @@
7689
("$h.", "model.layers."),
7790
("$wte.", "model.embed_tokens."),
7891
("$wpe.", "model.wpe.")]
92+
internlm2_keymap = [("$output.", "lm_head."),
93+
("$model.tok_embeddings.", "model.embed_tokens."),
94+
(".attention.", ".self_attn."),
95+
(".wo.", ".o_proj.")]
7996

8097
class RopeStyle(Enum):
8198
NONE = 0
@@ -100,6 +117,18 @@ def __init__(self, arch_string, read_config):
100117
self.orig_weights_transposed = False
101118
self.logit_scale_basedim = False
102119

120+
self.norm_key_1_post = None
121+
self.norm_key_2_post = None
122+
123+
self.swa = False
124+
self.alternating_swa = False
125+
126+
self.eager_attn_only = False
127+
self.clamp_hidden_states = False
128+
self.residual_stream_fp32 = False
129+
130+
self.fused_qkv_altpack = False
131+
103132
# Mistral
104133

105134
if arch_string == "MistralForCausalLM":
@@ -305,6 +334,45 @@ def __init__(self, arch_string, read_config):
305334
self.mqa = False
306335
self.scale_attn_weights = False
307336

337+
# Gemma2
338+
339+
if arch_string == "Gemma2ForCausalLM":
340+
arch_recognized = True
341+
self.layer_keys += \
342+
layer_keys_gemma2_norms + \
343+
layer_keys_llama_attn + \
344+
layer_keys_llama_mlp
345+
self.expect_keys += \
346+
expect_keys_gemma
347+
self.norm_eps_key = "rms_norm_eps"
348+
self.attention_bias_qkv = False
349+
self.attention_bias_o = False
350+
self.mlp_bias = False
351+
self.mlp_gate = True
352+
self.mlp_key_gate = ".mlp.gate_proj"
353+
self.mlp_key_up = ".mlp.up_proj"
354+
self.mlp_key_down = ".mlp.down_proj"
355+
self.mlp_act_func = "gelu"
356+
self.is_moe = False
357+
self.norm = "rmsnorm"
358+
self.lm_head_key = "model.embed_tokens"
359+
self.normalize_embeddings = True
360+
self.norm_key_1 = ".input_layernorm"
361+
self.norm_key_1_post = ".post_attention_layernorm"
362+
self.norm_key_2 = ".pre_feedforward_layernorm"
363+
self.norm_key_2_post = ".post_feedforward_layernorm"
364+
self.norm_constant_bias = 1
365+
self.parallel_decoder_blocks = False
366+
self.requires_bos = True
367+
self.rope_style = RopeStyle.NEOX
368+
self.keymap = None
369+
self.fused_qkv_key = None
370+
self.mqa = False
371+
self.scale_attn_weights = False
372+
self.pre_post_layernorm = True
373+
self.alternating_swa = True
374+
self.residual_stream_fp32 = True
375+
308376
# StarCoder2
309377

310378
if arch_string == "Starcoder2ForCausalLM":
@@ -586,6 +654,41 @@ def __init__(self, arch_string, read_config):
586654
self.scale_attn_weights = False
587655
self.logit_scale_basedim = True
588656

657+
# InternLM2
658+
659+
if arch_string == "InternLM2ForCausalLM":
660+
arch_recognized = True
661+
self.layer_keys += \
662+
layer_keys_internlm2_norms + \
663+
layer_keys_internlm2_attn + \
664+
layer_keys_internlm2_mlp
665+
self.expect_keys += \
666+
expect_keys_llama
667+
self.norm_eps_key = "rms_norm_eps"
668+
self.attention_bias_qkv = False
669+
self.attention_bias_o = False
670+
self.mlp_bias = False
671+
self.mlp_gate = True
672+
self.mlp_key_gate = ".feed_forward.w1"
673+
self.mlp_key_up = ".feed_forward.w3"
674+
self.mlp_key_down = ".feed_forward.w2"
675+
self.mlp_act_func = "silu"
676+
self.is_moe = False
677+
self.norm = "rmsnorm"
678+
self.lm_head_key = "lm_head"
679+
self.normalize_embeddings = False
680+
self.norm_key_1 = ".attention_norm"
681+
self.norm_key_2 = ".ffn_norm"
682+
self.norm_constant_bias = 0
683+
self.parallel_decoder_blocks = False
684+
self.requires_bos = False
685+
self.rope_style = RopeStyle.NEOX
686+
self.keymap = internlm2_keymap
687+
self.fused_qkv_key = "wqkv"
688+
self.fused_qkv_altpack = True
689+
self.mqa = False
690+
self.scale_attn_weights = False
691+
589692
# Llama (default + fallback)
590693

591694
if arch_string != "LlamaForCausalLM" and not arch_recognized:
@@ -637,6 +740,11 @@ def __init__(self, arch_string, read_config):
637740
self.expect_keys.remove(["lm_head"])
638741
self.lm_head_key = "model.embed_tokens"
639742

743+
# Sanity checks
744+
745+
if self.residual_stream_fp32:
746+
assert self.norm_key_1_post and self.norm_key_2_post, \
747+
"FP32 residual stream only implemented for arch with post layernorms"
640748

641749
def make_fused_mlp(self):
642750

0 commit comments

Comments
 (0)