Skip to content

Commit 094af23

Browse files
committed
construct logits mask in batch operation
1 parent 36875a0 commit 094af23

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed

outlines/processors/base_logits_processor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,10 @@ def __call__(
7575

7676
# Guarantee passed as 2D Tensors, then covert back to original (1D or 2D) shape
7777
if len(torch_logits.shape) == 2:
78-
processed_logits = self.process_logits(input_ids.tolist(), torch_logits)
78+
processed_logits = self.process_logits(input_ids, torch_logits)
7979
elif len(torch_logits.shape) == 1:
8080
processed_logits = self.process_logits(
81-
[input_ids.tolist()], torch_logits.unsqueeze(0)
81+
input_ids.unsqueeze(0), torch_logits.unsqueeze(0)
8282
).squeeze(0)
8383

8484
# return logits as passed array type

outlines/processors/structured.py

+23-10
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(self, tokenizer: "Tokenizer", guide: Guide):
7070
self._seq_start_idx = None
7171

7272
def process_logits(
73-
self, input_ids: List[List[int]], logits: torch.Tensor
73+
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
7474
) -> torch.Tensor:
7575
"""Use the Guide to bias the logits before sampling the next token.
7676
@@ -93,19 +93,32 @@ def process_logits(
9393

9494
for seq_ids in input_ids:
9595
gen_ids = seq_ids[self._seq_start_idx :]
96-
curr_state_key = hash(tuple(gen_ids))
96+
curr_state_key = hash(tuple(gen_ids.tolist()))
9797

9898
if curr_state_key not in self._guide_states:
99-
prev_state = self._guide_states[hash(tuple(gen_ids[:-1]))]
100-
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1])
99+
prev_state = self._guide_states[hash(tuple(gen_ids[:-1].tolist()))]
100+
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1].item())
101101
self._guide_states[curr_state_key] = curr_state
102102

103103
sequence_states.append(self._guide_states[curr_state_key])
104104

105105
mask = torch.ones_like(logits, dtype=torch.bool)
106+
107+
allowed_tokens_batch = []
108+
batch_indices = []
106109
for i, guide_state in enumerate(sequence_states):
107-
allowed_tokens = self.guide.get_next_instruction(guide_state).tokens
108-
mask[i, allowed_tokens] = False
110+
allowed_tokens = self.guide.get_next_instruction(guide_state).tokens.to(
111+
mask.device, non_blocking=True
112+
)
113+
allowed_tokens_batch.append(allowed_tokens)
114+
batch_indices.append(
115+
torch.full_like(allowed_tokens, i)
116+
) # Store batch index for each allowed token
117+
118+
allowed_tokens_concat = torch.cat(allowed_tokens_batch)
119+
batch_indices_concat = torch.cat(batch_indices)
120+
121+
mask[batch_indices_concat, allowed_tokens_concat] = False
109122
logits.masked_fill_(mask, float("-inf"))
110123

111124
return logits
@@ -202,7 +215,7 @@ def __init__(self, cfg_str: str, tokenizer: "Tokenizer"):
202215
super().__init__(tokenizer=tokenizer, guide=cfg_guide)
203216

204217
def process_logits(
205-
self, input_ids: List[List[int]], logits: torch.Tensor
218+
self, input_ids: torch.LongTensor, logits: torch.Tensor
206219
) -> torch.Tensor:
207220
"""Same behavior as GuideLogitsProcessor, but uses rejection sampling"""
208221
if self._seq_start_idx is None:
@@ -212,11 +225,11 @@ def process_logits(
212225

213226
for seq_ids in input_ids:
214227
gen_ids = seq_ids[self._seq_start_idx :]
215-
curr_state_key = hash(tuple(gen_ids))
228+
curr_state_key = hash(tuple(gen_ids.tolist()))
216229

217230
if curr_state_key not in self._guide_states:
218-
prev_state = self._guide_states[hash(tuple(gen_ids[:-1]))]
219-
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1])
231+
prev_state = self._guide_states[hash(tuple(gen_ids[:-1].tolist()))]
232+
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1].item())
220233
self._guide_states[curr_state_key] = curr_state
221234

222235
sequence_states.append(self._guide_states[curr_state_key])

0 commit comments

Comments
 (0)