Skip to content

Commit 9310c40

Browse files
sroy745sumitd2
authored andcommitted
[Bugfix] [Encoder-Decoder] Bugfix for encoder specific metadata construction during decode of encoder-decoder models. (vllm-project#8545)
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
1 parent 66cadfd commit 9310c40

File tree

2 files changed

+69
-31
lines changed

2 files changed

+69
-31
lines changed

tests/worker/test_encoder_decoder_model_runner.py

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,8 @@ def test_prepare_prompt(batch_size):
273273
"unsupported for encoder/ "
274274
"decoder models")
275275
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
276-
def test_prepare_decode(batch_size):
276+
@pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False])
277+
def test_prepare_decode(batch_size, multiple_seqs_per_seq_group):
277278
'''
278279
Test the ability of the encoder/decoder model runner subclass to
279280
produce decode-phase model inputs & attention metadata.
@@ -288,6 +289,7 @@ def test_prepare_decode(batch_size):
288289
Arguments:
289290
290291
* batch_size
292+
* multiple_seqs_per_seq_group
291293
* backend_name: The attention backend under test
292294
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
293295
'''
@@ -305,29 +307,40 @@ def test_prepare_decode(batch_size):
305307
seq_lens: List[int] = []
306308
encoder_seq_lens: List[int] = []
307309
seq_group_metadata_list: List[SequenceGroupMetadata] = []
308-
block_tables = {0: [1]}
310+
block_tables = {
311+
0: [1],
312+
1: [3]
313+
} if multiple_seqs_per_seq_group else {
314+
0: [1]
315+
}
309316
cross_block_table = [2]
310317
for i in range(batch_size):
311318
# make sure all tokens fit into one block
312319
seq_len = i % (model_runner.block_size - 1) + 1
313-
seq_lens.append(seq_len)
314320
seq_data = SequenceData(
315321
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
316322
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
317-
encoder_seq_lens.append(encoder_seq_len)
318323
encoder_seq_data = SequenceData(
319324
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
325+
320326
seq_group_metadata = SequenceGroupMetadata(
321327
request_id=f"test_{i}",
322328
is_prompt=False,
323-
seq_data={0: seq_data},
329+
seq_data={
330+
0: seq_data,
331+
1: seq_data
332+
} if multiple_seqs_per_seq_group else {0: seq_data},
324333
sampling_params=SamplingParams(temperature=0),
325334
block_tables=block_tables,
326335
encoder_seq_data=encoder_seq_data,
327336
cross_block_table=cross_block_table,
328337
)
329338
assert seq_group_metadata.token_chunk_size == 1
330339
seq_group_metadata_list.append(seq_group_metadata)
340+
seq_lens.extend(
341+
[seq_len for _ in range(len(seq_group_metadata.seq_data))])
342+
encoder_seq_lens.extend(
343+
[encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))])
331344

332345
# Build
333346
# * Decoder model inputs
@@ -398,19 +411,24 @@ def test_prepare_decode(batch_size):
398411

399412
# Verify block tables are correct for prompts
400413
# - Decoder self-attention
401-
expected = torch.tensor(
402-
[block_tables[0] for _ in range(len(seq_group_metadata_list))],
403-
dtype=torch.int32,
404-
device=model_runner.device)
414+
flattened_block_tables = [
415+
block_table for block_table in block_tables.values()
416+
]
417+
expected = torch.tensor(flattened_block_tables *
418+
len(seq_group_metadata_list),
419+
dtype=torch.int32,
420+
device=model_runner.device)
405421
assert torch.equal(
406422
attn_metadata.block_tables,
407423
expected,
408424
)
409425
# - Encoder/decoder cross-attention
410-
expected = torch.tensor(
411-
[cross_block_table for _ in range(len(seq_group_metadata_list))],
412-
dtype=torch.int32,
413-
device=model_runner.device)
426+
expected = torch.tensor([
427+
cross_block_table for seq_group_metadata in seq_group_metadata_list
428+
for _ in range(len(seq_group_metadata.seq_data))
429+
],
430+
dtype=torch.int32,
431+
device=model_runner.device)
414432
assert torch.equal(
415433
attn_metadata.cross_block_tables,
416434
expected,
@@ -474,7 +492,8 @@ def test_prepare_decode(batch_size):
474492

475493

476494
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
477-
def test_prepare_decode_cuda_graph(batch_size):
495+
@pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False])
496+
def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
478497
"""
479498
Tests that for encoder-decoder models with CUDA Graph capture and replay
480499
enabled, the tensors used during the decode phase are correctly padded
@@ -489,32 +508,45 @@ def test_prepare_decode_cuda_graph(batch_size):
489508
enable_chunked_prefill=False,
490509
enforce_eager=False,
491510
)
492-
511+
block_tables = {
512+
0: [1],
513+
1: [3]
514+
} if multiple_seqs_per_seq_group else {
515+
0: [1]
516+
}
493517
seq_lens: List[int] = []
494518
encoder_seq_lens: List[int] = []
495519
seq_group_metadata_list: List[SequenceGroupMetadata] = []
496-
block_tables = {0: [1]}
520+
497521
cross_block_table = [2]
522+
expanded_batch_size = 0
498523
for i in range(batch_size):
499524
# make sure all tokens fit into one block
500525
seq_len = i % (model_runner.block_size - 1) + 1
501-
seq_lens.append(seq_len)
502526
seq_data = SequenceData(
503527
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
504528
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
505-
encoder_seq_lens.append(encoder_seq_len)
506529
encoder_seq_data = SequenceData(
507530
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
508531
seq_group_metadata = SequenceGroupMetadata(
509532
request_id=f"test_{i}",
510533
is_prompt=False,
511-
seq_data={0: seq_data},
534+
seq_data={
535+
0: seq_data,
536+
1: seq_data
537+
} if multiple_seqs_per_seq_group else {0: seq_data},
512538
sampling_params=SamplingParams(temperature=0),
513539
block_tables=block_tables,
514540
encoder_seq_data=encoder_seq_data,
515541
cross_block_table=cross_block_table,
516542
)
517543
assert seq_group_metadata.token_chunk_size == 1
544+
seq_lens.extend(
545+
[seq_len for _ in range(len(seq_group_metadata.seq_data))])
546+
encoder_seq_lens.extend(
547+
[encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))])
548+
expanded_batch_size = expanded_batch_size + len(
549+
seq_group_metadata.seq_data)
518550
seq_group_metadata_list.append(seq_group_metadata)
519551

520552
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
@@ -530,8 +562,8 @@ def test_prepare_decode_cuda_graph(batch_size):
530562
# With CUDA Graph capture and replay enabled, the decoder and encoder
531563
# input sequences will be padded. Create the expected padded tensors
532564
# accordingly.
533-
graph_batch_size = _get_graph_batch_size(batch_size)
534-
cuda_graph_pad_size = graph_batch_size - batch_size
565+
graph_batch_size = _get_graph_batch_size(expanded_batch_size)
566+
cuda_graph_pad_size = graph_batch_size - expanded_batch_size
535567
padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
536568
padded_encoder_seq_lens = encoder_seq_lens + list(
537569
itertools.repeat(1, cuda_graph_pad_size))
@@ -560,10 +592,13 @@ def test_prepare_decode_cuda_graph(batch_size):
560592

561593
# Verify block tables are correct for prompts
562594
# - Decoder self-attention. Pad the block tables as expected.
563-
expected = [block_tables[0] for _ in range(batch_size)]
564-
expected.extend([[] for _ in range(cuda_graph_pad_size)])
595+
flattened_block_tables = [
596+
block_table for _ in range(len(seq_group_metadata_list))
597+
for block_table in block_tables.values()
598+
]
599+
flattened_block_tables.extend([[] for _ in range(cuda_graph_pad_size)])
565600
expected = make_tensor_with_pad(
566-
expected,
601+
flattened_block_tables,
567602
max_len=64,
568603
pad=0,
569604
dtype=torch.int32,
@@ -575,7 +610,10 @@ def test_prepare_decode_cuda_graph(batch_size):
575610
)
576611
# - Encoder/decoder cross-attention. Pad the cross-attention block tables
577612
# as expected.
578-
expected = [cross_block_table for _ in range(len(seq_group_metadata_list))]
613+
expected = [
614+
cross_block_table for seq_group_metadata in seq_group_metadata_list
615+
for _ in range(len(seq_group_metadata.seq_data))
616+
]
579617
expected.extend([[] for _ in range(cuda_graph_pad_size)])
580618
expected = make_tensor_with_pad(
581619
expected,

vllm/worker/enc_dec_model_runner.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -435,18 +435,18 @@ def _prepare_encoder_model_input_tensors(
435435
encoder_input_tokens_tensor = self._empty_long_tensor()
436436
encoder_input_positions_tensor = self._empty_long_tensor()
437437
cross_slot_mapping_tensor = self._empty_long_tensor()
438-
439438
# Extract cross-attention block tables &
440439
# seq len from each sequence group metadata.
441440
# Cross-attention block tables are empty
442441
# during vLLM memory profiling.
443442
cross_block_tables = []
444443
for seq_group_metadata in seq_group_metadata_list:
445-
encoder_seq_lens.append(
446-
seq_group_metadata.encoder_seq_data.get_len())
447-
cross_block_table = seq_group_metadata.cross_block_table
448-
cross_block_tables.append([] if (
449-
cross_block_table is None) else cross_block_table)
444+
for _ in range(len(seq_group_metadata.seq_data)):
445+
encoder_seq_lens.append(
446+
seq_group_metadata.encoder_seq_data.get_len())
447+
cross_block_table = seq_group_metadata.cross_block_table
448+
cross_block_tables.append([] if (
449+
cross_block_table is None) else cross_block_table)
450450

451451
if (model_input.attn_metadata is not None
452452
and model_input.attn_metadata.use_cuda_graph):

0 commit comments

Comments
 (0)