Skip to content

Commit 6f9cfc3

Browse files
authored
fix: simplify exit stack management for create_chat_completion and create_completion
1 parent fd2ca45 commit 6f9cfc3

File tree

1 file changed

+41
-53
lines changed

1 file changed

+41
-53
lines changed

llama_cpp/server/app.py

Lines changed: 41 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -198,27 +198,25 @@ async def get_event_publisher(
198198
interrupt_requests = (
199199
server_settings.interrupt_requests if server_settings else False
200200
)
201-
exit_stack = contextlib.AsyncExitStack()
202-
llama_proxy: LlamaProxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
203-
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
204-
async with inner_send_chan:
205-
try:
206-
iterator = await run_in_threadpool(llama_call, llama, **kwargs)
207-
async for chunk in iterate_in_threadpool(iterator):
208-
await inner_send_chan.send(dict(data=json.dumps(chunk)))
209-
if await request.is_disconnected():
210-
raise anyio.get_cancelled_exc_class()()
211-
if interrupt_requests and llama_outer_lock.locked():
212-
await inner_send_chan.send(dict(data="[DONE]"))
213-
raise anyio.get_cancelled_exc_class()()
214-
await inner_send_chan.send(dict(data="[DONE]"))
215-
except anyio.get_cancelled_exc_class() as e:
216-
print("disconnected")
217-
with anyio.move_on_after(1, shield=True):
218-
print(f"Disconnected from client (via refresh/close) {request.client}")
219-
raise e
220-
finally:
221-
await exit_stack.aclose()
201+
async with contextlib.AsyncExitStack() as exit_stack:
202+
llama_proxy: LlamaProxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
203+
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
204+
async with inner_send_chan:
205+
try:
206+
iterator = await run_in_threadpool(llama_call, llama, **kwargs)
207+
async for chunk in iterate_in_threadpool(iterator):
208+
await inner_send_chan.send(dict(data=json.dumps(chunk)))
209+
if await request.is_disconnected():
210+
raise anyio.get_cancelled_exc_class()()
211+
if interrupt_requests and llama_outer_lock.locked():
212+
await inner_send_chan.send(dict(data="[DONE]"))
213+
raise anyio.get_cancelled_exc_class()()
214+
await inner_send_chan.send(dict(data="[DONE]"))
215+
except anyio.get_cancelled_exc_class() as e:
216+
print("disconnected")
217+
with anyio.move_on_after(1, shield=True):
218+
print(f"Disconnected from client (via refresh/close) {request.client}")
219+
raise e
222220

223221

224222
def _logit_bias_tokens_to_input_ids(
@@ -341,23 +339,18 @@ async def create_completion(
341339
)
342340

343341
# handle regular request
344-
exit_stack = contextlib.AsyncExitStack()
345-
llama_proxy: LlamaProxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
346-
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
342+
async with contextlib.AsyncExitStack() as exit_stack:
343+
llama_proxy: LlamaProxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
344+
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
347345

348-
if await request.is_disconnected():
349-
print(f"Disconnected from client (via refresh/close) before llm invoked {request.client}")
350-
await exit_stack.aclose()
351-
raise HTTPException(
352-
status_code=status.HTTP_400_BAD_REQUEST,
353-
detail="Client closed request",
354-
)
346+
if await request.is_disconnected():
347+
print(f"Disconnected from client (via refresh/close) before llm invoked {request.client}")
348+
raise HTTPException(
349+
status_code=status.HTTP_400_BAD_REQUEST,
350+
detail="Client closed request",
351+
)
355352

356-
try:
357-
completion: llama_cpp.CreateCompletionResponse = await run_in_threadpool(llama, **kwargs)
358-
finally:
359-
await exit_stack.aclose()
360-
return completion
353+
return await run_in_threadpool(llama, **kwargs)
361354

362355

363356
@router.post(
@@ -514,23 +507,18 @@ async def create_chat_completion(
514507
)
515508

516509
# handle regular request
517-
exit_stack = contextlib.AsyncExitStack()
518-
llama_proxy: LlamaProxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
519-
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
520-
521-
if await request.is_disconnected():
522-
print(f"Disconnected from client (via refresh/close) before llm invoked {request.client}")
523-
await exit_stack.aclose()
524-
raise HTTPException(
525-
status_code=status.HTTP_400_BAD_REQUEST,
526-
detail="Client closed request",
527-
)
528-
529-
try:
530-
completion: llama_cpp.ChatCompletion = await run_in_threadpool(llama.create_chat_completion, **kwargs)
531-
finally:
532-
await exit_stack.aclose()
533-
return completion
510+
with contextlib.AsyncExitStack() as exit_stack:
511+
llama_proxy: LlamaProxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
512+
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
513+
514+
if await request.is_disconnected():
515+
print(f"Disconnected from client (via refresh/close) before llm invoked {request.client}")
516+
raise HTTPException(
517+
status_code=status.HTTP_400_BAD_REQUEST,
518+
detail="Client closed request",
519+
)
520+
521+
return await run_in_threadpool(llama.create_chat_completion, **kwargs)
534522

535523

536524
@router.get(

0 commit comments

Comments
 (0)