Skip to content

Commit e0be395

Browse files
agtsumitd2
authored andcommitted
[Frontend] OpenAI server: propagate usage accounting to FastAPI middleware layer (vllm-project#8672)
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
1 parent ea8a389 commit e0be395

File tree

3 files changed

+57
-11
lines changed

3 files changed

+57
-11
lines changed

vllm/entrypoints/openai/protocol.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ class UsageInfo(OpenAIBaseModel):
107107
completion_tokens: Optional[int] = 0
108108

109109

110+
class RequestResponseMetadata(BaseModel):
111+
request_id: str
112+
final_usage_info: Optional[UsageInfo] = None
113+
114+
110115
class JsonSchemaResponseFormat(OpenAIBaseModel):
111116
name: str
112117
description: Optional[str] = None

vllm/entrypoints/openai/serving_chat.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
ChatCompletionRequest, ChatCompletionResponse,
2323
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
2424
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
25-
DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo)
25+
DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata,
26+
ToolCall, UsageInfo)
2627
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
2728
LoRAModulePath,
2829
OpenAIServing,
@@ -175,6 +176,11 @@ async def create_chat_completion(
175176
"--enable-auto-tool-choice and --tool-call-parser to be set")
176177

177178
request_id = f"chat-{random_uuid()}"
179+
180+
request_metadata = RequestResponseMetadata(request_id=request_id)
181+
if raw_request:
182+
raw_request.state.request_metadata = request_metadata
183+
178184
try:
179185
guided_decode_logits_processor = (
180186
await self._guided_decode_logits_processor(request, tokenizer))
@@ -241,11 +247,13 @@ async def create_chat_completion(
241247
# Streaming response
242248
if request.stream:
243249
return self.chat_completion_stream_generator(
244-
request, result_generator, request_id, conversation, tokenizer)
250+
request, result_generator, request_id, conversation, tokenizer,
251+
request_metadata)
245252

246253
try:
247254
return await self.chat_completion_full_generator(
248-
request, result_generator, request_id, conversation, tokenizer)
255+
request, result_generator, request_id, conversation, tokenizer,
256+
request_metadata)
249257
except ValueError as e:
250258
# TODO: Use a vllm-specific Validation Error
251259
return self.create_error_response(str(e))
@@ -262,6 +270,7 @@ async def chat_completion_stream_generator(
262270
request_id: str,
263271
conversation: List[ConversationMessage],
264272
tokenizer: AnyTokenizer,
273+
request_metadata: RequestResponseMetadata,
265274
) -> AsyncGenerator[str, None]:
266275
model_name = self.base_model_paths[0].name
267276
created_time = int(time.time())
@@ -580,6 +589,13 @@ async def chat_completion_stream_generator(
580589
exclude_unset=True, exclude_none=True))
581590
yield f"data: {final_usage_data}\n\n"
582591

592+
# report to FastAPI middleware aggregate usage across all choices
593+
num_completion_tokens = sum(previous_num_tokens)
594+
request_metadata.final_usage_info = UsageInfo(
595+
prompt_tokens=num_prompt_tokens,
596+
completion_tokens=num_completion_tokens,
597+
total_tokens=num_prompt_tokens + num_completion_tokens)
598+
583599
except ValueError as e:
584600
# TODO: Use a vllm-specific Validation Error
585601
logger.error("error in chat completion stream generator: %s", e)
@@ -595,6 +611,7 @@ async def chat_completion_full_generator(
595611
request_id: str,
596612
conversation: List[ConversationMessage],
597613
tokenizer: AnyTokenizer,
614+
request_metadata: RequestResponseMetadata,
598615
) -> Union[ErrorResponse, ChatCompletionResponse]:
599616

600617
model_name = self.base_model_paths[0].name
@@ -714,6 +731,9 @@ async def chat_completion_full_generator(
714731
completion_tokens=num_generated_tokens,
715732
total_tokens=num_prompt_tokens + num_generated_tokens,
716733
)
734+
735+
request_metadata.final_usage_info = usage
736+
717737
response = ChatCompletionResponse(
718738
id=request_id,
719739
created=created_time,

vllm/entrypoints/openai/serving_completion.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
CompletionResponseChoice,
1919
CompletionResponseStreamChoice,
2020
CompletionStreamResponse,
21-
ErrorResponse, UsageInfo)
21+
ErrorResponse,
22+
RequestResponseMetadata,
23+
UsageInfo)
2224
# yapf: enable
2325
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
2426
LoRAModulePath,
@@ -94,6 +96,10 @@ async def create_completion(
9496
request_id = f"cmpl-{random_uuid()}"
9597
created_time = int(time.time())
9698

99+
request_metadata = RequestResponseMetadata(request_id=request_id)
100+
if raw_request:
101+
raw_request.state.request_metadata = request_metadata
102+
97103
# Schedule the request and get the result generator.
98104
generators: List[AsyncGenerator[RequestOutput, None]] = []
99105
try:
@@ -165,13 +171,15 @@ async def create_completion(
165171

166172
# Streaming response
167173
if stream:
168-
return self.completion_stream_generator(request,
169-
result_generator,
170-
request_id,
171-
created_time,
172-
model_name,
173-
num_prompts=len(prompts),
174-
tokenizer=tokenizer)
174+
return self.completion_stream_generator(
175+
request,
176+
result_generator,
177+
request_id,
178+
created_time,
179+
model_name,
180+
num_prompts=len(prompts),
181+
tokenizer=tokenizer,
182+
request_metadata=request_metadata)
175183

176184
# Non-streaming response
177185
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
@@ -198,6 +206,7 @@ async def create_completion(
198206
created_time,
199207
model_name,
200208
tokenizer,
209+
request_metadata,
201210
)
202211
except asyncio.CancelledError:
203212
return self.create_error_response("Client disconnected")
@@ -227,6 +236,7 @@ async def completion_stream_generator(
227236
model_name: str,
228237
num_prompts: int,
229238
tokenizer: AnyTokenizer,
239+
request_metadata: RequestResponseMetadata,
230240
) -> AsyncGenerator[str, None]:
231241
num_choices = 1 if request.n is None else request.n
232242
previous_text_lens = [0] * num_choices * num_prompts
@@ -346,6 +356,14 @@ async def completion_stream_generator(
346356
exclude_unset=False, exclude_none=True))
347357
yield f"data: {final_usage_data}\n\n"
348358

359+
# report to FastAPI middleware aggregate usage across all choices
360+
total_prompt_tokens = sum(num_prompt_tokens)
361+
total_completion_tokens = sum(previous_num_tokens)
362+
request_metadata.final_usage_info = UsageInfo(
363+
prompt_tokens=total_prompt_tokens,
364+
completion_tokens=total_completion_tokens,
365+
total_tokens=total_prompt_tokens + total_completion_tokens)
366+
349367
except ValueError as e:
350368
# TODO: Use a vllm-specific Validation Error
351369
data = self.create_streaming_error_response(str(e))
@@ -360,6 +378,7 @@ def request_output_to_completion_response(
360378
created_time: int,
361379
model_name: str,
362380
tokenizer: AnyTokenizer,
381+
request_metadata: RequestResponseMetadata,
363382
) -> CompletionResponse:
364383
choices: List[CompletionResponseChoice] = []
365384
num_prompt_tokens = 0
@@ -433,6 +452,8 @@ def request_output_to_completion_response(
433452
total_tokens=num_prompt_tokens + num_generated_tokens,
434453
)
435454

455+
request_metadata.final_usage_info = usage
456+
436457
return CompletionResponse(
437458
id=request_id,
438459
created=created_time,

0 commit comments

Comments
 (0)