Skip to content

Commit cc97560

Browse files
author
Mu Huai
committed
feat:add engine v1 tracing
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
1 parent 7ea6cb2 commit cc97560

File tree

6 files changed

+150
-72
lines changed

6 files changed

+150
-72
lines changed

vllm/tracing.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ class SpanAttributes:
118118
# forward, block/sync across workers, cpu-gpu sync time and sampling time.
119119
GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE = (
120120
"gen_ai.latency.time_in_model_execute")
121+
GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL = "gen_ai.latency.time_in_model_prefill"
122+
GEN_AI_LATENCY_TIME_IN_MODEL_DECODE = "gen_ai.latency.time_in_model_decode"
123+
GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE = "gen_ai.latency.time_in_model_inference"
121124

122125

123126
def contains_trace_headers(headers: Mapping[str, str]) -> bool:

vllm/v1/core/sched/scheduler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,9 @@ def update_from_output(
749749
new_logprobs=new_logprobs,
750750
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
751751
stop_reason=request.stop_reason,
752-
events=request.take_events()))
752+
events=request.take_events()),
753+
trace_headers=request.trace_headers
754+
)
753755
else:
754756
# Invariant: EngineCore returns no partial prefill outputs.
755757
assert not prompt_logprobs_tensors

vllm/v1/engine/__init__.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import enum
44
import time
55
from collections.abc import Sequence
6-
from typing import Any, Optional, Union
6+
from typing import Any, Optional, Union, Mapping
77

88
import msgspec
99

@@ -39,10 +39,10 @@ def __str__(self):
3939

4040

4141
class EngineCoreRequest(
42-
msgspec.Struct,
43-
array_like=True, # type: ignore[call-arg]
44-
omit_defaults=True, # type: ignore[call-arg]
45-
gc=False): # type: ignore[call-arg]
42+
msgspec.Struct,
43+
array_like=True, # type: ignore[call-arg]
44+
omit_defaults=True, # type: ignore[call-arg]
45+
gc=False): # type: ignore[call-arg]
4646

4747
# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
4848
# but this object is currently not playing well with msgspec
@@ -64,6 +64,8 @@ class EngineCoreRequest(
6464
# a wave finished notification is received.
6565
current_wave: int = 0
6666

67+
trace_headers: Optional[Mapping[str, str]] = None
68+
6769

6870
class EngineCoreEventType(enum.IntEnum):
6971
"""The type of engine core request event."""
@@ -91,10 +93,10 @@ def new_event(cls,
9193

9294

9395
class EngineCoreOutput(
94-
msgspec.Struct,
95-
array_like=True, # type: ignore[call-arg]
96-
omit_defaults=True, # type: ignore[call-arg]
97-
gc=False): # type: ignore[call-arg]
96+
msgspec.Struct,
97+
array_like=True, # type: ignore[call-arg]
98+
omit_defaults=True, # type: ignore[call-arg]
99+
gc=False): # type: ignore[call-arg]
98100

99101
request_id: str
100102
new_token_ids: list[int]
@@ -106,15 +108,17 @@ class EngineCoreOutput(
106108
stop_reason: Union[int, str, None] = None
107109
events: Optional[list[EngineCoreEvent]] = None
108110

111+
trace_headers: Optional[Mapping[str, str]] = None
112+
109113
@property
110114
def finished(self) -> bool:
111115
return self.finish_reason is not None
112116

113117

114118
class UtilityOutput(
115-
msgspec.Struct,
116-
array_like=True, # type: ignore[call-arg]
117-
gc=False): # type: ignore[call-arg]
119+
msgspec.Struct,
120+
array_like=True, # type: ignore[call-arg]
121+
gc=False): # type: ignore[call-arg]
118122

119123
call_id: int
120124

@@ -124,12 +128,12 @@ class UtilityOutput(
124128

125129

126130
class EngineCoreOutputs(
127-
msgspec.Struct,
128-
array_like=True, # type: ignore[call-arg]
129-
omit_defaults=True, # type: ignore[call-arg]
130-
gc=False): # type: ignore[call-arg]
131+
msgspec.Struct,
132+
array_like=True, # type: ignore[call-arg]
133+
omit_defaults=True, # type: ignore[call-arg]
134+
gc=False): # type: ignore[call-arg]
131135

132-
#NOTE(Nick): We could consider ways to make this more compact,
136+
# NOTE(Nick): We could consider ways to make this more compact,
133137
# e.g. columnwise layout
134138

135139
engine_index: int = 0

vllm/v1/engine/output_processor.py

Lines changed: 118 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from vllm.v1.engine.parallel_sampling import ParentRequest
1616
from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
1717
RequestStateStats)
18+
from vllm.config import ObservabilityConfig
19+
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
20+
init_tracer)
1821

1922

2023
class RequestOutputCollector:
@@ -64,28 +67,27 @@ def get_nowait(self) -> Optional[RequestOutput]:
6467

6568
@dataclass
6669
class OutputProcessorOutput:
67-
6870
request_outputs: list[RequestOutput]
6971
reqs_to_abort: list[str]
7072

7173

7274
class RequestState:
7375

7476
def __init__(
75-
self,
76-
request_id: str,
77-
parent_req: Optional[ParentRequest],
78-
request_index: int,
79-
lora_name: Optional[str],
80-
output_kind: RequestOutputKind,
81-
prompt: Optional[str],
82-
prompt_token_ids: list[int],
83-
logprobs_processor: LogprobsProcessor,
84-
detokenizer: IncrementalDetokenizer,
85-
max_tokens_param: Optional[int],
86-
arrival_time: float,
87-
queue: Optional[RequestOutputCollector],
88-
log_stats: bool,
77+
self,
78+
request_id: str,
79+
parent_req: Optional[ParentRequest],
80+
request_index: int,
81+
lora_name: Optional[str],
82+
output_kind: RequestOutputKind,
83+
prompt: Optional[str],
84+
prompt_token_ids: list[int],
85+
logprobs_processor: LogprobsProcessor,
86+
detokenizer: IncrementalDetokenizer,
87+
max_tokens_param: Optional[int],
88+
arrival_time: float,
89+
queue: Optional[RequestOutputCollector],
90+
log_stats: bool,
8991
):
9092
self.request_id = request_id
9193
self.parent_req = parent_req
@@ -106,14 +108,14 @@ def __init__(
106108

107109
@classmethod
108110
def from_new_request(
109-
cls,
110-
tokenizer: AnyTokenizer,
111-
request: EngineCoreRequest,
112-
prompt: Optional[str],
113-
parent_req: Optional[ParentRequest],
114-
request_index: int,
115-
queue: Optional[RequestOutputCollector],
116-
log_stats: bool,
111+
cls,
112+
tokenizer: AnyTokenizer,
113+
request: EngineCoreRequest,
114+
prompt: Optional[str],
115+
parent_req: Optional[ParentRequest],
116+
request_index: int,
117+
queue: Optional[RequestOutputCollector],
118+
log_stats: bool,
117119
) -> "RequestState":
118120
if not request.sampling_params.detokenize:
119121
tokenizer = None
@@ -142,10 +144,10 @@ def from_new_request(
142144
)
143145

144146
def make_request_output(
145-
self,
146-
new_token_ids: list[int],
147-
finish_reason: Optional[FinishReason],
148-
stop_reason: Union[int, str, None],
147+
self,
148+
new_token_ids: list[int],
149+
finish_reason: Optional[FinishReason],
150+
stop_reason: Union[int, str, None],
149151
) -> Optional[RequestOutput]:
150152

151153
finished = finish_reason is not None
@@ -170,10 +172,10 @@ def make_request_output(
170172
return self._new_request_output(request_id, outputs, finished)
171173

172174
def _new_request_output(
173-
self,
174-
request_id: str,
175-
outputs: list[CompletionOutput],
176-
finished: bool,
175+
self,
176+
request_id: str,
177+
outputs: list[CompletionOutput],
178+
finished: bool,
177179
) -> RequestOutput:
178180

179181
if self.output_kind == RequestOutputKind.DELTA:
@@ -192,10 +194,10 @@ def _new_request_output(
192194
)
193195

194196
def _new_completion_output(
195-
self,
196-
token_ids: list[int],
197-
finish_reason: Optional[FinishReason],
198-
stop_reason: Union[int, str, None],
197+
self,
198+
token_ids: list[int],
199+
finish_reason: Optional[FinishReason],
200+
stop_reason: Union[int, str, None],
199201
) -> CompletionOutput:
200202

201203
finished = finish_reason is not None
@@ -225,15 +227,26 @@ class OutputProcessor:
225227
"""Process EngineCoreOutputs into RequestOutputs."""
226228

227229
def __init__(
228-
self,
229-
tokenizer: TokenizerGroup,
230-
log_stats: bool,
230+
self,
231+
tokenizer: TokenizerGroup,
232+
log_stats: bool,
233+
observability_config: Optional[ObservabilityConfig] = None
231234
):
232235
self.log_stats = log_stats
233236
self.tokenizer = tokenizer
234237
self.request_states: dict[str, RequestState] = {}
235238
self.parent_requests: dict[str, ParentRequest] = {}
236239
self.lora_states = LoRARequestStates()
240+
self.observability_config = observability_config
241+
242+
self.tracer = None
243+
if self.observability_config is not None and self.observability_config.otlp_traces_endpoint:
244+
self.tracer = init_tracer(
245+
"vllm.llm_engine",
246+
self.observability_config.otlp_traces_endpoint)
247+
248+
def is_tracing_enabled(self) -> bool:
249+
return self.tracer is not None
237250

238251
def get_num_unfinished_requests(self):
239252
return len(self.request_states)
@@ -249,8 +262,8 @@ def propagate_error(self, e: Exception):
249262
state.queue.put(e)
250263

251264
def abort_requests(
252-
self,
253-
request_ids: Iterable[str],
265+
self,
266+
request_ids: Iterable[str],
254267
) -> list[str]:
255268
request_ids_to_abort = []
256269
for request_id in request_ids:
@@ -266,12 +279,12 @@ def abort_requests(
266279
return request_ids_to_abort
267280

268281
def add_request(
269-
self,
270-
request: EngineCoreRequest,
271-
prompt: Optional[str],
272-
parent_req: Optional[ParentRequest] = None,
273-
request_index: int = 0,
274-
queue: Optional[RequestOutputCollector] = None,
282+
self,
283+
request: EngineCoreRequest,
284+
prompt: Optional[str],
285+
parent_req: Optional[ParentRequest] = None,
286+
request_index: int = 0,
287+
queue: Optional[RequestOutputCollector] = None,
275288
) -> None:
276289
request_id = request.request_id
277290
if request_id in self.request_states:
@@ -291,10 +304,10 @@ def add_request(
291304
self.parent_requests[parent_req.request_id] = parent_req
292305

293306
def process_outputs(
294-
self,
295-
engine_core_outputs: list[EngineCoreOutput],
296-
engine_core_timestamp: Optional[float] = None,
297-
iteration_stats: Optional[IterationStats] = None,
307+
self,
308+
engine_core_outputs: list[EngineCoreOutput],
309+
engine_core_timestamp: Optional[float] = None,
310+
iteration_stats: Optional[IterationStats] = None,
298311
) -> OutputProcessorOutput:
299312
"""
300313
Process the EngineCoreOutputs:
@@ -373,14 +386,68 @@ def process_outputs(
373386
# Track per-request stats
374387
self._update_stats_from_finished(req_state, finish_reason,
375388
iteration_stats)
376-
389+
self.do_tracing(engine_core_output, req_state, iteration_stats)
377390
self.lora_states.update_iteration_stats(iteration_stats)
378391

379392
return OutputProcessorOutput(
380393
request_outputs=request_outputs,
381394
reqs_to_abort=reqs_to_abort,
382395
)
383396

397+
def do_tracing(self, engine_core_output: EngineCoreOutput,
398+
req_state: RequestState,
399+
iteration_stats: Optional[IterationStats]):
400+
if engine_core_output.finish_reason is None or iteration_stats is None:
401+
return
402+
arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
403+
404+
trace_context = extract_trace_context(engine_core_output.trace_headers)
405+
with tracer.start_as_current_span("llm_request",
406+
kind=SpanKind.SERVER,
407+
context=trace_context,
408+
start_time=arrival_time_nano_seconds) as span:
409+
metrics = req_state.stats
410+
ttft = metrics.first_token_ts - metrics.arrival_time
411+
e2e_time = time.time() - metrics.arrival_time
412+
# Queued interval is from first QUEUED event to first SCHEDULED
413+
queued_time = metrics.scheduled_ts - metrics.queued_ts
414+
415+
# Prefill interval is from first SCHEDULED to first NEW_TOKEN
416+
# Any preemptions during prefill is included in the interval
417+
prefill_time = metrics.first_token_ts - metrics.scheduled_ts
418+
419+
# Decode interval is from first NEW_TOKEN to last NEW_TOKEN
420+
# Any preemptions during decode are included
421+
decode_time = metrics.last_token_ts - metrics.first_token_ts
422+
423+
# Inference interval is from first SCHEDULED to last NEW_TOKEN
424+
# Any preemptions during prefill or decode are included
425+
inference_time = metrics.last_token_ts - metrics.scheduled_ts
426+
span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL,
427+
self.tokenizer.tokenizer_id)
428+
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
429+
req_state.request_id)
430+
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
431+
req_state.max_tokens_param)
432+
span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
433+
len(req_state.prompt_token_ids))
434+
span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
435+
metrics.num_generation_tokens)
436+
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
437+
metrics.queued_ts - metrics.arrival_time)
438+
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
439+
ttft)
440+
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E,
441+
e2e_time)
442+
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
443+
queued_time)
444+
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL,
445+
prefill_time)
446+
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE,
447+
decode_time)
448+
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE,
449+
inference_time)
450+
384451
def _update_stats_from_output(self, req_state: RequestState,
385452
engine_core_output: EngineCoreOutput,
386453
engine_core_timestamp: Optional[float],

vllm/v1/engine/processor.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,6 @@ def process_inputs(
215215
self._validate_params(params, lora_request)
216216
if priority != 0:
217217
raise ValueError("V1 does not support priority yet.")
218-
if trace_headers is not None:
219-
raise ValueError("V1 does not support tracing yet.")
220218
if prompt_adapter_request is not None:
221219
raise ValueError("V1 does not support prompt_adapter_request.")
222220

0 commit comments

Comments
 (0)