1
+
2
+ import sys , os
3
+ sys .path .append (os .path .dirname (os .path .dirname (os .path .abspath (__file__ ))))
4
+
5
+ from exllamav2 import ExLlamaV2 , ExLlamaV2Tokenizer
6
+ from exllamav2 .generator .filters import ExLlamaV2Filter
7
+ from functools import lru_cache
8
+ from lmformatenforcer .integrations .exllamav2 import build_token_enforcer_tokenizer_data
9
+ from lmformatenforcer import TokenEnforcer , CharacterLevelParser
10
+ from typing import List
11
+
12
+
13
+ # Temporary wrapper for lm-format-enforcer, until the integration in LMFE itself is updated
14
+
15
+
16
+ @lru_cache (10 )
17
+ def _get_lmfe_tokenizer_data (tokenizer : ExLlamaV2Tokenizer ):
18
+ return build_token_enforcer_tokenizer_data (tokenizer )
19
+
20
+
21
+ class ExLlamaV2TokenEnforcerFilter (ExLlamaV2Filter ):
22
+
23
+ token_sequence : List [int ]
24
+
25
+ def __init__ (
26
+ self ,
27
+ model : ExLlamaV2 ,
28
+ tokenizer : ExLlamaV2Tokenizer ,
29
+ character_level_parser : CharacterLevelParser ,
30
+ ):
31
+ super ().__init__ (model , tokenizer )
32
+ tokenizer_data = _get_lmfe_tokenizer_data (tokenizer )
33
+ self .token_enforcer = TokenEnforcer (tokenizer_data , character_level_parser )
34
+ self .token_sequence = []
35
+
36
+ def begin (self , prefix_str : str ) -> None :
37
+ self .token_sequence = []
38
+
39
+ def feed (self , token ) -> None :
40
+ self .token_sequence .append (int (token [0 ][0 ]))
41
+
42
+ def next (self ):
43
+ allowed_tokens = self .token_enforcer .get_allowed_tokens (self .token_sequence )
44
+ return sorted (allowed_tokens ), []
45
+
46
+ def use_background_worker (self ):
47
+ return True
0 commit comments