@@ -273,7 +273,8 @@ def test_prepare_prompt(batch_size):
273
273
"unsupported for encoder/ "
274
274
"decoder models" )
275
275
@pytest .mark .parametrize ("batch_size" , BATCH_SIZES )
276
- def test_prepare_decode (batch_size ):
276
+ @pytest .mark .parametrize ("multiple_seqs_per_seq_group" , [True , False ])
277
+ def test_prepare_decode (batch_size , multiple_seqs_per_seq_group ):
277
278
'''
278
279
Test the ability of the encoder/decoder model runner subclass to
279
280
produce decode-phase model inputs & attention metadata.
@@ -288,6 +289,7 @@ def test_prepare_decode(batch_size):
288
289
Arguments:
289
290
290
291
* batch_size
292
+ * multiple_seqs_per_seq_group
291
293
* backend_name: The attention backend under test
292
294
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
293
295
'''
@@ -305,29 +307,40 @@ def test_prepare_decode(batch_size):
305
307
seq_lens : List [int ] = []
306
308
encoder_seq_lens : List [int ] = []
307
309
seq_group_metadata_list : List [SequenceGroupMetadata ] = []
308
- block_tables = {0 : [1 ]}
310
+ block_tables = {
311
+ 0 : [1 ],
312
+ 1 : [3 ]
313
+ } if multiple_seqs_per_seq_group else {
314
+ 0 : [1 ]
315
+ }
309
316
cross_block_table = [2 ]
310
317
for i in range (batch_size ):
311
318
# make sure all tokens fit into one block
312
319
seq_len = i % (model_runner .block_size - 1 ) + 1
313
- seq_lens .append (seq_len )
314
320
seq_data = SequenceData (
315
321
array (VLLM_TOKEN_ID_ARRAY_TYPE , (range (seq_len ))))
316
322
encoder_seq_len = (i + 1 ) % (model_runner .block_size - 1 ) + 1
317
- encoder_seq_lens .append (encoder_seq_len )
318
323
encoder_seq_data = SequenceData (
319
324
array (VLLM_TOKEN_ID_ARRAY_TYPE , (range (encoder_seq_len ))))
325
+
320
326
seq_group_metadata = SequenceGroupMetadata (
321
327
request_id = f"test_{ i } " ,
322
328
is_prompt = False ,
323
- seq_data = {0 : seq_data },
329
+ seq_data = {
330
+ 0 : seq_data ,
331
+ 1 : seq_data
332
+ } if multiple_seqs_per_seq_group else {0 : seq_data },
324
333
sampling_params = SamplingParams (temperature = 0 ),
325
334
block_tables = block_tables ,
326
335
encoder_seq_data = encoder_seq_data ,
327
336
cross_block_table = cross_block_table ,
328
337
)
329
338
assert seq_group_metadata .token_chunk_size == 1
330
339
seq_group_metadata_list .append (seq_group_metadata )
340
+ seq_lens .extend (
341
+ [seq_len for _ in range (len (seq_group_metadata .seq_data ))])
342
+ encoder_seq_lens .extend (
343
+ [encoder_seq_len for _ in range (len (seq_group_metadata .seq_data ))])
331
344
332
345
# Build
333
346
# * Decoder model inputs
@@ -398,19 +411,24 @@ def test_prepare_decode(batch_size):
398
411
399
412
# Verify block tables are correct for prompts
400
413
# - Decoder self-attention
401
- expected = torch .tensor (
402
- [block_tables [0 ] for _ in range (len (seq_group_metadata_list ))],
403
- dtype = torch .int32 ,
404
- device = model_runner .device )
414
+ flattened_block_tables = [
415
+ block_table for block_table in block_tables .values ()
416
+ ]
417
+ expected = torch .tensor (flattened_block_tables *
418
+ len (seq_group_metadata_list ),
419
+ dtype = torch .int32 ,
420
+ device = model_runner .device )
405
421
assert torch .equal (
406
422
attn_metadata .block_tables ,
407
423
expected ,
408
424
)
409
425
# - Encoder/decoder cross-attention
410
- expected = torch .tensor (
411
- [cross_block_table for _ in range (len (seq_group_metadata_list ))],
412
- dtype = torch .int32 ,
413
- device = model_runner .device )
426
+ expected = torch .tensor ([
427
+ cross_block_table for seq_group_metadata in seq_group_metadata_list
428
+ for _ in range (len (seq_group_metadata .seq_data ))
429
+ ],
430
+ dtype = torch .int32 ,
431
+ device = model_runner .device )
414
432
assert torch .equal (
415
433
attn_metadata .cross_block_tables ,
416
434
expected ,
@@ -474,7 +492,8 @@ def test_prepare_decode(batch_size):
474
492
475
493
476
494
@pytest .mark .parametrize ("batch_size" , list (range (1 , 257 )))
477
- def test_prepare_decode_cuda_graph (batch_size ):
495
+ @pytest .mark .parametrize ("multiple_seqs_per_seq_group" , [True , False ])
496
+ def test_prepare_decode_cuda_graph (batch_size , multiple_seqs_per_seq_group ):
478
497
"""
479
498
Tests that for encoder-decoder models with CUDA Graph capture and replay
480
499
enabled, the tensors used during the decode phase are correctly padded
@@ -489,32 +508,45 @@ def test_prepare_decode_cuda_graph(batch_size):
489
508
enable_chunked_prefill = False ,
490
509
enforce_eager = False ,
491
510
)
492
-
511
+ block_tables = {
512
+ 0 : [1 ],
513
+ 1 : [3 ]
514
+ } if multiple_seqs_per_seq_group else {
515
+ 0 : [1 ]
516
+ }
493
517
seq_lens : List [int ] = []
494
518
encoder_seq_lens : List [int ] = []
495
519
seq_group_metadata_list : List [SequenceGroupMetadata ] = []
496
- block_tables = { 0 : [ 1 ]}
520
+
497
521
cross_block_table = [2 ]
522
+ expanded_batch_size = 0
498
523
for i in range (batch_size ):
499
524
# make sure all tokens fit into one block
500
525
seq_len = i % (model_runner .block_size - 1 ) + 1
501
- seq_lens .append (seq_len )
502
526
seq_data = SequenceData (
503
527
array (VLLM_TOKEN_ID_ARRAY_TYPE , (range (seq_len ))))
504
528
encoder_seq_len = (i + 1 ) % (model_runner .block_size - 1 ) + 1
505
- encoder_seq_lens .append (encoder_seq_len )
506
529
encoder_seq_data = SequenceData (
507
530
array (VLLM_TOKEN_ID_ARRAY_TYPE , (range (encoder_seq_len ))))
508
531
seq_group_metadata = SequenceGroupMetadata (
509
532
request_id = f"test_{ i } " ,
510
533
is_prompt = False ,
511
- seq_data = {0 : seq_data },
534
+ seq_data = {
535
+ 0 : seq_data ,
536
+ 1 : seq_data
537
+ } if multiple_seqs_per_seq_group else {0 : seq_data },
512
538
sampling_params = SamplingParams (temperature = 0 ),
513
539
block_tables = block_tables ,
514
540
encoder_seq_data = encoder_seq_data ,
515
541
cross_block_table = cross_block_table ,
516
542
)
517
543
assert seq_group_metadata .token_chunk_size == 1
544
+ seq_lens .extend (
545
+ [seq_len for _ in range (len (seq_group_metadata .seq_data ))])
546
+ encoder_seq_lens .extend (
547
+ [encoder_seq_len for _ in range (len (seq_group_metadata .seq_data ))])
548
+ expanded_batch_size = expanded_batch_size + len (
549
+ seq_group_metadata .seq_data )
518
550
seq_group_metadata_list .append (seq_group_metadata )
519
551
520
552
model_input = model_runner .prepare_model_input (seq_group_metadata_list )
@@ -530,8 +562,8 @@ def test_prepare_decode_cuda_graph(batch_size):
530
562
# With CUDA Graph capture and replay enabled, the decoder and encoder
531
563
# input sequences will be padded. Create the expected padded tensors
532
564
# accordingly.
533
- graph_batch_size = _get_graph_batch_size (batch_size )
534
- cuda_graph_pad_size = graph_batch_size - batch_size
565
+ graph_batch_size = _get_graph_batch_size (expanded_batch_size )
566
+ cuda_graph_pad_size = graph_batch_size - expanded_batch_size
535
567
padded_seq_lens = seq_lens + list (itertools .repeat (1 , cuda_graph_pad_size ))
536
568
padded_encoder_seq_lens = encoder_seq_lens + list (
537
569
itertools .repeat (1 , cuda_graph_pad_size ))
@@ -560,10 +592,13 @@ def test_prepare_decode_cuda_graph(batch_size):
560
592
561
593
# Verify block tables are correct for prompts
562
594
# - Decoder self-attention. Pad the block tables as expected.
563
- expected = [block_tables [0 ] for _ in range (batch_size )]
564
- expected .extend ([[] for _ in range (cuda_graph_pad_size )])
595
+ flattened_block_tables = [
596
+ block_table for _ in range (len (seq_group_metadata_list ))
597
+ for block_table in block_tables .values ()
598
+ ]
599
+ flattened_block_tables .extend ([[] for _ in range (cuda_graph_pad_size )])
565
600
expected = make_tensor_with_pad (
566
- expected ,
601
+ flattened_block_tables ,
567
602
max_len = 64 ,
568
603
pad = 0 ,
569
604
dtype = torch .int32 ,
@@ -575,7 +610,10 @@ def test_prepare_decode_cuda_graph(batch_size):
575
610
)
576
611
# - Encoder/decoder cross-attention. Pad the cross-attention block tables
577
612
# as expected.
578
- expected = [cross_block_table for _ in range (len (seq_group_metadata_list ))]
613
+ expected = [
614
+ cross_block_table for seq_group_metadata in seq_group_metadata_list
615
+ for _ in range (len (seq_group_metadata .seq_data ))
616
+ ]
579
617
expected .extend ([[] for _ in range (cuda_graph_pad_size )])
580
618
expected = make_tensor_with_pad (
581
619
expected ,
0 commit comments