Skip to content

Commit 6b825a8

Browse files
committedOct 14, 2024
test fsm_union and walk_fsm
1 parent 8c16102 commit 6b825a8

File tree

1 file changed

+101
-0
lines changed

1 file changed

+101
-0
lines changed
 

Diff for: ‎tests/fsm/test_parsing.py

+101
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,104 @@ def test_sequential_parse_example(cleanup_lark_import):
204204

205205
if i + 1 == len(input_tokens):
206206
assert all(tk in next_vocab for tk in ["\n", "\nde", " ", " + 1"])
207+
208+
209+
# TODO: Remove once fsm_union and walk_fsm are implemented in Outlines-Core
210+
import interegular # noqa
211+
212+
from outlines.fsm.parsing import fsm_union, walk_fsm # noqa
213+
214+
215+
def test_outlines_interegular_union_consistency():
216+
fsm0 = interegular.parse_pattern(r"abc").to_fsm()
217+
fsm1 = interegular.parse_pattern(r"WXYZ").to_fsm()
218+
fsm2 = interegular.parse_pattern(r"12345").to_fsm()
219+
220+
interegular_unioned_fsm = fsm0 | fsm1 | fsm2
221+
outlines_unioned_fsm, _ = fsm_union([fsm0, fsm1, fsm2])
222+
223+
assert list(outlines_unioned_fsm.strings()) == list(
224+
interegular_unioned_fsm.strings()
225+
)
226+
227+
228+
def _reconstruct_fsms(fsm, fsms_to_trans_finals):
229+
"""Reconstruct the original fsms for testing purposes"""
230+
reconstructed_fsms = []
231+
for transitions, finals, state_map in fsms_to_trans_finals.values():
232+
inv_state_map = {new: orig for orig, news in state_map.items() for new in news}
233+
states = set(inv_state_map.values())
234+
initial = inv_state_map.get(fsm.initial) or next(
235+
(orig for orig, news in state_map.items() if fsm.initial in news), None
236+
)
237+
finals = {inv_state_map[s] for s in finals}
238+
239+
transition_map = {}
240+
alphabet = {}
241+
for trans_id, (from_state, to_state) in enumerate(transitions):
242+
orig_from, orig_to = inv_state_map[from_state], inv_state_map[to_state]
243+
# Collect symbols associated with the transition
244+
symbols = {
245+
symbol
246+
for trans, dest in fsm.map.get(from_state, {}).items()
247+
if dest == to_state
248+
for symbol in fsm.alphabet.by_transition.get(trans, [])
249+
}
250+
if symbols:
251+
# NOTE: THIS RECONSTRUCTOR DOESNT WORK FOR MORE THAN ONE TRANSITION PER SYMBOL
252+
assert len(symbols) == 1
253+
symbol = list(symbols)[0]
254+
alphabet[symbol] = trans_id
255+
transition_map.setdefault(orig_from, {})[trans_id] = orig_to
256+
257+
reconstructed_fsms.append(
258+
interegular.fsm.FSM(
259+
alphabet=interegular.fsm.Alphabet(alphabet),
260+
states=frozenset(states),
261+
initial=initial,
262+
finals=frozenset(finals),
263+
map=transition_map,
264+
__no_validation__=True,
265+
)
266+
)
267+
return reconstructed_fsms
268+
269+
270+
def test_fsm_to_trans_finals_reconstruction():
271+
"""Assert that _fsms_to_trans_finals is correct by reconstructing original fsms"""
272+
fsm0 = interegular.parse_pattern(r"abc").to_fsm()
273+
fsm1 = interegular.parse_pattern(r"XYZ").to_fsm()
274+
fsm2 = interegular.parse_pattern(r"12345").to_fsm()
275+
276+
fsm, _fsms_to_trans_finals = fsm_union([fsm0, fsm1, fsm2])
277+
278+
reconstructed = _reconstruct_fsms(fsm, _fsms_to_trans_finals)
279+
280+
# assert reconstruction equivalent
281+
assert list(fsm0.strings()) == list(reconstructed[0].strings())
282+
assert list(fsm1.strings()) == list(reconstructed[1].strings())
283+
assert list(fsm2.strings()) == list(reconstructed[2].strings())
284+
285+
286+
def test_walk_fsm():
287+
fsm = interegular.parse_pattern(r"abc*d").to_fsm()
288+
# convert to BetterFSM
289+
fsm = fsm_union([fsm])[0]
290+
291+
# if match, produce equivalent number of states, assert state can terminate
292+
transitions = [fsm.alphabet[letter] for letter in "abcccd"]
293+
accepted_states = walk_fsm(fsm, transitions, fsm.initial, full_match=True)
294+
assert len(accepted_states) == len(transitions)
295+
assert accepted_states[-1] in fsm.finals
296+
297+
# if no match, assert empty
298+
accepted_states = walk_fsm(
299+
fsm, [fsm.alphabet[letter] for letter in "b"], fsm.initial, full_match=True
300+
)
301+
assert accepted_states == []
302+
303+
# if full_match, but last state not present, assert empty
304+
accepted_states = walk_fsm(
305+
fsm, [fsm.alphabet[letter] for letter in "abc"], fsm.initial, full_match=True
306+
)
307+
assert accepted_states == []

0 commit comments

Comments
 (0)