Skip to content

Commit abfc4f3

Browse files
[Misc] Use dataclass for InputMetadata (#3452)
Co-authored-by: youkaichao <youkaichao@126.com>
1 parent 6b78837 commit abfc4f3

File tree

3 files changed

+24
-63
lines changed

3 files changed

+24
-63
lines changed

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import io
33
import os
44
import re
5-
import shutil
65
import subprocess
76
import warnings
87
from pathlib import Path

vllm/model_executor/input_metadata.py

Lines changed: 14 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from dataclasses import dataclass
12
from typing import Optional
23

34
import torch
45

56

7+
@dataclass
68
class InputMetadata:
79
"""Metadata for input sequences. Used in PagedAttention.
810
@@ -15,40 +17,17 @@ class InputMetadata:
1517
kv_cache_dtype: Data type to store kv cache.
1618
"""
1719

18-
def __init__(
19-
self,
20-
is_prompt: bool,
21-
slot_mapping: torch.Tensor,
22-
prompt_lens: Optional[torch.Tensor],
23-
max_seq_len: Optional[int],
24-
start_loc: Optional[torch.Tensor],
25-
max_context_len: Optional[int],
26-
context_lens: Optional[torch.Tensor],
27-
block_tables: Optional[torch.Tensor],
28-
use_cuda_graph: bool,
29-
kv_cache_dtype: str,
30-
) -> None:
31-
self.is_prompt = is_prompt
32-
self.prompt_lens = prompt_lens
33-
self.max_seq_len = max_seq_len
34-
self.start_loc = start_loc
35-
self.max_context_len = max_context_len
36-
self.slot_mapping = slot_mapping
37-
self.context_lens = context_lens
38-
self.block_tables = block_tables
39-
self.use_cuda_graph = use_cuda_graph
40-
self.kv_cache_dtype = kv_cache_dtype
20+
is_prompt: bool
21+
slot_mapping: torch.Tensor
22+
prompt_lens: Optional[torch.Tensor]
23+
max_seq_len: Optional[int]
24+
start_loc: Optional[torch.Tensor]
25+
max_context_len: Optional[int]
26+
context_lens: Optional[torch.Tensor]
27+
block_tables: Optional[torch.Tensor]
28+
use_cuda_graph: bool
29+
kv_cache_dtype: str
4130

42-
# Set during the execution of the first attention op.
43-
# FIXME(woosuk): This is a hack.
31+
def __post_init__(self):
32+
# will not appear in the __repr__ and __init__
4433
self.attn_bias = None
45-
46-
def __repr__(self) -> str:
47-
return ("InputMetadata("
48-
f"is_prompt={self.is_prompt}, "
49-
f"max_context_len={self.max_context_len}, "
50-
f"slot_mapping={self.slot_mapping}, "
51-
f"context_lens={self.context_lens}, "
52-
f"block_tables={self.block_tables}, "
53-
f"use_cuda_graph={self.use_cuda_graph}, "
54-
f"kv_cache_dtype={self.kv_cache_dtype})")

vllm/worker/model_runner.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextlib
2+
import dataclasses
23
import time
34
from typing import Dict, List, Optional, Tuple, Set, Union
45

@@ -521,45 +522,27 @@ def prepare_input_tensors(
521522
metadata_dict = {
522523
"input_tokens": input_tokens,
523524
"input_positions": input_positions,
524-
"is_prompt": input_metadata.is_prompt,
525-
"slot_mapping": input_metadata.slot_mapping,
526-
"prompt_lens": input_metadata.prompt_lens,
527-
"max_seq_len": input_metadata.max_seq_len,
528-
"start_loc": input_metadata.start_loc,
529-
"max_context_len": input_metadata.max_context_len,
530-
"context_lens": input_metadata.context_lens,
531-
"block_tables": input_metadata.block_tables,
532-
"use_cuda_graph": input_metadata.use_cuda_graph,
533-
"kv_cache_dtype": input_metadata.kv_cache_dtype,
534525
"selected_token_indices":
535526
sampling_metadata.selected_token_indices,
536527
"lora_requests": lora_requests,
537528
"lora_mapping": lora_mapping,
538529
}
530+
metadata_dict.update(dataclasses.asdict(input_metadata))
539531
broadcast_tensor_dict(metadata_dict, src=0)
540532
else:
541533
metadata_dict = broadcast_tensor_dict(src=0)
542-
input_tokens = metadata_dict["input_tokens"]
543-
input_positions = metadata_dict["input_positions"]
544-
lora_mapping = metadata_dict["lora_mapping"]
545-
lora_requests = metadata_dict["lora_requests"]
546-
input_metadata = InputMetadata(
547-
is_prompt=metadata_dict["is_prompt"],
548-
slot_mapping=metadata_dict["slot_mapping"],
549-
prompt_lens=metadata_dict["prompt_lens"],
550-
max_seq_len=metadata_dict["max_seq_len"],
551-
start_loc=metadata_dict["start_loc"],
552-
max_context_len=metadata_dict["max_context_len"],
553-
context_lens=metadata_dict["context_lens"],
554-
block_tables=metadata_dict["block_tables"],
555-
use_cuda_graph=metadata_dict["use_cuda_graph"],
556-
kv_cache_dtype=metadata_dict["kv_cache_dtype"],
557-
)
534+
input_tokens = metadata_dict.pop("input_tokens")
535+
input_positions = metadata_dict.pop("input_positions")
536+
selected_token_indices = metadata_dict.pop(
537+
"selected_token_indices")
538+
lora_mapping = metadata_dict.pop("lora_mapping")
539+
lora_requests = metadata_dict.pop("lora_requests")
540+
input_metadata = InputMetadata(**metadata_dict)
558541
sampling_metadata = SamplingMetadata(
559542
seq_groups=None,
560543
seq_data=None,
561544
prompt_lens=None,
562-
selected_token_indices=metadata_dict["selected_token_indices"],
545+
selected_token_indices=selected_token_indices,
563546
categorized_sample_indices=None,
564547
generators=None,
565548
perform_sampling=False,

0 commit comments

Comments
 (0)