@@ -70,7 +70,7 @@ def __init__(self, tokenizer: "Tokenizer", guide: Guide):
70
70
self ._seq_start_idx = None
71
71
72
72
def process_logits (
73
- self , input_ids : List [ List [ int ]] , logits : torch .Tensor
73
+ self , input_ids : torch . LongTensor , logits : torch .FloatTensor
74
74
) -> torch .Tensor :
75
75
"""Use the Guide to bias the logits before sampling the next token.
76
76
@@ -93,19 +93,32 @@ def process_logits(
93
93
94
94
for seq_ids in input_ids :
95
95
gen_ids = seq_ids [self ._seq_start_idx :]
96
- curr_state_key = hash (tuple (gen_ids ))
96
+ curr_state_key = hash (tuple (gen_ids . tolist () ))
97
97
98
98
if curr_state_key not in self ._guide_states :
99
- prev_state = self ._guide_states [hash (tuple (gen_ids [:- 1 ]))]
100
- curr_state = self .guide .get_next_state (prev_state , gen_ids [- 1 ])
99
+ prev_state = self ._guide_states [hash (tuple (gen_ids [:- 1 ]. tolist () ))]
100
+ curr_state = self .guide .get_next_state (prev_state , gen_ids [- 1 ]. item () )
101
101
self ._guide_states [curr_state_key ] = curr_state
102
102
103
103
sequence_states .append (self ._guide_states [curr_state_key ])
104
104
105
105
mask = torch .ones_like (logits , dtype = torch .bool )
106
+
107
+ allowed_tokens_batch = []
108
+ batch_indices = []
106
109
for i , guide_state in enumerate (sequence_states ):
107
- allowed_tokens = self .guide .get_next_instruction (guide_state ).tokens
108
- mask [i , allowed_tokens ] = False
110
+ allowed_tokens = self .guide .get_next_instruction (guide_state ).tokens .to (
111
+ mask .device , non_blocking = True
112
+ )
113
+ allowed_tokens_batch .append (allowed_tokens )
114
+ batch_indices .append (
115
+ torch .full_like (allowed_tokens , i )
116
+ ) # Store batch index for each allowed token
117
+
118
+ allowed_tokens_concat = torch .cat (allowed_tokens_batch )
119
+ batch_indices_concat = torch .cat (batch_indices )
120
+
121
+ mask [batch_indices_concat , allowed_tokens_concat ] = False
109
122
logits .masked_fill_ (mask , float ("-inf" ))
110
123
111
124
return logits
@@ -202,7 +215,7 @@ def __init__(self, cfg_str: str, tokenizer: "Tokenizer"):
202
215
super ().__init__ (tokenizer = tokenizer , guide = cfg_guide )
203
216
204
217
def process_logits (
205
- self , input_ids : List [ List [ int ]] , logits : torch .Tensor
218
+ self , input_ids : torch . LongTensor , logits : torch .Tensor
206
219
) -> torch .Tensor :
207
220
"""Same behavior as GuideLogitsProcessor, but uses rejection sampling"""
208
221
if self ._seq_start_idx is None :
@@ -212,11 +225,11 @@ def process_logits(
212
225
213
226
for seq_ids in input_ids :
214
227
gen_ids = seq_ids [self ._seq_start_idx :]
215
- curr_state_key = hash (tuple (gen_ids ))
228
+ curr_state_key = hash (tuple (gen_ids . tolist () ))
216
229
217
230
if curr_state_key not in self ._guide_states :
218
- prev_state = self ._guide_states [hash (tuple (gen_ids [:- 1 ]))]
219
- curr_state = self .guide .get_next_state (prev_state , gen_ids [- 1 ])
231
+ prev_state = self ._guide_states [hash (tuple (gen_ids [:- 1 ]. tolist () ))]
232
+ curr_state = self .guide .get_next_state (prev_state , gen_ids [- 1 ]. item () )
220
233
self ._guide_states [curr_state_key ] = curr_state
221
234
222
235
sequence_states .append (self ._guide_states [curr_state_key ])
0 commit comments