Skip to content

Commit 46eff43

Browse files
committed
Merge branch 'refs/heads/dev'
2 parents 1e18e80 + 228ba34 commit 46eff43

File tree

4 files changed

+33
-8
lines changed

4 files changed

+33
-8
lines changed

exllamav2/attn.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,7 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para
840840

841841
# SDPA
842842

843-
if has_lower_right_sdpa and attn_params.is_causal() and not cfg.no_sdpa and not cfg.attn_logit_softcapping:
843+
if has_lower_right_sdpa and not cfg.no_sdpa and not cfg.attn_logit_softcapping:
844844

845845
k_states = self.repeat_kv(k_states, cfg.num_key_value_groups)
846846
v_states = self.repeat_kv(v_states, cfg.num_key_value_groups)
@@ -849,7 +849,10 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para
849849
k_states = k_states[:, :, -self.sliding_window:, :]
850850
v_states = v_states[:, :, -self.sliding_window:, :]
851851

852-
attn_mask_lr = causal_lower_right(q_len, k_states.shape[2])
852+
if attn_params.is_causal():
853+
attn_mask_lr = causal_lower_right(q_len, k_states.shape[2])
854+
else:
855+
attn_mask_lr = attn_params.get_attn_mask(q_states.device)
853856
attn_output = F.scaled_dot_product_attention(
854857
q_states,
855858
k_states,

exllamav2/generator/sampler.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -400,10 +400,17 @@ def prep_logit_filter(lf):
400400

401401
pass_tokens = None
402402
end_tokens = None
403-
for f in filters:
404403

404+
pts = []
405+
ets = []
406+
for f in filters:
405407
pt, et = f.get_next()
406-
if len(filters) > 1 and not isinstance(pt, set):
408+
if pt is not None:
409+
pts.append(pt)
410+
ets.append(et)
411+
412+
for pt, et in zip(pts, ets):
413+
if len(pts) > 1 and not isinstance(pt, set):
407414
pt, et = set(pt), set(et)
408415

409416
if pt is not None: pass_tokens = pt if pass_tokens is None else pass_tokens & pt
@@ -425,7 +432,7 @@ def prep_logit_filter(lf):
425432
if filter_prefer_eos and tokenizer.eos_token_id in pass_tokens:
426433
pass_tokens_list = [tokenizer.eos_token_id]
427434
logit_filter = prep_logit_filter(logit_filter)
428-
ext_c.logit_filter_exclusive(logit_filter, pass_tokens_list)
435+
ext_c.logit_filter_exclusive(logit_filter, [pass_tokens_list])
429436
else:
430437
logit_filter = prep_logit_filter(logit_filter)
431438
if isinstance(pass_tokens, set):

exllamav2/tokenizer/tokenizer.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@
88
ExLlamaV2TokenizerSPM,
99
ExLlamaV2TokenizerHF
1010
)
11+
import threading
12+
13+
14+
lock = threading.RLock()
15+
def synchronized_init(func):
16+
def wrapper(*args, **kwargs):
17+
with lock:
18+
return func(*args, **kwargs)
19+
return wrapper
20+
1121

1222
class ExLlamaV2Tokenizer:
1323

@@ -20,7 +30,6 @@ def __init__(self, children = None, leaf = None):
2030
self.children = children if children is not None else {}
2131
self.leaf = leaf if leaf is not None else []
2232

23-
2433
config: ExLlamaV2Config
2534

2635
tokenizer_model: ExLlamaV2TokenizerBase
@@ -567,8 +576,8 @@ def num_tokens(self, text):
567576

568577
# Get ordinals of single-byte tokens
569578

579+
@synchronized_init
570580
def get_id_to_ord_list(self):
571-
572581
if self.id_to_ord is not None: return self.id_to_ord
573582

574583
self.id_to_ord = []
@@ -594,6 +603,7 @@ def get_id_to_ord_list(self):
594603

595604
# Copy vocabulary from model
596605

606+
@synchronized_init
597607
def get_id_to_piece_list(self, include_special_tokens = False):
598608

599609
if include_special_tokens:
@@ -633,6 +643,7 @@ def get_id_to_piece_list(self, include_special_tokens = False):
633643
return self.id_to_piece
634644

635645

646+
@synchronized_init
636647
def get_piece_to_id_dict(self):
637648

638649
if self.piece_to_id is not None: return self.piece_to_id
@@ -644,6 +655,7 @@ def get_piece_to_id_dict(self):
644655

645656
# Create dictionary mapping prefixes to token IDs
646657

658+
@synchronized_init
647659
def get_prefix_to_ids_dict(self):
648660

649661
if self.prefix_to_ids is not None: return self.prefix_to_ids
@@ -671,6 +683,7 @@ def get_prefix_to_ids_dict(self):
671683

672684
# Create dictionary mapping each ID to any IDs that it prefixes
673685

686+
@synchronized_init
674687
def get_prefix_id_to_ids_dict(self):
675688

676689
if self.prefix_id_to_ids is not None: return self.prefix_id_to_ids
@@ -712,6 +725,7 @@ def _make_trie(self, ci):
712725
return trie
713726

714727

728+
@synchronized_init
715729
def get_char_trie(self):
716730

717731
if self.char_trie is not None: return self.char_trie
@@ -720,6 +734,7 @@ def get_char_trie(self):
720734
return self.char_trie
721735

722736

737+
@synchronized_init
723738
def get_char_trie_ci(self):
724739

725740
if self.char_trie_ci is not None: return self.char_trie_ci

exllamav2/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.2.1"
1+
__version__ = "0.2.2"

0 commit comments

Comments
 (0)