|
1 | 1 | import contextlib
|
| 2 | +import dataclasses |
2 | 3 | import time
|
3 | 4 | from typing import Dict, List, Optional, Tuple, Set, Union
|
4 | 5 |
|
@@ -521,45 +522,27 @@ def prepare_input_tensors(
|
521 | 522 | metadata_dict = {
|
522 | 523 | "input_tokens": input_tokens,
|
523 | 524 | "input_positions": input_positions,
|
524 |
| - "is_prompt": input_metadata.is_prompt, |
525 |
| - "slot_mapping": input_metadata.slot_mapping, |
526 |
| - "prompt_lens": input_metadata.prompt_lens, |
527 |
| - "max_seq_len": input_metadata.max_seq_len, |
528 |
| - "start_loc": input_metadata.start_loc, |
529 |
| - "max_context_len": input_metadata.max_context_len, |
530 |
| - "context_lens": input_metadata.context_lens, |
531 |
| - "block_tables": input_metadata.block_tables, |
532 |
| - "use_cuda_graph": input_metadata.use_cuda_graph, |
533 |
| - "kv_cache_dtype": input_metadata.kv_cache_dtype, |
534 | 525 | "selected_token_indices":
|
535 | 526 | sampling_metadata.selected_token_indices,
|
536 | 527 | "lora_requests": lora_requests,
|
537 | 528 | "lora_mapping": lora_mapping,
|
538 | 529 | }
|
| 530 | + metadata_dict.update(dataclasses.asdict(input_metadata)) |
539 | 531 | broadcast_tensor_dict(metadata_dict, src=0)
|
540 | 532 | else:
|
541 | 533 | metadata_dict = broadcast_tensor_dict(src=0)
|
542 |
| - input_tokens = metadata_dict["input_tokens"] |
543 |
| - input_positions = metadata_dict["input_positions"] |
544 |
| - lora_mapping = metadata_dict["lora_mapping"] |
545 |
| - lora_requests = metadata_dict["lora_requests"] |
546 |
| - input_metadata = InputMetadata( |
547 |
| - is_prompt=metadata_dict["is_prompt"], |
548 |
| - slot_mapping=metadata_dict["slot_mapping"], |
549 |
| - prompt_lens=metadata_dict["prompt_lens"], |
550 |
| - max_seq_len=metadata_dict["max_seq_len"], |
551 |
| - start_loc=metadata_dict["start_loc"], |
552 |
| - max_context_len=metadata_dict["max_context_len"], |
553 |
| - context_lens=metadata_dict["context_lens"], |
554 |
| - block_tables=metadata_dict["block_tables"], |
555 |
| - use_cuda_graph=metadata_dict["use_cuda_graph"], |
556 |
| - kv_cache_dtype=metadata_dict["kv_cache_dtype"], |
557 |
| - ) |
| 534 | + input_tokens = metadata_dict.pop("input_tokens") |
| 535 | + input_positions = metadata_dict.pop("input_positions") |
| 536 | + selected_token_indices = metadata_dict.pop( |
| 537 | + "selected_token_indices") |
| 538 | + lora_mapping = metadata_dict.pop("lora_mapping") |
| 539 | + lora_requests = metadata_dict.pop("lora_requests") |
| 540 | + input_metadata = InputMetadata(**metadata_dict) |
558 | 541 | sampling_metadata = SamplingMetadata(
|
559 | 542 | seq_groups=None,
|
560 | 543 | seq_data=None,
|
561 | 544 | prompt_lens=None,
|
562 |
| - selected_token_indices=metadata_dict["selected_token_indices"], |
| 545 | + selected_token_indices=selected_token_indices, |
563 | 546 | categorized_sample_indices=None,
|
564 | 547 | generators=None,
|
565 | 548 | perform_sampling=False,
|
|
0 commit comments