Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logits processors: Update inplace, with batch operation #1192

Merged
merged 2 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions outlines/processors/base_logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ def __call__(

# Guarantee passed as 2D Tensors, then covert back to original (1D or 2D) shape
if len(torch_logits.shape) == 2:
processed_logits = self.process_logits(input_ids.tolist(), torch_logits)
processed_logits = self.process_logits(input_ids, torch_logits)
elif len(torch_logits.shape) == 1:
processed_logits = self.process_logits(
[input_ids.tolist()], torch_logits.unsqueeze(0)
input_ids.unsqueeze(0), torch_logits.unsqueeze(0)
).squeeze(0)

# return logits as passed array type
Expand Down
38 changes: 26 additions & 12 deletions outlines/processors/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, tokenizer: "Tokenizer", guide: Guide):
self._seq_start_idx = None

def process_logits(
self, input_ids: List[List[int]], logits: torch.Tensor
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
) -> torch.Tensor:
"""Use the Guide to bias the logits before sampling the next token.

Expand All @@ -93,21 +93,35 @@ def process_logits(

for seq_ids in input_ids:
gen_ids = seq_ids[self._seq_start_idx :]
curr_state_key = hash(tuple(gen_ids))
curr_state_key = hash(tuple(gen_ids.tolist()))

if curr_state_key not in self._guide_states:
prev_state = self._guide_states[hash(tuple(gen_ids[:-1]))]
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1])
prev_state = self._guide_states[hash(tuple(gen_ids[:-1].tolist()))]
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1].item())
self._guide_states[curr_state_key] = curr_state

sequence_states.append(self._guide_states[curr_state_key])

mask = torch.full_like(logits, -math.inf)
mask = torch.ones_like(logits, dtype=torch.bool)

allowed_tokens_batch = []
batch_indices = []
for i, guide_state in enumerate(sequence_states):
allowed_tokens = self.guide.get_next_instruction(guide_state).tokens
mask[i, allowed_tokens] = logits[i, allowed_tokens]
allowed_tokens = self.guide.get_next_instruction(guide_state).tokens.to(
mask.device, non_blocking=True
)
allowed_tokens_batch.append(allowed_tokens)
batch_indices.append(
torch.full_like(allowed_tokens, i)
) # Store batch index for each allowed token

return mask
allowed_tokens_concat = torch.cat(allowed_tokens_batch)
batch_indices_concat = torch.cat(batch_indices)

mask[batch_indices_concat, allowed_tokens_concat] = False
logits.masked_fill_(mask, float("-inf"))

return logits

def copy(self) -> "GuideLogitsProcessor":
"""Return a copy of the logits processor."""
Expand Down Expand Up @@ -201,7 +215,7 @@ def __init__(self, cfg_str: str, tokenizer: "Tokenizer"):
super().__init__(tokenizer=tokenizer, guide=cfg_guide)

def process_logits(
self, input_ids: List[List[int]], logits: torch.Tensor
self, input_ids: torch.LongTensor, logits: torch.Tensor
) -> torch.Tensor:
"""Same behavior as GuideLogitsProcessor, but uses rejection sampling"""
if self._seq_start_idx is None:
Expand All @@ -211,11 +225,11 @@ def process_logits(

for seq_ids in input_ids:
gen_ids = seq_ids[self._seq_start_idx :]
curr_state_key = hash(tuple(gen_ids))
curr_state_key = hash(tuple(gen_ids.tolist()))

if curr_state_key not in self._guide_states:
prev_state = self._guide_states[hash(tuple(gen_ids[:-1]))]
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1])
prev_state = self._guide_states[hash(tuple(gen_ids[:-1].tolist()))]
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1].item())
self._guide_states[curr_state_key] = curr_state

sequence_states.append(self._guide_states[curr_state_key])
Expand Down
Loading