1
- from typing import Iterator , List , Optional , Union
1
+ from typing import Iterator , List , Optional , Tuple , Union
2
2
3
3
import torch
4
4
5
- from outlines .fsm .fsm import FSMState
5
+ from outlines .fsm .fsm import FSM , FSMState
6
6
from outlines .generate .generator import sequence_generator
7
7
8
8
@@ -21,6 +21,53 @@ def __init__(
21
21
self .device = device
22
22
self .num_samples = sampler .samples
23
23
24
+ def align_prompt_tokens (
25
+ self ,
26
+ prompt_token_ids : torch .Tensor ,
27
+ attention_masks : torch .Tensor ,
28
+ fsms : List [FSM ],
29
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
30
+ """Implement token alignment for each fsm. Return the updated tokens_ids and attention_masks"""
31
+ aligned_prompts , aligned_masks = zip (
32
+ * [
33
+ fsm .align_prompt_tokens (prompt , mask )
34
+ for prompt , mask , fsm in zip (prompt_token_ids , attention_masks , fsms )
35
+ ]
36
+ )
37
+ # We have to pad some of the prompts if they are not all of the same length after this operation
38
+ max_length_aligned_prompt = max (prompt .shape [0 ] for prompt in aligned_prompts )
39
+ padded_aligned_prompts = [
40
+ torch .cat (
41
+ [
42
+ torch .full (
43
+ (max_length_aligned_prompt - prompt .shape [0 ],),
44
+ 0 ,
45
+ device = prompt_token_ids .device ,
46
+ dtype = prompt .dtype ,
47
+ ),
48
+ prompt ,
49
+ ]
50
+ )
51
+ for prompt in aligned_prompts
52
+ ]
53
+ padded_aligned_masks = [
54
+ torch .cat (
55
+ [
56
+ torch .full (
57
+ (max_length_aligned_prompt - mask .shape [0 ],),
58
+ 0 ,
59
+ device = prompt_token_ids .device ,
60
+ dtype = mask .dtype ,
61
+ ),
62
+ mask ,
63
+ ]
64
+ )
65
+ for mask in aligned_masks
66
+ ]
67
+ aligned_prompt_token_ids = torch .stack (padded_aligned_prompts )
68
+ aligned_attention_masks = torch .stack (padded_aligned_masks )
69
+ return aligned_prompt_token_ids , aligned_attention_masks
70
+
24
71
def get_generated_token_ids (
25
72
self ,
26
73
prompt_token_ids : torch .Tensor ,
@@ -189,49 +236,15 @@ def __call__(
189
236
num_samples = self .num_samples
190
237
batch_size = len (prompts )
191
238
192
- fsm_states = [FSMState (0 ) for _ in range (batch_size * num_samples )]
193
- fsms = [self .fsm .copy () for _ in range (batch_size * num_samples )]
194
-
195
239
prompt_token_ids = torch .repeat_interleave (prompt_token_ids , num_samples , dim = 0 )
196
240
attention_masks = torch .repeat_interleave (attention_masks , num_samples , dim = 0 )
197
241
198
- # Token alignment may shorten some of the prompts by removing tokens at their end.
199
- # We have to pad some of the prompts if they are not all of the same length after this operation
200
- aligned_prompts , aligned_masks = zip (
201
- * [
202
- fsm .align_prompt_tokens (prompt , mask )
203
- for prompt , mask , fsm in zip (prompt_token_ids , attention_masks , fsms )
204
- ]
242
+ fsm_states = [FSMState (0 ) for _ in range (batch_size * num_samples )]
243
+ fsms = [self .fsm .copy () for _ in range (batch_size * num_samples )]
244
+
245
+ aligned_prompt_token_ids , aligned_attention_masks = self .align_prompt_tokens (
246
+ prompt_token_ids , attention_masks , fsms
205
247
)
206
- max_length_aligned_prompt = max (prompt .shape [0 ] for prompt in aligned_prompts )
207
- padded_aligned_prompts = [
208
- torch .cat (
209
- [
210
- torch .full (
211
- (max_length_aligned_prompt - prompt .shape [0 ],),
212
- 0 ,
213
- dtype = prompt .dtype ,
214
- ),
215
- prompt ,
216
- ]
217
- )
218
- for prompt in aligned_prompts
219
- ]
220
- padded_aligned_masks = [
221
- torch .cat (
222
- [
223
- torch .full (
224
- (max_length_aligned_prompt - mask .shape [0 ],),
225
- 0 ,
226
- dtype = mask .dtype ,
227
- ),
228
- mask ,
229
- ]
230
- )
231
- for mask in aligned_masks
232
- ]
233
- aligned_prompt_token_ids = torch .stack (padded_aligned_prompts )
234
- aligned_attention_masks = torch .stack (padded_aligned_masks )
235
248
236
249
weights = torch .zeros (
237
250
(batch_size * num_samples ), dtype = torch .float , device = self .device
0 commit comments