From fd2ca45f859631a544bf543eb1be266eb2216356 Mon Sep 17 00:00:00 2001 From: Graeme Power Date: Mon, 23 Dec 2024 13:06:53 +0000 Subject: [PATCH 1/4] fix: correct issue with handling lock during streaming move locking for streaming into get_event_publisher call so it is locked and unlocked in the correct task for the streaming reponse --- llama_cpp/server/app.py | 193 +++++++++++++++++++--------------------- 1 file changed, 91 insertions(+), 102 deletions(-) diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index b6db453b8..d4ab23ac0 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -154,19 +154,56 @@ def create_app( return app +def prepare_request_resources( + body: CreateCompletionRequest | CreateChatCompletionRequest, + llama_proxy: LlamaProxy, + body_model: str, + kwargs) -> llama_cpp.Llama: + if llama_proxy is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Service is not available", + ) + llama = llama_proxy(body_model) + if body.logit_bias is not None: + kwargs["logit_bias"] = ( + _logit_bias_tokens_to_input_ids(llama, body.logit_bias) + if body.logit_bias_type == "tokens" + else body.logit_bias + ) + + if body.grammar is not None: + kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar) + + if body.min_tokens > 0: + _min_tokens_logits_processor = llama_cpp.LogitsProcessorList( + [llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())] + ) + if "logits_processor" not in kwargs: + kwargs["logits_processor"] = _min_tokens_logits_processor + else: + kwargs["logits_processor"].extend(_min_tokens_logits_processor) + return llama + async def get_event_publisher( request: Request, inner_send_chan: MemoryObjectSendStream[typing.Any], - iterator: Iterator[typing.Any], - on_complete: typing.Optional[typing.Callable[[], typing.Awaitable[None]]] = None, + body: CreateCompletionRequest | CreateChatCompletionRequest, + body_model: str, + llama_call, + kwargs, ): server_settings = next(get_server_settings()) interrupt_requests = ( server_settings.interrupt_requests if server_settings else False ) + exit_stack = contextlib.AsyncExitStack() + llama_proxy: LlamaProxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)()) + llama = prepare_request_resources(body, llama_proxy, body_model, kwargs) async with inner_send_chan: try: + iterator = await run_in_threadpool(llama_call, llama, **kwargs) async for chunk in iterate_in_threadpool(iterator): await inner_send_chan.send(dict(data=json.dumps(chunk))) if await request.is_disconnected(): @@ -181,8 +218,7 @@ async def get_event_publisher( print(f"Disconnected from client (via refresh/close) {request.client}") raise e finally: - if on_complete: - await on_complete() + await exit_stack.aclose() def _logit_bias_tokens_to_input_ids( @@ -267,18 +303,11 @@ async def create_completion( request: Request, body: CreateCompletionRequest, ) -> llama_cpp.Completion: - exit_stack = contextlib.AsyncExitStack() - llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)()) - if llama_proxy is None: - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Service is not available", - ) if isinstance(body.prompt, list): assert len(body.prompt) <= 1 body.prompt = body.prompt[0] if len(body.prompt) > 0 else "" - llama = llama_proxy( + body_model = ( body.model if request.url.path != "/v1/engines/copilot-codex/completions" else "copilot-codex" @@ -293,44 +322,8 @@ async def create_completion( } kwargs = body.model_dump(exclude=exclude) - if body.logit_bias is not None: - kwargs["logit_bias"] = ( - _logit_bias_tokens_to_input_ids(llama, body.logit_bias) - if body.logit_bias_type == "tokens" - else body.logit_bias - ) - - if body.grammar is not None: - kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar) - - if body.min_tokens > 0: - _min_tokens_logits_processor = llama_cpp.LogitsProcessorList( - [llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())] - ) - if "logits_processor" not in kwargs: - kwargs["logits_processor"] = _min_tokens_logits_processor - else: - kwargs["logits_processor"].extend(_min_tokens_logits_processor) - - try: - iterator_or_completion: Union[ - llama_cpp.CreateCompletionResponse, - Iterator[llama_cpp.CreateCompletionStreamResponse], - ] = await run_in_threadpool(llama, **kwargs) - except Exception as err: - await exit_stack.aclose() - raise err - - if isinstance(iterator_or_completion, Iterator): - # EAFP: It's easier to ask for forgiveness than permission - first_response = await run_in_threadpool(next, iterator_or_completion) - - # If no exception was raised from first_response, we can assume that - # the iterator is valid and we can use it to stream the response. - def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]: - yield first_response - yield from iterator_or_completion - + # handle streaming request + if kwargs.get("stream", False): send_chan, recv_chan = anyio.create_memory_object_stream(10) return EventSourceResponse( recv_chan, @@ -338,15 +331,33 @@ def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]: get_event_publisher, request=request, inner_send_chan=send_chan, - iterator=iterator(), - on_complete=exit_stack.aclose, + body=body, + body_model=body_model, + llama_call=llama_cpp.Llama.__call__, + kwargs=kwargs, ), sep="\n", ping_message_factory=_ping_message_factory, ) - else: + + # handle regular request + exit_stack = contextlib.AsyncExitStack() + llama_proxy: LlamaProxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)()) + llama = prepare_request_resources(body, llama_proxy, body_model, kwargs) + + if await request.is_disconnected(): + print(f"Disconnected from client (via refresh/close) before llm invoked {request.client}") await exit_stack.aclose() - return iterator_or_completion + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Client closed request", + ) + + try: + completion: llama_cpp.CreateCompletionResponse = await run_in_threadpool(llama, **kwargs) + finally: + await exit_stack.aclose() + return completion @router.post( @@ -474,13 +485,8 @@ async def create_chat_completion( # where the dependency is cleaned up before a StreamingResponse # is complete. # https://github.com/tiangolo/fastapi/issues/11143 - exit_stack = contextlib.AsyncExitStack() - llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)()) - if llama_proxy is None: - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Service is not available", - ) + + body_model = body.model exclude = { "n", "logit_bias_type", @@ -488,44 +494,9 @@ async def create_chat_completion( "min_tokens", } kwargs = body.model_dump(exclude=exclude) - llama = llama_proxy(body.model) - if body.logit_bias is not None: - kwargs["logit_bias"] = ( - _logit_bias_tokens_to_input_ids(llama, body.logit_bias) - if body.logit_bias_type == "tokens" - else body.logit_bias - ) - - if body.grammar is not None: - kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar) - - if body.min_tokens > 0: - _min_tokens_logits_processor = llama_cpp.LogitsProcessorList( - [llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())] - ) - if "logits_processor" not in kwargs: - kwargs["logits_processor"] = _min_tokens_logits_processor - else: - kwargs["logits_processor"].extend(_min_tokens_logits_processor) - - try: - iterator_or_completion: Union[ - llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk] - ] = await run_in_threadpool(llama.create_chat_completion, **kwargs) - except Exception as err: - await exit_stack.aclose() - raise err - - if isinstance(iterator_or_completion, Iterator): - # EAFP: It's easier to ask for forgiveness than permission - first_response = await run_in_threadpool(next, iterator_or_completion) - - # If no exception was raised from first_response, we can assume that - # the iterator is valid and we can use it to stream the response. - def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]: - yield first_response - yield from iterator_or_completion + # handle streaming request + if kwargs.get("stream", False): send_chan, recv_chan = anyio.create_memory_object_stream(10) return EventSourceResponse( recv_chan, @@ -533,15 +504,33 @@ def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]: get_event_publisher, request=request, inner_send_chan=send_chan, - iterator=iterator(), - on_complete=exit_stack.aclose, + body=body, + body_model=body_model, + llama_call=llama_cpp.Llama.create_chat_completion, + kwargs=kwargs, ), sep="\n", ping_message_factory=_ping_message_factory, ) - else: + + # handle regular request + exit_stack = contextlib.AsyncExitStack() + llama_proxy: LlamaProxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)()) + llama = prepare_request_resources(body, llama_proxy, body_model, kwargs) + + if await request.is_disconnected(): + print(f"Disconnected from client (via refresh/close) before llm invoked {request.client}") + await exit_stack.aclose() + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Client closed request", + ) + + try: + completion: llama_cpp.ChatCompletion = await run_in_threadpool(llama.create_chat_completion, **kwargs) + finally: await exit_stack.aclose() - return iterator_or_completion + return completion @router.get( From 6f9cfc31d9ef6318edd1619b24153f5d51aaa8cc Mon Sep 17 00:00:00 2001 From: Graeme Power Date: Mon, 23 Dec 2024 13:19:07 +0000 Subject: [PATCH 2/4] fix: simplify exit stack management for create_chat_completion and create_completion --- llama_cpp/server/app.py | 94 ++++++++++++++++++----------------------- 1 file changed, 41 insertions(+), 53 deletions(-) diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index d4ab23ac0..6b0816a4e 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -198,27 +198,25 @@ async def get_event_publisher( interrupt_requests = ( server_settings.interrupt_requests if server_settings else False ) - exit_stack = contextlib.AsyncExitStack() - llama_proxy: LlamaProxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)()) - llama = prepare_request_resources(body, llama_proxy, body_model, kwargs) - async with inner_send_chan: - try: - iterator = await run_in_threadpool(llama_call, llama, **kwargs) - async for chunk in iterate_in_threadpool(iterator): - await inner_send_chan.send(dict(data=json.dumps(chunk))) - if await request.is_disconnected(): - raise anyio.get_cancelled_exc_class()() - if interrupt_requests and llama_outer_lock.locked(): - await inner_send_chan.send(dict(data="[DONE]")) - raise anyio.get_cancelled_exc_class()() - await inner_send_chan.send(dict(data="[DONE]")) - except anyio.get_cancelled_exc_class() as e: - print("disconnected") - with anyio.move_on_after(1, shield=True): - print(f"Disconnected from client (via refresh/close) {request.client}") - raise e - finally: - await exit_stack.aclose() + async with contextlib.AsyncExitStack() as exit_stack: + llama_proxy: LlamaProxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)()) + llama = prepare_request_resources(body, llama_proxy, body_model, kwargs) + async with inner_send_chan: + try: + iterator = await run_in_threadpool(llama_call, llama, **kwargs) + async for chunk in iterate_in_threadpool(iterator): + await inner_send_chan.send(dict(data=json.dumps(chunk))) + if await request.is_disconnected(): + raise anyio.get_cancelled_exc_class()() + if interrupt_requests and llama_outer_lock.locked(): + await inner_send_chan.send(dict(data="[DONE]")) + raise anyio.get_cancelled_exc_class()() + await inner_send_chan.send(dict(data="[DONE]")) + except anyio.get_cancelled_exc_class() as e: + print("disconnected") + with anyio.move_on_after(1, shield=True): + print(f"Disconnected from client (via refresh/close) {request.client}") + raise e def _logit_bias_tokens_to_input_ids( @@ -341,23 +339,18 @@ async def create_completion( ) # handle regular request - exit_stack = contextlib.AsyncExitStack() - llama_proxy: LlamaProxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)()) - llama = prepare_request_resources(body, llama_proxy, body_model, kwargs) + async with contextlib.AsyncExitStack() as exit_stack: + llama_proxy: LlamaProxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)()) + llama = prepare_request_resources(body, llama_proxy, body_model, kwargs) - if await request.is_disconnected(): - print(f"Disconnected from client (via refresh/close) before llm invoked {request.client}") - await exit_stack.aclose() - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Client closed request", - ) + if await request.is_disconnected(): + print(f"Disconnected from client (via refresh/close) before llm invoked {request.client}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Client closed request", + ) - try: - completion: llama_cpp.CreateCompletionResponse = await run_in_threadpool(llama, **kwargs) - finally: - await exit_stack.aclose() - return completion + return await run_in_threadpool(llama, **kwargs) @router.post( @@ -514,23 +507,18 @@ async def create_chat_completion( ) # handle regular request - exit_stack = contextlib.AsyncExitStack() - llama_proxy: LlamaProxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)()) - llama = prepare_request_resources(body, llama_proxy, body_model, kwargs) - - if await request.is_disconnected(): - print(f"Disconnected from client (via refresh/close) before llm invoked {request.client}") - await exit_stack.aclose() - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Client closed request", - ) - - try: - completion: llama_cpp.ChatCompletion = await run_in_threadpool(llama.create_chat_completion, **kwargs) - finally: - await exit_stack.aclose() - return completion + with contextlib.AsyncExitStack() as exit_stack: + llama_proxy: LlamaProxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)()) + llama = prepare_request_resources(body, llama_proxy, body_model, kwargs) + + if await request.is_disconnected(): + print(f"Disconnected from client (via refresh/close) before llm invoked {request.client}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Client closed request", + ) + + return await run_in_threadpool(llama.create_chat_completion, **kwargs) @router.get( From f4fb0ce9a7685aa09fda5e2fc7691c1ec1459e69 Mon Sep 17 00:00:00 2001 From: Graeme Power Date: Mon, 23 Dec 2024 13:38:10 +0000 Subject: [PATCH 3/4] fix: correct missing `async with` and format code --- llama_cpp/server/app.py | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 6b0816a4e..4ef906782 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -7,7 +7,7 @@ from anyio import Lock from functools import partial -from typing import Iterator, List, Optional, Union, Dict +from typing import List, Optional, Union, Dict import llama_cpp @@ -154,11 +154,13 @@ def create_app( return app + def prepare_request_resources( body: CreateCompletionRequest | CreateChatCompletionRequest, llama_proxy: LlamaProxy, body_model: str, - kwargs) -> llama_cpp.Llama: + kwargs, +) -> llama_cpp.Llama: if llama_proxy is None: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, @@ -199,7 +201,9 @@ async def get_event_publisher( server_settings.interrupt_requests if server_settings else False ) async with contextlib.AsyncExitStack() as exit_stack: - llama_proxy: LlamaProxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)()) + llama_proxy: LlamaProxy = await exit_stack.enter_async_context( + contextlib.asynccontextmanager(get_llama_proxy)() + ) llama = prepare_request_resources(body, llama_proxy, body_model, kwargs) async with inner_send_chan: try: @@ -215,7 +219,9 @@ async def get_event_publisher( except anyio.get_cancelled_exc_class() as e: print("disconnected") with anyio.move_on_after(1, shield=True): - print(f"Disconnected from client (via refresh/close) {request.client}") + print( + f"Disconnected from client (via refresh/close) {request.client}" + ) raise e @@ -340,11 +346,15 @@ async def create_completion( # handle regular request async with contextlib.AsyncExitStack() as exit_stack: - llama_proxy: LlamaProxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)()) + llama_proxy: LlamaProxy = await exit_stack.enter_async_context( + contextlib.asynccontextmanager(get_llama_proxy)() + ) llama = prepare_request_resources(body, llama_proxy, body_model, kwargs) if await request.is_disconnected(): - print(f"Disconnected from client (via refresh/close) before llm invoked {request.client}") + print( + f"Disconnected from client (via refresh/close) before llm invoked {request.client}" + ) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Client closed request", @@ -507,12 +517,16 @@ async def create_chat_completion( ) # handle regular request - with contextlib.AsyncExitStack() as exit_stack: - llama_proxy: LlamaProxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)()) + async with contextlib.AsyncExitStack() as exit_stack: + llama_proxy: LlamaProxy = await exit_stack.enter_async_context( + contextlib.asynccontextmanager(get_llama_proxy)() + ) llama = prepare_request_resources(body, llama_proxy, body_model, kwargs) if await request.is_disconnected(): - print(f"Disconnected from client (via refresh/close) before llm invoked {request.client}") + print( + f"Disconnected from client (via refresh/close) before llm invoked {request.client}" + ) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Client closed request", From 1ee719acf7c44682d1e1ed86684e0dbf2d146d9f Mon Sep 17 00:00:00 2001 From: Graeme Power Date: Tue, 24 Dec 2024 12:22:22 +0000 Subject: [PATCH 4/4] fix: remove unnecessary explicit use of AsyncExitStack fix: correct type hints for body_model --- llama_cpp/server/app.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 4ef906782..5120f2416 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -158,7 +158,7 @@ def create_app( def prepare_request_resources( body: CreateCompletionRequest | CreateChatCompletionRequest, llama_proxy: LlamaProxy, - body_model: str, + body_model: str | None, kwargs, ) -> llama_cpp.Llama: if llama_proxy is None: @@ -192,7 +192,7 @@ async def get_event_publisher( request: Request, inner_send_chan: MemoryObjectSendStream[typing.Any], body: CreateCompletionRequest | CreateChatCompletionRequest, - body_model: str, + body_model: str | None, llama_call, kwargs, ): @@ -200,10 +200,7 @@ async def get_event_publisher( interrupt_requests = ( server_settings.interrupt_requests if server_settings else False ) - async with contextlib.AsyncExitStack() as exit_stack: - llama_proxy: LlamaProxy = await exit_stack.enter_async_context( - contextlib.asynccontextmanager(get_llama_proxy)() - ) + async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy: llama = prepare_request_resources(body, llama_proxy, body_model, kwargs) async with inner_send_chan: try: @@ -345,10 +342,7 @@ async def create_completion( ) # handle regular request - async with contextlib.AsyncExitStack() as exit_stack: - llama_proxy: LlamaProxy = await exit_stack.enter_async_context( - contextlib.asynccontextmanager(get_llama_proxy)() - ) + async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy: llama = prepare_request_resources(body, llama_proxy, body_model, kwargs) if await request.is_disconnected(): @@ -517,10 +511,7 @@ async def create_chat_completion( ) # handle regular request - async with contextlib.AsyncExitStack() as exit_stack: - llama_proxy: LlamaProxy = await exit_stack.enter_async_context( - contextlib.asynccontextmanager(get_llama_proxy)() - ) + async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy: llama = prepare_request_resources(body, llama_proxy, body_model, kwargs) if await request.is_disconnected():