@@ -365,7 +365,7 @@ def test_accept_tokens_partially(seed: int, device: str):
365
365
# Next only keep the first 2 draft tokens same as the zero temperature
366
366
# tokens. For the remaining 3 choose some other tokens. In the
367
367
# response we will expect the first 2 tokens to be the same as the
368
- # draft tokens and the rest as -1
368
+ # draft tokens and the recovered token and rest as -1
369
369
draft_token_ids_to_replace = get_draft_token_ids (
370
370
batch_size , k , vocab_size , zero_temperature_token_ids )
371
371
draft_token_ids = torch .cat (
@@ -378,6 +378,8 @@ def test_accept_tokens_partially(seed: int, device: str):
378
378
assert output_token_ids .shape [0 ] == batch_size
379
379
assert output_token_ids .shape [1 ] == (k + 1 )
380
380
assert torch .all (output_token_ids [:, :2 ] == draft_token_ids [:, :2 ])
381
+ assert torch .all (
382
+ output_token_ids [:, 2 ] == target_with_bonus_probs .argmax (- 1 )[:, 2 ])
381
383
assert torch .all (output_token_ids [:, - 3 :] == - 1 )
382
384
383
385
@@ -443,14 +445,14 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
443
445
@pytest .mark .parametrize ("seed" , list (range (10 )))
444
446
@pytest .mark .parametrize ("device" , CUDA_DEVICES )
445
447
@torch .inference_mode ()
446
- def test_replacement_token_ids (seed : int , device : str ):
448
+ def test_get_recovered_token_ids (seed : int , device : str ):
447
449
"""
448
450
Test the TypicalAcceptanceSampler's method for generating
449
451
replacement token IDs.
450
452
451
- This test verifies that the `_replacement_token_ids ` method of the
453
+ This test verifies that the `_get_recovered_token_ids ` method of the
452
454
TypicalAcceptanceSampler correctly identifies the token IDs to be used
453
- as replacements based on the target probability distribution.
455
+ as recovered token IDs based on the target probability distribution.
454
456
Specifically, it ensures that the method correctly identifies the
455
457
tokens with the highest probability for each sequence in the batch.
456
458
"""
@@ -462,10 +464,7 @@ def test_replacement_token_ids(seed: int, device: str):
462
464
typical_acceptance_sampler = get_acceptance_sampler (strict_mode = True )
463
465
typical_acceptance_sampler .init_gpu_tensors (device = device )
464
466
target_probs = torch .rand (batch_size , k , vocab_size , dtype = torch .float32 )
465
- expected_replacement_tokens = - torch .ones (
466
- (batch_size , k ), dtype = torch .long )
467
- expected_replacement_tokens [:, 0 ] = torch .argmax (target_probs [:, 0 , :],
468
- dim = 1 )
467
+ expected_replacement_tokens = torch .argmax (target_probs , dim = - 1 )
469
468
actual_replacement_tokens = (
470
- typical_acceptance_sampler ._replacement_token_ids (target_probs ))
469
+ typical_acceptance_sampler ._get_recovered_token_ids (target_probs ))
471
470
assert torch .all (expected_replacement_tokens == actual_replacement_tokens )
0 commit comments