Skip to content

Commit dac0d76

Browse files
committedOct 10, 2024
update RegexGuide to conform with outlines-core
1 parent d7569ef commit dac0d76

File tree

5 files changed

+25
-20
lines changed

5 files changed

+25
-20
lines changed
 

Diff for: ‎benchmarks/bench_json_schema.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,4 @@ def time_json_schema_to_regex(self, schema_name):
7777
@cache_disabled()
7878
def time_json_schema_to_fsm(self, schema_name):
7979
regex = build_regex_from_schema(self.schema)
80-
RegexGuide(regex, self.tokenizer)
80+
RegexGuide.from_regex(regex, self.tokenizer)

Diff for: ‎benchmarks/bench_regex_guide.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def setup(self, pattern_name):
2525

2626
@cache_disabled()
2727
def time_regex_to_guide(self, pattern_name):
28-
RegexGuide(self.pattern, self.tokenizer)
28+
RegexGuide.from_regex(self.pattern, self.tokenizer)
2929

3030

3131
class MemoryRegexGuideBenchmark:
@@ -37,4 +37,4 @@ def setup(self, pattern_name):
3737

3838
@cache_disabled()
3939
def peakmem_regex_to_guide(self, pattern_name):
40-
RegexGuide(self.pattern, self.tokenizer)
40+
RegexGuide.from_regex(self.pattern, self.tokenizer)

Diff for: ‎outlines/fsm/guide.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def copy(self):
7474

7575

7676
@cache()
77-
def create_states_mapping(regex_string, tokenizer):
78-
return uncached_create_states_mapping(regex_string, tokenizer)
77+
def cached_create_states_mapping(regex_string, tokenizer, *args, **kwargs):
78+
return uncached_create_states_mapping(regex_string, tokenizer, *args, **kwargs)
7979

8080

8181
class RegexGuide(CoreRegexGuide):
@@ -84,15 +84,20 @@ class RegexGuide(CoreRegexGuide):
8484
CoreRegexGuide with outlines cache
8585
"""
8686

87-
def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
88-
(
89-
self.states_to_token_maps,
90-
self.empty_token_ids,
91-
fsm_finals,
92-
) = create_states_mapping(regex_string, tokenizer)
93-
self.eos_token_id = tokenizer.eos_token_id
94-
self.final_states = fsm_finals | {-1}
95-
self._cache_state_to_token_tensor()
87+
@classmethod
88+
def from_regex(
89+
cls,
90+
regex_string: str,
91+
tokenizer,
92+
_create_states_mapping=cached_create_states_mapping,
93+
**kwargs,
94+
):
95+
return super().from_regex(
96+
regex_string,
97+
tokenizer,
98+
_create_states_mapping=_create_states_mapping,
99+
**kwargs,
100+
)
96101

97102

98103
CFGState = collections.namedtuple("CFGState", ["parser_state", "prev_token"])

Diff for: ‎outlines/processors/structured.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
149149
tokenizer
150150
An Outlines tokenizer
151151
"""
152-
guide = RegexGuide(regex_string, tokenizer)
152+
guide = RegexGuide.from_regex(regex_string, tokenizer)
153153
super().__init__(tokenizer=tokenizer, guide=guide)
154154

155155

Diff for: ‎tests/fsm/test_guide.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def convert_token_to_string(self, token):
4343
regex_str = "[1-9]"
4444

4545
with pytest.raises(ValueError, match="The vocabulary"):
46-
RegexGuide(regex_str, MockTokenizer())
46+
RegexGuide.from_regex(regex_str, MockTokenizer())
4747

4848

4949
def test_regex():
@@ -57,7 +57,7 @@ def convert_token_to_string(self, token):
5757

5858
regex_str = "[1-9]"
5959
tokenizer = MockTokenizer()
60-
fsm = RegexGuide(regex_str, tokenizer)
60+
fsm = RegexGuide.from_regex(regex_str, tokenizer)
6161

6262
assert fsm.states_to_token_maps == {0: {1: 1}}
6363

@@ -98,7 +98,7 @@ def convert_token_to_string(self, token):
9898

9999
regex_str = "[😁-😎]"
100100
tokenizer = MockTokenizer()
101-
fsm = RegexGuide(regex_str, tokenizer)
101+
fsm = RegexGuide.from_regex(regex_str, tokenizer)
102102

103103
assert fsm.states_to_token_maps == {
104104
0: {5: 1, 4: 2},
@@ -145,7 +145,7 @@ def convert_token_to_string(self, token):
145145

146146
regex_str = " [😁-😎]"
147147
tokenizer = MockTokenizer()
148-
fsm = RegexGuide(regex_str, tokenizer)
148+
fsm = RegexGuide.from_regex(regex_str, tokenizer)
149149

150150
assert fsm.states_to_token_maps == {
151151
0: {5: 1, 10: 2},
@@ -180,7 +180,7 @@ def convert_token_to_string(self, token):
180180

181181
regex_str = r"`\n(\.\n)?`\n"
182182
tokenizer = MockTokenizer()
183-
fsm = RegexGuide(regex_str, tokenizer)
183+
fsm = RegexGuide.from_regex(regex_str, tokenizer)
184184

185185
state = fsm.get_next_state(state=4, token_id=103)
186186
assert state == 5

0 commit comments

Comments
 (0)
Failed to load comments.