@@ -154,19 +154,56 @@ def create_app(
154
154
155
155
return app
156
156
157
+ def prepare_request_resources (
158
+ body : CreateCompletionRequest | CreateChatCompletionRequest ,
159
+ llama_proxy : LlamaProxy ,
160
+ body_model : str ,
161
+ kwargs ) -> llama_cpp .Llama :
162
+ if llama_proxy is None :
163
+ raise HTTPException (
164
+ status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
165
+ detail = "Service is not available" ,
166
+ )
167
+ llama = llama_proxy (body_model )
168
+ if body .logit_bias is not None :
169
+ kwargs ["logit_bias" ] = (
170
+ _logit_bias_tokens_to_input_ids (llama , body .logit_bias )
171
+ if body .logit_bias_type == "tokens"
172
+ else body .logit_bias
173
+ )
174
+
175
+ if body .grammar is not None :
176
+ kwargs ["grammar" ] = llama_cpp .LlamaGrammar .from_string (body .grammar )
177
+
178
+ if body .min_tokens > 0 :
179
+ _min_tokens_logits_processor = llama_cpp .LogitsProcessorList (
180
+ [llama_cpp .MinTokensLogitsProcessor (body .min_tokens , llama .token_eos ())]
181
+ )
182
+ if "logits_processor" not in kwargs :
183
+ kwargs ["logits_processor" ] = _min_tokens_logits_processor
184
+ else :
185
+ kwargs ["logits_processor" ].extend (_min_tokens_logits_processor )
186
+ return llama
187
+
157
188
158
189
async def get_event_publisher (
159
190
request : Request ,
160
191
inner_send_chan : MemoryObjectSendStream [typing .Any ],
161
- iterator : Iterator [typing .Any ],
162
- on_complete : typing .Optional [typing .Callable [[], typing .Awaitable [None ]]] = None ,
192
+ body : CreateCompletionRequest | CreateChatCompletionRequest ,
193
+ body_model : str ,
194
+ llama_call ,
195
+ kwargs ,
163
196
):
164
197
server_settings = next (get_server_settings ())
165
198
interrupt_requests = (
166
199
server_settings .interrupt_requests if server_settings else False
167
200
)
201
+ exit_stack = contextlib .AsyncExitStack ()
202
+ llama_proxy : LlamaProxy = await exit_stack .enter_async_context (contextlib .asynccontextmanager (get_llama_proxy )())
203
+ llama = prepare_request_resources (body , llama_proxy , body_model , kwargs )
168
204
async with inner_send_chan :
169
205
try :
206
+ iterator = await run_in_threadpool (llama_call , llama , ** kwargs )
170
207
async for chunk in iterate_in_threadpool (iterator ):
171
208
await inner_send_chan .send (dict (data = json .dumps (chunk )))
172
209
if await request .is_disconnected ():
@@ -181,8 +218,7 @@ async def get_event_publisher(
181
218
print (f"Disconnected from client (via refresh/close) { request .client } " )
182
219
raise e
183
220
finally :
184
- if on_complete :
185
- await on_complete ()
221
+ await exit_stack .aclose ()
186
222
187
223
188
224
def _logit_bias_tokens_to_input_ids (
@@ -267,18 +303,11 @@ async def create_completion(
267
303
request : Request ,
268
304
body : CreateCompletionRequest ,
269
305
) -> llama_cpp .Completion :
270
- exit_stack = contextlib .AsyncExitStack ()
271
- llama_proxy = await exit_stack .enter_async_context (contextlib .asynccontextmanager (get_llama_proxy )())
272
- if llama_proxy is None :
273
- raise HTTPException (
274
- status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
275
- detail = "Service is not available" ,
276
- )
277
306
if isinstance (body .prompt , list ):
278
307
assert len (body .prompt ) <= 1
279
308
body .prompt = body .prompt [0 ] if len (body .prompt ) > 0 else ""
280
309
281
- llama = llama_proxy (
310
+ body_model = (
282
311
body .model
283
312
if request .url .path != "/v1/engines/copilot-codex/completions"
284
313
else "copilot-codex"
@@ -293,60 +322,42 @@ async def create_completion(
293
322
}
294
323
kwargs = body .model_dump (exclude = exclude )
295
324
296
- if body .logit_bias is not None :
297
- kwargs ["logit_bias" ] = (
298
- _logit_bias_tokens_to_input_ids (llama , body .logit_bias )
299
- if body .logit_bias_type == "tokens"
300
- else body .logit_bias
301
- )
302
-
303
- if body .grammar is not None :
304
- kwargs ["grammar" ] = llama_cpp .LlamaGrammar .from_string (body .grammar )
305
-
306
- if body .min_tokens > 0 :
307
- _min_tokens_logits_processor = llama_cpp .LogitsProcessorList (
308
- [llama_cpp .MinTokensLogitsProcessor (body .min_tokens , llama .token_eos ())]
309
- )
310
- if "logits_processor" not in kwargs :
311
- kwargs ["logits_processor" ] = _min_tokens_logits_processor
312
- else :
313
- kwargs ["logits_processor" ].extend (_min_tokens_logits_processor )
314
-
315
- try :
316
- iterator_or_completion : Union [
317
- llama_cpp .CreateCompletionResponse ,
318
- Iterator [llama_cpp .CreateCompletionStreamResponse ],
319
- ] = await run_in_threadpool (llama , ** kwargs )
320
- except Exception as err :
321
- await exit_stack .aclose ()
322
- raise err
323
-
324
- if isinstance (iterator_or_completion , Iterator ):
325
- # EAFP: It's easier to ask for forgiveness than permission
326
- first_response = await run_in_threadpool (next , iterator_or_completion )
327
-
328
- # If no exception was raised from first_response, we can assume that
329
- # the iterator is valid and we can use it to stream the response.
330
- def iterator () -> Iterator [llama_cpp .CreateCompletionStreamResponse ]:
331
- yield first_response
332
- yield from iterator_or_completion
333
-
325
+ # handle streaming request
326
+ if kwargs .get ("stream" , False ):
334
327
send_chan , recv_chan = anyio .create_memory_object_stream (10 )
335
328
return EventSourceResponse (
336
329
recv_chan ,
337
330
data_sender_callable = partial ( # type: ignore
338
331
get_event_publisher ,
339
332
request = request ,
340
333
inner_send_chan = send_chan ,
341
- iterator = iterator (),
342
- on_complete = exit_stack .aclose ,
334
+ body = body ,
335
+ body_model = body_model ,
336
+ llama_call = llama_cpp .Llama .__call__ ,
337
+ kwargs = kwargs ,
343
338
),
344
339
sep = "\n " ,
345
340
ping_message_factory = _ping_message_factory ,
346
341
)
347
- else :
342
+
343
+ # handle regular request
344
+ exit_stack = contextlib .AsyncExitStack ()
345
+ llama_proxy : LlamaProxy = await exit_stack .enter_async_context (contextlib .asynccontextmanager (get_llama_proxy )())
346
+ llama = prepare_request_resources (body , llama_proxy , body_model , kwargs )
347
+
348
+ if await request .is_disconnected ():
349
+ print (f"Disconnected from client (via refresh/close) before llm invoked { request .client } " )
348
350
await exit_stack .aclose ()
349
- return iterator_or_completion
351
+ raise HTTPException (
352
+ status_code = status .HTTP_400_BAD_REQUEST ,
353
+ detail = "Client closed request" ,
354
+ )
355
+
356
+ try :
357
+ completion : llama_cpp .CreateCompletionResponse = await run_in_threadpool (llama , ** kwargs )
358
+ finally :
359
+ await exit_stack .aclose ()
360
+ return completion
350
361
351
362
352
363
@router .post (
@@ -474,74 +485,52 @@ async def create_chat_completion(
474
485
# where the dependency is cleaned up before a StreamingResponse
475
486
# is complete.
476
487
# https://github.com/tiangolo/fastapi/issues/11143
477
- exit_stack = contextlib .AsyncExitStack ()
478
- llama_proxy = await exit_stack .enter_async_context (contextlib .asynccontextmanager (get_llama_proxy )())
479
- if llama_proxy is None :
480
- raise HTTPException (
481
- status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
482
- detail = "Service is not available" ,
483
- )
488
+
489
+ body_model = body .model
484
490
exclude = {
485
491
"n" ,
486
492
"logit_bias_type" ,
487
493
"user" ,
488
494
"min_tokens" ,
489
495
}
490
496
kwargs = body .model_dump (exclude = exclude )
491
- llama = llama_proxy (body .model )
492
- if body .logit_bias is not None :
493
- kwargs ["logit_bias" ] = (
494
- _logit_bias_tokens_to_input_ids (llama , body .logit_bias )
495
- if body .logit_bias_type == "tokens"
496
- else body .logit_bias
497
- )
498
-
499
- if body .grammar is not None :
500
- kwargs ["grammar" ] = llama_cpp .LlamaGrammar .from_string (body .grammar )
501
-
502
- if body .min_tokens > 0 :
503
- _min_tokens_logits_processor = llama_cpp .LogitsProcessorList (
504
- [llama_cpp .MinTokensLogitsProcessor (body .min_tokens , llama .token_eos ())]
505
- )
506
- if "logits_processor" not in kwargs :
507
- kwargs ["logits_processor" ] = _min_tokens_logits_processor
508
- else :
509
- kwargs ["logits_processor" ].extend (_min_tokens_logits_processor )
510
-
511
- try :
512
- iterator_or_completion : Union [
513
- llama_cpp .ChatCompletion , Iterator [llama_cpp .ChatCompletionChunk ]
514
- ] = await run_in_threadpool (llama .create_chat_completion , ** kwargs )
515
- except Exception as err :
516
- await exit_stack .aclose ()
517
- raise err
518
-
519
- if isinstance (iterator_or_completion , Iterator ):
520
- # EAFP: It's easier to ask for forgiveness than permission
521
- first_response = await run_in_threadpool (next , iterator_or_completion )
522
-
523
- # If no exception was raised from first_response, we can assume that
524
- # the iterator is valid and we can use it to stream the response.
525
- def iterator () -> Iterator [llama_cpp .ChatCompletionChunk ]:
526
- yield first_response
527
- yield from iterator_or_completion
528
497
498
+ # handle streaming request
499
+ if kwargs .get ("stream" , False ):
529
500
send_chan , recv_chan = anyio .create_memory_object_stream (10 )
530
501
return EventSourceResponse (
531
502
recv_chan ,
532
503
data_sender_callable = partial ( # type: ignore
533
504
get_event_publisher ,
534
505
request = request ,
535
506
inner_send_chan = send_chan ,
536
- iterator = iterator (),
537
- on_complete = exit_stack .aclose ,
507
+ body = body ,
508
+ body_model = body_model ,
509
+ llama_call = llama_cpp .Llama .create_chat_completion ,
510
+ kwargs = kwargs ,
538
511
),
539
512
sep = "\n " ,
540
513
ping_message_factory = _ping_message_factory ,
541
514
)
542
- else :
515
+
516
+ # handle regular request
517
+ exit_stack = contextlib .AsyncExitStack ()
518
+ llama_proxy : LlamaProxy = await exit_stack .enter_async_context (contextlib .asynccontextmanager (get_llama_proxy )())
519
+ llama = prepare_request_resources (body , llama_proxy , body_model , kwargs )
520
+
521
+ if await request .is_disconnected ():
522
+ print (f"Disconnected from client (via refresh/close) before llm invoked { request .client } " )
523
+ await exit_stack .aclose ()
524
+ raise HTTPException (
525
+ status_code = status .HTTP_400_BAD_REQUEST ,
526
+ detail = "Client closed request" ,
527
+ )
528
+
529
+ try :
530
+ completion : llama_cpp .ChatCompletion = await run_in_threadpool (llama .create_chat_completion , ** kwargs )
531
+ finally :
543
532
await exit_stack .aclose ()
544
- return iterator_or_completion
533
+ return completion
545
534
546
535
547
536
@router .get (
0 commit comments