@@ -204,3 +204,104 @@ def test_sequential_parse_example(cleanup_lark_import):
204
204
205
205
if i + 1 == len (input_tokens ):
206
206
assert all (tk in next_vocab for tk in ["\n " , "\n de" , " " , " + 1" ])
207
+
208
+
209
+ # TODO: Remove once fsm_union and walk_fsm are implemented in Outlines-Core
210
+ import interegular # noqa
211
+
212
+ from outlines .fsm .parsing import fsm_union , walk_fsm # noqa
213
+
214
+
215
+ def test_outlines_interegular_union_consistency ():
216
+ fsm0 = interegular .parse_pattern (r"abc" ).to_fsm ()
217
+ fsm1 = interegular .parse_pattern (r"WXYZ" ).to_fsm ()
218
+ fsm2 = interegular .parse_pattern (r"12345" ).to_fsm ()
219
+
220
+ interegular_unioned_fsm = fsm0 | fsm1 | fsm2
221
+ outlines_unioned_fsm , _ = fsm_union ([fsm0 , fsm1 , fsm2 ])
222
+
223
+ assert list (outlines_unioned_fsm .strings ()) == list (
224
+ interegular_unioned_fsm .strings ()
225
+ )
226
+
227
+
228
+ def _reconstruct_fsms (fsm , fsms_to_trans_finals ):
229
+ """Reconstruct the original fsms for testing purposes"""
230
+ reconstructed_fsms = []
231
+ for transitions , finals , state_map in fsms_to_trans_finals .values ():
232
+ inv_state_map = {new : orig for orig , news in state_map .items () for new in news }
233
+ states = set (inv_state_map .values ())
234
+ initial = inv_state_map .get (fsm .initial ) or next (
235
+ (orig for orig , news in state_map .items () if fsm .initial in news ), None
236
+ )
237
+ finals = {inv_state_map [s ] for s in finals }
238
+
239
+ transition_map = {}
240
+ alphabet = {}
241
+ for trans_id , (from_state , to_state ) in enumerate (transitions ):
242
+ orig_from , orig_to = inv_state_map [from_state ], inv_state_map [to_state ]
243
+ # Collect symbols associated with the transition
244
+ symbols = {
245
+ symbol
246
+ for trans , dest in fsm .map .get (from_state , {}).items ()
247
+ if dest == to_state
248
+ for symbol in fsm .alphabet .by_transition .get (trans , [])
249
+ }
250
+ if symbols :
251
+ # NOTE: THIS RECONSTRUCTOR DOESNT WORK FOR MORE THAN ONE TRANSITION PER SYMBOL
252
+ assert len (symbols ) == 1
253
+ symbol = list (symbols )[0 ]
254
+ alphabet [symbol ] = trans_id
255
+ transition_map .setdefault (orig_from , {})[trans_id ] = orig_to
256
+
257
+ reconstructed_fsms .append (
258
+ interegular .fsm .FSM (
259
+ alphabet = interegular .fsm .Alphabet (alphabet ),
260
+ states = frozenset (states ),
261
+ initial = initial ,
262
+ finals = frozenset (finals ),
263
+ map = transition_map ,
264
+ __no_validation__ = True ,
265
+ )
266
+ )
267
+ return reconstructed_fsms
268
+
269
+
270
+ def test_fsm_to_trans_finals_reconstruction ():
271
+ """Assert that _fsms_to_trans_finals is correct by reconstructing original fsms"""
272
+ fsm0 = interegular .parse_pattern (r"abc" ).to_fsm ()
273
+ fsm1 = interegular .parse_pattern (r"XYZ" ).to_fsm ()
274
+ fsm2 = interegular .parse_pattern (r"12345" ).to_fsm ()
275
+
276
+ fsm , _fsms_to_trans_finals = fsm_union ([fsm0 , fsm1 , fsm2 ])
277
+
278
+ reconstructed = _reconstruct_fsms (fsm , _fsms_to_trans_finals )
279
+
280
+ # assert reconstruction equivalent
281
+ assert list (fsm0 .strings ()) == list (reconstructed [0 ].strings ())
282
+ assert list (fsm1 .strings ()) == list (reconstructed [1 ].strings ())
283
+ assert list (fsm2 .strings ()) == list (reconstructed [2 ].strings ())
284
+
285
+
286
+ def test_walk_fsm ():
287
+ fsm = interegular .parse_pattern (r"abc*d" ).to_fsm ()
288
+ # convert to BetterFSM
289
+ fsm = fsm_union ([fsm ])[0 ]
290
+
291
+ # if match, produce equivalent number of states, assert state can terminate
292
+ transitions = [fsm .alphabet [letter ] for letter in "abcccd" ]
293
+ accepted_states = walk_fsm (fsm , transitions , fsm .initial , full_match = True )
294
+ assert len (accepted_states ) == len (transitions )
295
+ assert accepted_states [- 1 ] in fsm .finals
296
+
297
+ # if no match, assert empty
298
+ accepted_states = walk_fsm (
299
+ fsm , [fsm .alphabet [letter ] for letter in "b" ], fsm .initial , full_match = True
300
+ )
301
+ assert accepted_states == []
302
+
303
+ # if full_match, but last state not present, assert empty
304
+ accepted_states = walk_fsm (
305
+ fsm , [fsm .alphabet [letter ] for letter in "abc" ], fsm .initial , full_match = True
306
+ )
307
+ assert accepted_states == []
0 commit comments