Skip to content

Commit 8c16102

Browse files
committed
update RegexGuide to conform with outlines-core
1 parent d7569ef commit 8c16102

File tree

7 files changed

+28
-24
lines changed

7 files changed

+28
-24
lines changed

benchmarks/bench_json_schema.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,4 @@ def time_json_schema_to_regex(self, schema_name):
7777
@cache_disabled()
7878
def time_json_schema_to_fsm(self, schema_name):
7979
regex = build_regex_from_schema(self.schema)
80-
RegexGuide(regex, self.tokenizer)
80+
RegexGuide.from_regex(regex, self.tokenizer)

benchmarks/bench_regex_guide.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def setup(self, pattern_name):
2525

2626
@cache_disabled()
2727
def time_regex_to_guide(self, pattern_name):
28-
RegexGuide(self.pattern, self.tokenizer)
28+
RegexGuide.from_regex(self.pattern, self.tokenizer)
2929

3030

3131
class MemoryRegexGuideBenchmark:
@@ -37,4 +37,4 @@ def setup(self, pattern_name):
3737

3838
@cache_disabled()
3939
def peakmem_regex_to_guide(self, pattern_name):
40-
RegexGuide(self.pattern, self.tokenizer)
40+
RegexGuide.from_regex(self.pattern, self.tokenizer)

outlines/fsm/guide.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def copy(self):
7474

7575

7676
@cache()
77-
def create_states_mapping(regex_string, tokenizer):
78-
return uncached_create_states_mapping(regex_string, tokenizer)
77+
def cached_create_states_mapping(regex_string, tokenizer, *args, **kwargs):
78+
return uncached_create_states_mapping(regex_string, tokenizer, *args, **kwargs)
7979

8080

8181
class RegexGuide(CoreRegexGuide):
@@ -84,15 +84,19 @@ class RegexGuide(CoreRegexGuide):
8484
CoreRegexGuide with outlines cache
8585
"""
8686

87-
def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
88-
(
89-
self.states_to_token_maps,
90-
self.empty_token_ids,
91-
fsm_finals,
92-
) = create_states_mapping(regex_string, tokenizer)
93-
self.eos_token_id = tokenizer.eos_token_id
94-
self.final_states = fsm_finals | {-1}
95-
self._cache_state_to_token_tensor()
87+
@classmethod
88+
def from_regex(
89+
cls,
90+
regex_string: str,
91+
tokenizer,
92+
**kwargs,
93+
):
94+
return super().from_regex(
95+
regex_string,
96+
tokenizer,
97+
_create_states_mapping=cached_create_states_mapping,
98+
**kwargs,
99+
)
96100

97101

98102
CFGState = collections.namedtuple("CFGState", ["parser_state", "prev_token"])

outlines/processors/structured.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
149149
tokenizer
150150
An Outlines tokenizer
151151
"""
152-
guide = RegexGuide(regex_string, tokenizer)
152+
guide = RegexGuide.from_regex(regex_string, tokenizer)
153153
super().__init__(tokenizer=tokenizer, guide=guide)
154154

155155

tests/fsm/test_guide.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def convert_token_to_string(self, token):
4343
regex_str = "[1-9]"
4444

4545
with pytest.raises(ValueError, match="The vocabulary"):
46-
RegexGuide(regex_str, MockTokenizer())
46+
RegexGuide.from_regex(regex_str, MockTokenizer())
4747

4848

4949
def test_regex():
@@ -57,7 +57,7 @@ def convert_token_to_string(self, token):
5757

5858
regex_str = "[1-9]"
5959
tokenizer = MockTokenizer()
60-
fsm = RegexGuide(regex_str, tokenizer)
60+
fsm = RegexGuide.from_regex(regex_str, tokenizer)
6161

6262
assert fsm.states_to_token_maps == {0: {1: 1}}
6363

@@ -98,7 +98,7 @@ def convert_token_to_string(self, token):
9898

9999
regex_str = "[😁-😎]"
100100
tokenizer = MockTokenizer()
101-
fsm = RegexGuide(regex_str, tokenizer)
101+
fsm = RegexGuide.from_regex(regex_str, tokenizer)
102102

103103
assert fsm.states_to_token_maps == {
104104
0: {5: 1, 4: 2},
@@ -145,7 +145,7 @@ def convert_token_to_string(self, token):
145145

146146
regex_str = " [😁-😎]"
147147
tokenizer = MockTokenizer()
148-
fsm = RegexGuide(regex_str, tokenizer)
148+
fsm = RegexGuide.from_regex(regex_str, tokenizer)
149149

150150
assert fsm.states_to_token_maps == {
151151
0: {5: 1, 10: 2},
@@ -180,7 +180,7 @@ def convert_token_to_string(self, token):
180180

181181
regex_str = r"`\n(\.\n)?`\n"
182182
tokenizer = MockTokenizer()
183-
fsm = RegexGuide(regex_str, tokenizer)
183+
fsm = RegexGuide.from_regex(regex_str, tokenizer)
184184

185185
state = fsm.get_next_state(state=4, token_id=103)
186186
assert state == 5

tests/generate/test_integration_llamacpp.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def test_RegexGuide_caching(model, temp_cache_dir):
278278
import llama_cpp
279279

280280
import outlines.caching
281-
from outlines.fsm.guide import create_states_mapping
281+
from outlines.fsm.guide import cached_create_states_mapping
282282

283283
assert outlines.caching._caching_enabled
284284

@@ -291,7 +291,7 @@ def test_RegexGuide_caching(model, temp_cache_dir):
291291
_ = cache.stats(enable=True)
292292
assert cache.statistics
293293

294-
assert create_states_mapping.__memory__ is cache
294+
assert cached_create_states_mapping.__memory__ is cache
295295

296296
generator = generate.regex(model, regex, sampler=samplers.greedy())
297297
assert cache.stats() == (0, 1)

tests/generate/test_integration_transformers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def test_transformers_use_existing_model_and_tokenizer():
494494

495495
def test_RegexGuide_caching(temp_cache_dir):
496496
import outlines.caching
497-
from outlines.fsm.guide import create_states_mapping
497+
from outlines.fsm.guide import cached_create_states_mapping
498498

499499
assert outlines.caching._caching_enabled
500500

@@ -507,7 +507,7 @@ def test_RegexGuide_caching(temp_cache_dir):
507507
_ = cache.stats(enable=True)
508508
assert cache.statistics
509509

510-
assert create_states_mapping.__memory__ is cache
510+
assert cached_create_states_mapping.__memory__ is cache
511511

512512
model = models.transformers(
513513
"hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM", device="cpu"

0 commit comments

Comments
 (0)