Skip to content

Commit 5f7bb58

Browse files
authored
Fix typical acceptance sampler with correct recovered token ids (#8562)
1 parent b05f5c9 commit 5f7bb58

File tree

2 files changed

+17
-28
lines changed

2 files changed

+17
-28
lines changed

tests/samplers/test_typical_acceptance_sampler.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def test_accept_tokens_partially(seed: int, device: str):
365365
# Next only keep the first 2 draft tokens same as the zero temperature
366366
# tokens. For the remaining 3 choose some other tokens. In the
367367
# response we will expect the first 2 tokens to be the same as the
368-
# draft tokens and the rest as -1
368+
# draft tokens and the recovered token and rest as -1
369369
draft_token_ids_to_replace = get_draft_token_ids(
370370
batch_size, k, vocab_size, zero_temperature_token_ids)
371371
draft_token_ids = torch.cat(
@@ -378,6 +378,8 @@ def test_accept_tokens_partially(seed: int, device: str):
378378
assert output_token_ids.shape[0] == batch_size
379379
assert output_token_ids.shape[1] == (k + 1)
380380
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
381+
assert torch.all(
382+
output_token_ids[:, 2] == target_with_bonus_probs.argmax(-1)[:, 2])
381383
assert torch.all(output_token_ids[:, -3:] == -1)
382384

383385

@@ -443,14 +445,14 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
443445
@pytest.mark.parametrize("seed", list(range(10)))
444446
@pytest.mark.parametrize("device", CUDA_DEVICES)
445447
@torch.inference_mode()
446-
def test_replacement_token_ids(seed: int, device: str):
448+
def test_get_recovered_token_ids(seed: int, device: str):
447449
"""
448450
Test the TypicalAcceptanceSampler's method for generating
449451
replacement token IDs.
450452
451-
This test verifies that the `_replacement_token_ids` method of the
453+
This test verifies that the `_get_recovered_token_ids` method of the
452454
TypicalAcceptanceSampler correctly identifies the token IDs to be used
453-
as replacements based on the target probability distribution.
455+
as recovered token IDs based on the target probability distribution.
454456
Specifically, it ensures that the method correctly identifies the
455457
tokens with the highest probability for each sequence in the batch.
456458
"""
@@ -462,10 +464,7 @@ def test_replacement_token_ids(seed: int, device: str):
462464
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
463465
typical_acceptance_sampler.init_gpu_tensors(device=device)
464466
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
465-
expected_replacement_tokens = -torch.ones(
466-
(batch_size, k), dtype=torch.long)
467-
expected_replacement_tokens[:, 0] = torch.argmax(target_probs[:, 0, :],
468-
dim=1)
467+
expected_replacement_tokens = torch.argmax(target_probs, dim=-1)
469468
actual_replacement_tokens = (
470-
typical_acceptance_sampler._replacement_token_ids(target_probs))
469+
typical_acceptance_sampler._get_recovered_token_ids(target_probs))
471470
assert torch.all(expected_replacement_tokens == actual_replacement_tokens)

vllm/model_executor/layers/typical_acceptance_sampler.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def forward(
8080
target_probs = target_with_bonus_probs[:, :-1]
8181
accepted = self._evaluate_accepted_tokens(target_probs,
8282
draft_token_ids)
83-
recovered_token_ids = self._replacement_token_ids(target_probs)
83+
recovered_token_ids = self._get_recovered_token_ids(target_probs)
8484
output_token_ids = self._create_output(accepted, recovered_token_ids,
8585
draft_token_ids,
8686
bonus_token_ids)
@@ -148,16 +148,10 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids):
148148
accepted_mask = candidates_prob > threshold
149149
return accepted_mask
150150

151-
def _replacement_token_ids(self, target_probs):
151+
def _get_recovered_token_ids(self, target_probs):
152152
"""
153-
Generate one replacement token ID for each sequence based on target
154-
probabilities. The replacement token is used as the fallback option
155-
if typical acceptance sampling does not accept any draft tokens for
156-
that particular sequence.
157-
158-
This method computes the token IDs to be replaced by selecting the
159-
token with the highest probability for each sequence in the first
160-
position. The rest of the output is filled with -1.
153+
The recovered token ids will fill the first unmatched token
154+
by the target token.
161155
162156
Parameters
163157
----------
@@ -168,13 +162,9 @@ def _replacement_token_ids(self, target_probs):
168162
Returns
169163
-------
170164
torch.Tensor
171-
A tensor of shape (batch_size, k) with the replacement
172-
token IDs. Only the first column is set, and the rest of the
173-
columns are filled with -1.
165+
A tensor of shape (batch_size, k) with the recovered token
166+
ids which are selected from target probs.
174167
"""
175-
max_indices = torch.argmax(target_probs[:, 0, :], dim=1)
176-
output = -torch.ones((target_probs.shape[0], target_probs.shape[1]),
177-
dtype=self.token_id_dtype,
178-
device=target_probs.device)
179-
output[:, 0] = max_indices
180-
return output
168+
max_indices = torch.argmax(target_probs, dim=-1)
169+
170+
return max_indices

0 commit comments

Comments
 (0)