Skip to content

Commit 3f19cbf

Browse files
committed
Merge remote-tracking branch 'turboderp/master'
2 parents 2776691 + b9c025b commit 3f19cbf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1197
-71
lines changed

.github/workflows/build-wheels-release-linux.yml

+347
Large diffs are not rendered by default.

.github/workflows/build-wheels-release-rocm62.yml

+347
Large diffs are not rendered by default.

examples/chat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def format_prompt(user_prompt, first):
188188
global system_prompt, prompt_format
189189

190190
if first:
191-
return prompt_format.first_prompt(not system_prompt) \
191+
return prompt_format.first_prompt(bool(system_prompt)) \
192192
.replace("<|system_prompt|>", system_prompt) \
193193
.replace("<|user_prompt|>", user_prompt)
194194
else:

examples/chat_prompts.py

+37
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,42 @@ def print_extra_newline(self):
547547
return True
548548

549549

550+
class PromptFormat_granite3(PromptFormat):
551+
description = "Granite 3"
552+
553+
def __init__(self):
554+
super().__init__()
555+
pass
556+
557+
def default_system_prompt(self):
558+
return "You are Granite, developed by IBM. You are a helpful AI assistant."
559+
560+
def first_prompt(self, sysprompt):
561+
r = ""
562+
if sysprompt:
563+
r += """<|start_of_role|>system<|end_of_role|>|system_prompt|><|end_of_text|>"""
564+
r += """<|start_of_role|>user<|end_of_role|><|user_prompt|><|end_of_text|>"""
565+
r += """<|start_of_role|>assistant<|end_of_role|>"""
566+
return r
567+
568+
def subs_prompt(self):
569+
r = ""
570+
r += """<|start_of_role|>user<|end_of_role|><|user_prompt|><|end_of_text|>"""
571+
r += """<|start_of_role|>assistant<|end_of_role|>"""
572+
return r
573+
574+
def stop_conditions(self, tokenizer):
575+
return [
576+
tokenizer.eos_token_id,
577+
]
578+
579+
def encoding_options(self):
580+
return True, False, True
581+
582+
def print_extra_newline(self):
583+
return True
584+
585+
550586
class PromptFormat_cohere(PromptFormat):
551587
description = "Cohere"
552588

@@ -610,4 +646,5 @@ def print_extra_newline(self):
610646
"cohere": PromptFormat_cohere,
611647
"phi3": PromptFormat_phi3,
612648
"granite": PromptFormat_granite,
649+
"granite3": PromptFormat_granite3,
613650
}
File renamed without changes.
File renamed without changes.

examples/media/test_video_01.png

72.5 KB
Loading

examples/media/test_video_02.png

73.1 KB
Loading

examples/media/test_video_03.png

76.6 KB
Loading

examples/media/test_video_04.png

77.9 KB
Loading

examples/media/test_video_05.png

79.4 KB
Loading

examples/media/test_video_06.png

79.3 KB
Loading

examples/media/test_video_07.png

81.4 KB
Loading

examples/media/test_video_08.png

82 KB
Loading

examples/media/test_video_09.png

82 KB
Loading

examples/media/test_video_10.png

84.5 KB
Loading

examples/media/test_video_11.png

86 KB
Loading

examples/media/test_video_12.png

86.2 KB
Loading

examples/media/test_video_13.png

86.6 KB
Loading

examples/media/test_video_14.png

87 KB
Loading

examples/media/test_video_15.png

87.8 KB
Loading

examples/media/test_video_16.png

87.3 KB
Loading

examples/media/test_video_17.png

87.3 KB
Loading

examples/media/test_video_18.png

87.2 KB
Loading

examples/media/test_video_19.png

89.8 KB
Loading

examples/media/test_video_20.png

90.9 KB
Loading

examples/media/test_video_21.png

90.1 KB
Loading

examples/media/test_video_22.png

91 KB
Loading

examples/multimodal.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from PIL import Image
1919
import requests
2020

21+
import torch
22+
torch.set_printoptions(precision = 5, sci_mode = False, linewidth=200)
23+
2124
# Models used:
2225
#
2326
# Pixtral:
@@ -39,8 +42,8 @@
3942
model_directory = "/mnt/str/models/qwen2-vl-7b-instruct-exl2/6.0bpw"
4043

4144
images = [
42-
{"file": "test_image_1.jpg"},
43-
{"file": "test_image_2.jpg"},
45+
{"file": "media/test_image_1.jpg"},
46+
{"file": "media/test_image_2.jpg"},
4447
# {"url": "https://media.istockphoto.com/id/1212540739/photo/mom-cat-with-kitten.jpg?s=612x612&w=0&k=20&c=RwoWm5-6iY0np7FuKWn8FTSieWxIoO917FF47LfcBKE="},
4548
# {"url": "https://i.dailymail.co.uk/1s/2023/07/10/21/73050285-12283411-Which_way_should_I_go_One_lady_from_the_US_shared_this_incredibl-a-4_1689019614007.jpg"},
4649
# {"url": "https://images.fineartamerica.com/images-medium-large-5/metal-household-objects-trevor-clifford-photography.jpg"}
@@ -127,7 +130,7 @@ def get_image(file = None, url = None):
127130
"<|im_start|>user\n" +
128131
placeholders +
129132
instruction +
130-
"\n" +
133+
"<|im_end|>\n" +
131134
"<|im_start|>assistant\n"
132135
)
133136

examples/multimodal_video.py

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import sys, os
2+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
3+
4+
from exllamav2 import (
5+
ExLlamaV2,
6+
ExLlamaV2Config,
7+
ExLlamaV2Cache,
8+
ExLlamaV2Tokenizer,
9+
ExLlamaV2VisionTower,
10+
)
11+
12+
from exllamav2.generator import (
13+
ExLlamaV2DynamicGenerator,
14+
ExLlamaV2DynamicJob,
15+
ExLlamaV2Sampler,
16+
)
17+
18+
from PIL import Image
19+
import requests, glob
20+
21+
import torch
22+
torch.set_printoptions(precision = 5, sci_mode = False, linewidth=200)
23+
24+
# Model used:
25+
#
26+
# Qwen2-VL:
27+
# https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct
28+
# https://huggingface.co/turboderp/Qwen2-VL-7B-Instruct-exl2
29+
30+
streaming = True
31+
greedy = True
32+
33+
model_directory = "/mnt/str/models/qwen2-vl-7b-instruct-exl2/6.0bpw"
34+
images_mask = os.path.join(os.path.dirname(os.path.abspath(__file__)), "media/test_video_*.png")
35+
36+
frames = [
37+
{"file": f}
38+
for f in sorted(glob.glob(images_mask))
39+
]
40+
41+
instruction = "Describe this video."
42+
43+
# Initialize model
44+
45+
config = ExLlamaV2Config(model_directory)
46+
config.max_seq_len = 16384 # Pixtral default is 1M
47+
48+
# Load vision model and multimodal projector and initialize preprocessor
49+
50+
vision_model = ExLlamaV2VisionTower(config)
51+
vision_model.load(progress = True)
52+
53+
# Load EXL2 model
54+
55+
model = ExLlamaV2(config)
56+
cache = ExLlamaV2Cache(model, lazy = True, max_seq_len = 16384)
57+
model.load_autosplit(cache, progress = True)
58+
tokenizer = ExLlamaV2Tokenizer(config)
59+
60+
# Create generator
61+
62+
generator = ExLlamaV2DynamicGenerator(
63+
model = model,
64+
cache = cache,
65+
tokenizer = tokenizer,
66+
)
67+
68+
# Util function to get a PIL image from a URL or from a file in the script's directory
69+
70+
def get_image(file = None, url = None):
71+
assert (file or url) and not (file and url)
72+
if file:
73+
script_dir = os.path.dirname(os.path.abspath(__file__))
74+
file_path = os.path.join(script_dir, file)
75+
return Image.open(file_path)
76+
elif url:
77+
return Image.open(requests.get(url, stream = True).raw)
78+
79+
# Convert video to embeddings. Aliases can be given explicitly with the text_alias argument, but here we
80+
# use automatically assigned unique identifiers, then concatenate them into a string
81+
82+
video_embedding = vision_model.get_video_embeddings(
83+
model = model,
84+
tokenizer = tokenizer,
85+
video = [get_image(**img_args) for img_args in frames],
86+
)
87+
video_embeddings = [video_embedding]
88+
89+
# Define prompt
90+
91+
prompt = (
92+
"<|im_start|>system\n" +
93+
"You are a helpful assistant.<|im_end|>\n" +
94+
"<|im_start|>user\n" +
95+
video_embedding.text_alias +
96+
# "\n" +
97+
instruction +
98+
"<|im_end|>\n" +
99+
"<|im_start|>assistant\n"
100+
)
101+
102+
# Generate
103+
104+
if streaming:
105+
106+
input_ids = tokenizer.encode(
107+
prompt,
108+
# add_bos = True,
109+
encode_special_tokens = True,
110+
embeddings = video_embeddings,
111+
)
112+
113+
job = ExLlamaV2DynamicJob(
114+
input_ids = input_ids,
115+
max_new_tokens = 500,
116+
decode_special_tokens = True,
117+
stop_conditions = [tokenizer.eos_token_id],
118+
gen_settings = ExLlamaV2Sampler.Settings.greedy() if greedy else None,
119+
embeddings = video_embeddings,
120+
)
121+
122+
generator.enqueue(job)
123+
124+
print()
125+
print(prompt, end = ""); sys.stdout.flush()
126+
127+
eos = False
128+
while generator.num_remaining_jobs():
129+
results = generator.iterate()
130+
for result in results:
131+
text = result.get("text", "")
132+
print(text, end = ""); sys.stdout.flush()
133+
134+
print()
135+
136+
else:
137+
138+
output = generator.generate(
139+
prompt = prompt,
140+
max_new_tokens = 500,
141+
add_bos = True,
142+
encode_special_tokens = True,
143+
decode_special_tokens = True,
144+
stop_conditions = [tokenizer.eos_token_id],
145+
gen_settings = ExLlamaV2Sampler.Settings.greedy() if greedy else None,
146+
embeddings = video_embeddings,
147+
)
148+
149+
print(output)

exllamav2/architecture.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ class Params:
312312
})
313313
self.mmp.mlp_gate = False
314314
self.mmp.mlp_act_func = "gelu"
315-
self.mmp.mlp_bias = True
315+
self.mmp.mlp_bias = bool(read_config.get("multimodal_projector_bias", True))
316316

317317
# Yi
318318

@@ -515,6 +515,28 @@ class Params:
515515
self.lm.parallel_decoder_blocks = True
516516
self.lm.requires_bos = True
517517

518+
# Cohere 2
519+
520+
if arch_string == "Cohere2ForCausalLM":
521+
arch_recognized = True
522+
self.lm.layer_keys += \
523+
layer_keys_cohere_norms + \
524+
layer_keys_llama_attn + \
525+
layer_keys_llama_mlp
526+
self.lm.expect_keys += \
527+
expect_keys_gemma
528+
self.lm.keys.update({
529+
"norm_eps": "layer_norm_eps",
530+
"lm_head": "model.embed_tokens",
531+
"norm_1": ".input_layernorm",
532+
"norm_2": None,
533+
})
534+
self.lm.norm = "layernorm"
535+
self.lm.rope_style = RopeStyle.GPTJ
536+
self.lm.parallel_decoder_blocks = True
537+
self.lm.requires_bos = True
538+
self.lm.alternating_swa = True
539+
518540
# DBRX
519541

520542
if arch_string == "DbrxForCausalLM":
@@ -659,6 +681,17 @@ class Params:
659681
self.lm.expect_keys += \
660682
expect_keys_llama
661683

684+
# Granite (v3)
685+
686+
if arch_string == "GraniteForCausalLM":
687+
arch_recognized = True
688+
self.lm.layer_keys += \
689+
layer_keys_llama_norms + \
690+
layer_keys_llama_attn + \
691+
layer_keys_llama_mlp
692+
self.lm.expect_keys += \
693+
expect_keys_llama
694+
662695
# Llama (default + fallback)
663696

664697
if arch_string != "LlamaForCausalLM" and not arch_recognized:

exllamav2/attn.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,9 @@ def __init__(
211211
if cfg.use_qk_norm:
212212
self.submodules += [self.q_norm, self.k_norm]
213213

214-
if cfg.query_pre_attn_scalar:
214+
if cfg.attention_multiplier:
215+
self.scaling = cfg.attention_multiplier
216+
elif cfg.query_pre_attn_scalar:
215217
self.scaling = cfg.query_pre_attn_scalar ** (-0.5)
216218
else:
217219
self.scaling = 1 / math.sqrt(self.head_dim)

exllamav2/config.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class ExLlamaV2Config:
115115
final_logit_softcapping: float | None
116116
attn_logit_softcapping: float | None
117117
sliding_window: int
118+
sliding_window_pattern: int
118119
norm_head: int | None
119120
l3_rope_factor: float | None
120121
l3_rope_low_freq_factor: float | None
@@ -125,6 +126,7 @@ class ExLlamaV2Config:
125126
checkpoint_fused_mlp: bool
126127
checkpoint_offset_qzeros: bool
127128
mrope_section: list | None
129+
attention_multiplier: float | None
128130

129131
vision_model_type: str | None
130132
vision_head_dim: int | None
@@ -288,6 +290,7 @@ def prepare(self, no_tensors: bool = False):
288290
self.use_qk_norm = read(read_config, bool, ["use_qk_norm"], False)
289291

290292
self.query_pre_attn_scalar = read(read_config, float, "query_pre_attn_scalar", None)
293+
self.attention_multiplier = read(read_config, float, "attention_multiplier", None)
291294

292295
# MLP params
293296

@@ -313,11 +316,17 @@ def prepare(self, no_tensors: bool = False):
313316
dim_model_base = read(read_config, int, "dim_model_base", self.hidden_size)
314317
self.logit_scale /= (self.hidden_size / dim_model_base)
315318

316-
self.scale_emb = read(read_config, float, "scale_emb", 1)
319+
logit_scaling = read(read_config, float, "logits_scaling", None) # Granite is backwards
320+
if logit_scaling:
321+
self.logit_scale = 1.0 / logit_scaling
322+
323+
self.scale_emb = read(read_config, float, ["scale_emb", "embedding_multiplier"], 1)
324+
residual_multiplier = read(read_config, float, "residual_multiplier", None)
317325
scale_depth = read(read_config, float, "scale_depth", None)
318-
if scale_depth is None:
319-
self.scale_depth = 1
320-
else:
326+
self.scale_depth = 1
327+
if residual_multiplier:
328+
self.scale_depth = residual_multiplier
329+
elif scale_depth:
321330
self.scale_depth = scale_depth / math.sqrt(self.num_hidden_layers)
322331

323332
self.attn_logit_softcapping = read(read_config, float, "attn_logit_softcapping", None)
@@ -347,6 +356,7 @@ def prepare(self, no_tensors: bool = False):
347356
self.original_max_seq_len = self.max_seq_len
348357

349358
self.sliding_window = read(read_config, int, ["sliding_window", "sliding_window_size"], 0, opt_subkey = "text_config")
359+
self.sliding_window_pattern = read(read_config, int, ["sliding_window_pattern"], 1)
350360

351361
rs = read(read_config, dict, "rope_scaling", None)
352362
if rs:
@@ -476,13 +486,14 @@ def check_keys(archparams, prefix):
476486
self.vision_num_attention_heads = read(read_config, int, ["vision_config->num_attention_heads"], no_default)
477487
self.vision_num_key_value_heads = read(read_config, int, ["vision_config->num_key_value_heads"], self.vision_num_attention_heads)
478488
self.vision_num_key_value_groups = self.vision_num_attention_heads // self.vision_num_key_value_heads
489+
self.multimodal_projector_bias = read(read_config, bool, ["multimodal_projector_bias"], True)
479490

480491
self.vision_hidden_act = read(read_config, str, ["vision_config->hidden_act"], no_default)
481-
self.vision_hidden_size = read(read_config, int, ["vision_config->image_size"], no_default)
492+
self.vision_hidden_size = read(read_config, int, ["vision_config->hidden_size"], 1024)
482493
patch_size = read(read_config, int, ["vision_config->patch_size"], no_default)
483494
self.vision_rope_theta = read(read_config, int, ["vision_config->rope_theta"], no_default)
484495
self.vision_feature_layer = read(read_config, int, ["vision_feature_layer"], no_default)
485-
self.vision_num_layers = 24
496+
self.vision_num_layers = read(read_config, int, ["vision_config->num_hidden_layers"], 24)
486497
self.vision_intermediate_size = read(read_config, int, ["vision_config->intermediate_size"], self.hidden_size)
487498

488499
image_processor_type = read(read_prep_config, str, ["image_processor_type"], no_default)

exllamav2/device.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -123,20 +123,13 @@ def prepare_sincos(self):
123123
self.cos = self.sin
124124
return
125125

126-
base = cfg.rotary_embedding_base
127-
alpha = cfg.scale_alpha_value or 1.0
128-
scale = cfg.scale_pos_emb or 1.0
129-
130-
# Alpha scaling for any rope_scaling type
131-
132-
if alpha != 1.0: base *= alpha ** (cfg.head_dim / (cfg.head_dim - 2))
133-
134126
# RoPE params
135127

136128
inv_freq, scaling_factor = rope.get_rope_params(device, cfg)
137129

138130
# Common
139131

132+
scale = cfg.scale_pos_emb or 1.0
140133
t = torch.arange(cfg.max_seq_len, device = device, dtype = torch.float32)
141134
if scale != 1.0: t /= scale
142135

0 commit comments

Comments
 (0)