Skip to content

Commit 0ed200f

Browse files
authored
Merge branch 'turboderp-org:master' into master
2 parents baaa786 + 1a80d38 commit 0ed200f

17 files changed

+369
-632
lines changed

.github/workflows/build-wheels-release-linux.yml

-347
This file was deleted.

.github/workflows/build-wheels-release-rocm62.yml

+11-141
Large diffs are not rendered by default.

.github/workflows/build-wheels-release.yml

+29-76
Large diffs are not rendered by default.

examples/multimodal_grounding_qwen.py

+73-7
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class Model:
5151
current_image: Image or None = None
5252
current_description: str
5353

54-
def __init__(self, model_directory):
54+
def __init__(self, model_directory, bbox_mode: str):
5555
self.model_directory = model_directory
5656
self.config = None
5757
self.vision_model = None
@@ -61,17 +61,22 @@ def __init__(self, model_directory):
6161
self.current_image = None
6262
self.current_emb = None
6363
self.current_description = ""
64+
bbox_funcs = {
65+
"qwen2": self.get_grounding_bb_qwen2,
66+
"qwen25": self.get_grounding_bb_qwen25,
67+
}
68+
self.bbox_func = bbox_funcs[bbox_mode]
6469

6570
def load(self):
6671
"""Load and initialize the things"""
6772
self.config = ExLlamaV2Config(self.model_directory)
68-
self.config.max_seq_len = 16384
73+
self.config.max_seq_len = 8192
6974

7075
self.vision_model = ExLlamaV2VisionTower(self.config)
7176
self.vision_model.load(progress = True)
7277

7378
self.model = ExLlamaV2(self.config)
74-
self.cache = ExLlamaV2Cache(self.model, lazy = True, max_seq_len = 16384)
79+
self.cache = ExLlamaV2Cache(self.model, lazy = True, max_seq_len = 32768)
7580
self.model.load_autosplit(self.cache, progress = True)
7681
self.tokenizer = ExLlamaV2Tokenizer(self.config)
7782

@@ -148,14 +153,21 @@ def inference(self, settext_fn, update_fn):
148153
lastupdate = time.time()
149154
settext_fn(text)
150155
update_fn()
156+
#
157+
# text = \
158+
# """And you may find yourself living in a shotgun shack
159+
# And you may find yourself in another part of the world
160+
# And you may find yourself behind the wheel of a large automobile
161+
# And you may find yourself in a beautiful house, with a beautiful wife
162+
# And you may ask yourself, "Well, how did I get here?\""""
151163

152164
settext_fn(text)
153165
update_fn()
154166
self.current_description = text
155167
print("Image description from model:")
156168
print(text)
157169

158-
def get_grounding_bb(self, start, end) -> tuple:
170+
def get_grounding_bb_qwen2(self, start, end) -> tuple:
159171
"""
160172
Prompt the model again and try to extraxt the bounding box of the image details indicated by selected portion
161173
of the description. We do this by repeating the exact same prompt up to and including the selected text, but
@@ -209,6 +221,55 @@ def get_grounding_bb(self, start, end) -> tuple:
209221

210222
return a, b
211223

224+
def get_grounding_bb_qwen25(self, start, end) -> tuple:
225+
"""
226+
Qwen2.5 works the same way, except the coordinates are no longer normalized and the format is:
227+
"(x0,y0,x1,y1)"
228+
"""
229+
230+
if start >= end:
231+
return None, None
232+
233+
# Including leading space
234+
if start > 0 and self.current_description[start - 1] == " ":
235+
start -= 1
236+
237+
# Repeat the same prompt up to the selection, with grounding tokens added
238+
prompt = self.get_prompt()
239+
prompt += self.current_description[:start]
240+
prompt += "<|object_ref_start|>"
241+
prompt += self.current_description[start:end]
242+
prompt += "<|object_ref_end|><|box_start|>("
243+
244+
bb_string, res = self.generator.generate(
245+
prompt = prompt,
246+
add_bos = True,
247+
max_new_tokens = 28,
248+
stop_conditions = [self.tokenizer.single_id("<|box_end|>")],
249+
gen_settings = ExLlamaV2Sampler.Settings.greedy(),
250+
embeddings = [self.current_emb],
251+
completion_only = True,
252+
return_last_results = True, # debug purposes
253+
)
254+
bb_string = "(" + bb_string
255+
256+
print(f"Generation: {bb_string}")
257+
pprint.pprint(res, indent = 4)
258+
259+
# BB string is in the format "(x0,y0,x1,y1)" with integer coordinates
260+
261+
s = self.current_image.size
262+
try:
263+
d = tuple(map(int, bb_string.strip("()").split(",")))
264+
a = (d[0] / s[0], d[1] / s[1])
265+
b = (d[2] / s[0], d[3] / s[1])
266+
except:
267+
print("No bounding box could be determined")
268+
a, b = None, None
269+
270+
return a, b
271+
272+
212273

213274
class GroundingDemo(QMainWindow):
214275

@@ -472,7 +533,7 @@ def on_selection_made(self, pos):
472533

473534
print(f"Selected span: {start}, {end}")
474535
print(f"Selected text: {repr(self.model.current_description[start:end])}")
475-
a, b = self.model.get_grounding_bb(start, end)
536+
a, b = self.model.bbox_func(start, end)
476537
self.image_label.set_bounding_box(a, b)
477538

478539

@@ -481,9 +542,14 @@ def on_selection_made(self, pos):
481542
# https://huggingface.co/turboderp/Qwen2-VL-7B-Instruct-exl2
482543

483544
def main():
484-
model_dir = "/mnt/str/models/qwen2-vl-7b-instruct-exl2/6.0bpw"
545+
546+
# model_dir = "/mnt/str/models/qwen2-vl-7b-instruct-exl2/6.0bpw"
547+
# bbox_mode = "qwen25"
548+
model_dir = "/mnt/str/models/qwen2.5-vl-7b-instruct-exl2/6.0bpw"
549+
bbox_mode = "qwen25"
550+
485551
app = QApplication(sys.argv)
486-
model = Model(model_dir)
552+
model = Model(model_dir, bbox_mode)
487553
model.load()
488554
window = GroundingDemo(model, model_dir)
489555
window.show()

exllamav2/architecture.py

+33-16
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ class Params:
356356

357357
# Qwen2-VL (2, 2.5)
358358

359-
if arch_string == "Qwen2VLForConditionalGeneration":
359+
if arch_string in ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]:
360360
arch_recognized = True
361361
self.lm.layer_keys += \
362362
layer_keys_llama_norms + \
@@ -368,27 +368,44 @@ class Params:
368368
self.lm.mrope = True
369369
self.lm.rope_freq_half = True
370370

371-
read_config["vision_config"].update({"model_type": "qwen2"})
372371
self.vt_prefix = "visual."
373-
self.vt.keys.update({
374-
"fused_qkv": ".attn.qkv",
375-
"attn_o": ".attn.proj",
376-
"mlp_gate": None,
377-
"mlp_up": ".mlp.fc1",
378-
"mlp_down": ".mlp.fc2",
379-
"norm_1": ".norm1",
380-
"norm_2": ".norm2",
381-
"layers": "blocks",
382-
"patch_conv": "patch_embed.proj",
383-
})
384-
self.vt.mlp_gate = False
372+
if arch_string == "Qwen2VLForConditionalGeneration":
373+
read_config["vision_config"].update({"model_type": "qwen2"})
374+
self.vt.keys.update({
375+
"fused_qkv": ".attn.qkv",
376+
"attn_o": ".attn.proj",
377+
"mlp_gate": None,
378+
"mlp_up": ".mlp.fc1",
379+
"mlp_down": ".mlp.fc2",
380+
"norm_1": ".norm1",
381+
"norm_2": ".norm2",
382+
"layers": "blocks",
383+
"patch_conv": "patch_embed.proj",
384+
})
385+
self.vt.mlp_gate = False
386+
self.vt.mlp_act_func = "quickgelu"
387+
self.vt.norm = "layernorm"
388+
elif arch_string == "Qwen2_5_VLForConditionalGeneration":
389+
read_config["vision_config"].update({"model_type": "qwen2.5"})
390+
self.vt.keys.update({
391+
"fused_qkv": ".attn.qkv",
392+
"attn_o": ".attn.proj",
393+
"mlp_gate": ".mlp.gate_proj",
394+
"mlp_up": ".mlp.up_proj",
395+
"mlp_down": ".mlp.down_proj",
396+
"norm_1": ".norm1",
397+
"norm_2": ".norm2",
398+
"layers": "blocks",
399+
"patch_conv": "patch_embed.proj",
400+
})
401+
self.vt.mlp_gate = True
402+
self.vt.mlp_act_func = "silu"
403+
self.vt.norm = "rmsnorm"
385404
self.vt.mlp_bias = True
386405
self.vt.attention_bias_qkv = True
387406
self.vt.attention_bias_o = True
388407
self.vt.vision_input_norm = False
389408
self.vt.vision_conv3d = True
390-
self.vt.mlp_act_func = "quickgelu"
391-
self.vt.norm = "layernorm"
392409

393410
self.mmp_prefix = "visual.merger."
394411
self.mmp.keys.update({

exllamav2/attn.py

+32-12
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@
4141
print(" ## Warning: Flash Attention is installed but unsupported GPUs were detected.")
4242

4343
if [2, 2, 1] <= flash_attn_ver < [2, 5, 7]:
44-
from flash_attn import flash_attn_func
44+
from flash_attn import flash_attn_func, flash_attn_varlen_func
4545
has_flash_attn = True
4646

4747
if [2, 5, 7] <= flash_attn_ver:
48-
from flash_attn import flash_attn_func, flash_attn_with_kvcache
48+
from flash_attn import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache
4949
# import flash_attn_2_cuda as flash_attn_cuda
5050

5151
signature = list(inspect.signature(flash_attn_func).parameters)
@@ -882,7 +882,9 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para
882882
k_states = k_states[:, :, -self.sliding_window:, :]
883883
v_states = v_states[:, :, -self.sliding_window:, :]
884884

885-
if attn_params.is_causal():
885+
if self.layer_idx in attn_params.block_diag_layers:
886+
attn_mask_lr = attn_params.get_block_diag_mask(q_states.device)
887+
elif attn_params.is_causal():
886888
attn_mask_lr = causal_lower_right(q_len, k_states.shape[2])
887889
else:
888890
attn_mask_lr = attn_params.get_attn_mask(q_states.device)
@@ -904,7 +906,9 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para
904906
attn_weights = torch.matmul(q_states, k_states)
905907

906908
attn_weights *= self.scaling
907-
if causal:
909+
if self.layer_idx in attn_params.block_diag_layers:
910+
attn_mask = attn_params.get_block_diag_mask(attn_weights.device)
911+
elif causal:
908912
attn_mask = attn_params.get_attn_mask(attn_weights.device)
909913

910914
if cfg.attn_logit_softcapping:
@@ -939,14 +943,30 @@ def _attn_flash(self, batch_size, q_len, q_states, k_states, v_states, attn_para
939943
if has_flash_attn_with_softcap:
940944
flash_kwargs["softcap"] = cfg.attn_logit_softcapping
941945

942-
attn_output = flash_attn_func(
943-
q_states,
944-
k_states,
945-
v_states,
946-
causal = causal,
947-
softmax_scale = self.scaling,
948-
**flash_kwargs
949-
)
946+
if self.layer_idx in attn_params.block_diag_layers:
947+
q_states = q_states.flatten(start_dim = 0, end_dim = 1)
948+
k_states = k_states.flatten(start_dim = 0, end_dim = 1)
949+
v_states = v_states.flatten(start_dim = 0, end_dim = 1)
950+
max_seqlen = attn_params.get_cu_seqlens_max()
951+
cu_seqlens = attn_params.get_cu_seqlens(self.device_idx)
952+
attn_output = flash_attn_varlen_func(
953+
q_states,
954+
k_states,
955+
v_states,
956+
cu_seqlens,
957+
cu_seqlens,
958+
max_seqlen,
959+
max_seqlen
960+
)
961+
else:
962+
attn_output = flash_attn_func(
963+
q_states,
964+
k_states,
965+
v_states,
966+
causal = causal,
967+
softmax_scale = self.scaling,
968+
**flash_kwargs
969+
)
950970
attn_output = attn_output.reshape((batch_size, q_len, self.num_attention_heads * self.head_dim))
951971
return attn_output
952972

exllamav2/attn_params.py

+34
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ class Params:
2121
alt_rope_embed_dict: dict | None
2222
rope_offsets: torch.Tensor | None
2323
non_causal_attn: bool
24+
block_diag_layers: set
25+
block_diag_mask: torch.Tensor | None
26+
cu_seqlens: torch.Tensor | None
27+
cu_seqlens_max: int | None
2428

2529
def __init__(
2630
self,
@@ -66,6 +70,11 @@ def __init__(
6670
self.past_len_tp = None
6771
self.paged = paged
6872

73+
self.block_diag_layers = set()
74+
self.block_diag_mask = None
75+
self.cu_seqlens = None
76+
self.cu_seqlens_max = None
77+
6978
def is_causal(self) -> bool:
7079
return self.input_mask is None
7180

@@ -164,6 +173,31 @@ def get_rope_offsets(self, device_idx: int) -> torch.Tensor | None:
164173
self.rope_offsets = safe_move_tensor(self.rope_offsets, device_idx, non_blocking = True)
165174
return self.rope_offsets
166175

176+
def get_cu_seqlens(self, device: int) -> torch.Tensor | None:
177+
if self.cu_seqlens is None:
178+
return None
179+
if self.cu_seqlens.device.index != device:
180+
self.cu_seqlens = safe_move_tensor(self.cu_seqlens, device, non_blocking = True)
181+
return self.cu_seqlens
182+
183+
def get_cu_seqlens_max(self) -> torch.Tensor | None:
184+
assert self.cu_seqlens is not None
185+
if self.cu_seqlens_max is not None:
186+
return self.cu_seqlens_max
187+
self.cu_seqlens_max = (self.cu_seqlens[1:] - self.cu_seqlens[:-1]).max().item()
188+
return self.cu_seqlens_max
189+
190+
def get_block_diag_mask(self, device: int) -> torch.Tensor | None:
191+
if self.block_diag_mask is None:
192+
csl = self.get_cu_seqlens(device)
193+
if csl is None:
194+
return None
195+
positions = torch.arange(csl[-1], device = csl.device)
196+
labels = torch.searchsorted(csl[1:], positions, right = True)
197+
self.block_diag_mask = labels.unsqueeze(0) == labels.unsqueeze(1).repeat(self.batch_size)
198+
if self.block_diag_mask.device.index != device:
199+
self.block_diag_mask = safe_move_tensor(self.block_diag_mask, device, non_blocking = True)
200+
return self.block_diag_mask
167201

168202

169203
class PagedParams(Params):

0 commit comments

Comments
 (0)