Skip to content

Commit fd2ca45

Browse files
authored
fix: correct issue with handling lock during streaming
move locking for streaming into get_event_publisher call so it is locked and unlocked in the correct task for the streaming reponse
1 parent 2bc1d97 commit fd2ca45

File tree

1 file changed

+91
-102
lines changed

1 file changed

+91
-102
lines changed

llama_cpp/server/app.py

Lines changed: 91 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -154,19 +154,56 @@ def create_app(
154154

155155
return app
156156

157+
def prepare_request_resources(
158+
body: CreateCompletionRequest | CreateChatCompletionRequest,
159+
llama_proxy: LlamaProxy,
160+
body_model: str,
161+
kwargs) -> llama_cpp.Llama:
162+
if llama_proxy is None:
163+
raise HTTPException(
164+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
165+
detail="Service is not available",
166+
)
167+
llama = llama_proxy(body_model)
168+
if body.logit_bias is not None:
169+
kwargs["logit_bias"] = (
170+
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
171+
if body.logit_bias_type == "tokens"
172+
else body.logit_bias
173+
)
174+
175+
if body.grammar is not None:
176+
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
177+
178+
if body.min_tokens > 0:
179+
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
180+
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
181+
)
182+
if "logits_processor" not in kwargs:
183+
kwargs["logits_processor"] = _min_tokens_logits_processor
184+
else:
185+
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
186+
return llama
187+
157188

158189
async def get_event_publisher(
159190
request: Request,
160191
inner_send_chan: MemoryObjectSendStream[typing.Any],
161-
iterator: Iterator[typing.Any],
162-
on_complete: typing.Optional[typing.Callable[[], typing.Awaitable[None]]] = None,
192+
body: CreateCompletionRequest | CreateChatCompletionRequest,
193+
body_model: str,
194+
llama_call,
195+
kwargs,
163196
):
164197
server_settings = next(get_server_settings())
165198
interrupt_requests = (
166199
server_settings.interrupt_requests if server_settings else False
167200
)
201+
exit_stack = contextlib.AsyncExitStack()
202+
llama_proxy: LlamaProxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
203+
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
168204
async with inner_send_chan:
169205
try:
206+
iterator = await run_in_threadpool(llama_call, llama, **kwargs)
170207
async for chunk in iterate_in_threadpool(iterator):
171208
await inner_send_chan.send(dict(data=json.dumps(chunk)))
172209
if await request.is_disconnected():
@@ -181,8 +218,7 @@ async def get_event_publisher(
181218
print(f"Disconnected from client (via refresh/close) {request.client}")
182219
raise e
183220
finally:
184-
if on_complete:
185-
await on_complete()
221+
await exit_stack.aclose()
186222

187223

188224
def _logit_bias_tokens_to_input_ids(
@@ -267,18 +303,11 @@ async def create_completion(
267303
request: Request,
268304
body: CreateCompletionRequest,
269305
) -> llama_cpp.Completion:
270-
exit_stack = contextlib.AsyncExitStack()
271-
llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
272-
if llama_proxy is None:
273-
raise HTTPException(
274-
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
275-
detail="Service is not available",
276-
)
277306
if isinstance(body.prompt, list):
278307
assert len(body.prompt) <= 1
279308
body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
280309

281-
llama = llama_proxy(
310+
body_model = (
282311
body.model
283312
if request.url.path != "/v1/engines/copilot-codex/completions"
284313
else "copilot-codex"
@@ -293,60 +322,42 @@ async def create_completion(
293322
}
294323
kwargs = body.model_dump(exclude=exclude)
295324

296-
if body.logit_bias is not None:
297-
kwargs["logit_bias"] = (
298-
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
299-
if body.logit_bias_type == "tokens"
300-
else body.logit_bias
301-
)
302-
303-
if body.grammar is not None:
304-
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
305-
306-
if body.min_tokens > 0:
307-
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
308-
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
309-
)
310-
if "logits_processor" not in kwargs:
311-
kwargs["logits_processor"] = _min_tokens_logits_processor
312-
else:
313-
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
314-
315-
try:
316-
iterator_or_completion: Union[
317-
llama_cpp.CreateCompletionResponse,
318-
Iterator[llama_cpp.CreateCompletionStreamResponse],
319-
] = await run_in_threadpool(llama, **kwargs)
320-
except Exception as err:
321-
await exit_stack.aclose()
322-
raise err
323-
324-
if isinstance(iterator_or_completion, Iterator):
325-
# EAFP: It's easier to ask for forgiveness than permission
326-
first_response = await run_in_threadpool(next, iterator_or_completion)
327-
328-
# If no exception was raised from first_response, we can assume that
329-
# the iterator is valid and we can use it to stream the response.
330-
def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
331-
yield first_response
332-
yield from iterator_or_completion
333-
325+
# handle streaming request
326+
if kwargs.get("stream", False):
334327
send_chan, recv_chan = anyio.create_memory_object_stream(10)
335328
return EventSourceResponse(
336329
recv_chan,
337330
data_sender_callable=partial( # type: ignore
338331
get_event_publisher,
339332
request=request,
340333
inner_send_chan=send_chan,
341-
iterator=iterator(),
342-
on_complete=exit_stack.aclose,
334+
body=body,
335+
body_model=body_model,
336+
llama_call=llama_cpp.Llama.__call__,
337+
kwargs=kwargs,
343338
),
344339
sep="\n",
345340
ping_message_factory=_ping_message_factory,
346341
)
347-
else:
342+
343+
# handle regular request
344+
exit_stack = contextlib.AsyncExitStack()
345+
llama_proxy: LlamaProxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
346+
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
347+
348+
if await request.is_disconnected():
349+
print(f"Disconnected from client (via refresh/close) before llm invoked {request.client}")
348350
await exit_stack.aclose()
349-
return iterator_or_completion
351+
raise HTTPException(
352+
status_code=status.HTTP_400_BAD_REQUEST,
353+
detail="Client closed request",
354+
)
355+
356+
try:
357+
completion: llama_cpp.CreateCompletionResponse = await run_in_threadpool(llama, **kwargs)
358+
finally:
359+
await exit_stack.aclose()
360+
return completion
350361

351362

352363
@router.post(
@@ -474,74 +485,52 @@ async def create_chat_completion(
474485
# where the dependency is cleaned up before a StreamingResponse
475486
# is complete.
476487
# https://github.com/tiangolo/fastapi/issues/11143
477-
exit_stack = contextlib.AsyncExitStack()
478-
llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
479-
if llama_proxy is None:
480-
raise HTTPException(
481-
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
482-
detail="Service is not available",
483-
)
488+
489+
body_model = body.model
484490
exclude = {
485491
"n",
486492
"logit_bias_type",
487493
"user",
488494
"min_tokens",
489495
}
490496
kwargs = body.model_dump(exclude=exclude)
491-
llama = llama_proxy(body.model)
492-
if body.logit_bias is not None:
493-
kwargs["logit_bias"] = (
494-
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
495-
if body.logit_bias_type == "tokens"
496-
else body.logit_bias
497-
)
498-
499-
if body.grammar is not None:
500-
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
501-
502-
if body.min_tokens > 0:
503-
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
504-
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
505-
)
506-
if "logits_processor" not in kwargs:
507-
kwargs["logits_processor"] = _min_tokens_logits_processor
508-
else:
509-
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
510-
511-
try:
512-
iterator_or_completion: Union[
513-
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
514-
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
515-
except Exception as err:
516-
await exit_stack.aclose()
517-
raise err
518-
519-
if isinstance(iterator_or_completion, Iterator):
520-
# EAFP: It's easier to ask for forgiveness than permission
521-
first_response = await run_in_threadpool(next, iterator_or_completion)
522-
523-
# If no exception was raised from first_response, we can assume that
524-
# the iterator is valid and we can use it to stream the response.
525-
def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
526-
yield first_response
527-
yield from iterator_or_completion
528497

498+
# handle streaming request
499+
if kwargs.get("stream", False):
529500
send_chan, recv_chan = anyio.create_memory_object_stream(10)
530501
return EventSourceResponse(
531502
recv_chan,
532503
data_sender_callable=partial( # type: ignore
533504
get_event_publisher,
534505
request=request,
535506
inner_send_chan=send_chan,
536-
iterator=iterator(),
537-
on_complete=exit_stack.aclose,
507+
body=body,
508+
body_model=body_model,
509+
llama_call=llama_cpp.Llama.create_chat_completion,
510+
kwargs=kwargs,
538511
),
539512
sep="\n",
540513
ping_message_factory=_ping_message_factory,
541514
)
542-
else:
515+
516+
# handle regular request
517+
exit_stack = contextlib.AsyncExitStack()
518+
llama_proxy: LlamaProxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
519+
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
520+
521+
if await request.is_disconnected():
522+
print(f"Disconnected from client (via refresh/close) before llm invoked {request.client}")
523+
await exit_stack.aclose()
524+
raise HTTPException(
525+
status_code=status.HTTP_400_BAD_REQUEST,
526+
detail="Client closed request",
527+
)
528+
529+
try:
530+
completion: llama_cpp.ChatCompletion = await run_in_threadpool(llama.create_chat_completion, **kwargs)
531+
finally:
543532
await exit_stack.aclose()
544-
return iterator_or_completion
533+
return completion
545534

546535

547536
@router.get(

0 commit comments

Comments
 (0)