1
1
import asyncio
2
+ import collections
2
3
import json
3
4
import os
4
5
import re
@@ -261,35 +262,35 @@ def step_a_prompt_prompt(context, prompt):
261
262
@step (u'concurrent completion requests' )
262
263
@async_run_until_complete ()
263
264
async def step_concurrent_completion_requests (context ):
264
- await concurrent_completion_requests (context ,
265
- request_completion ,
266
- # prompt is inserted automatically
267
- context .base_url ,
268
- debug = context .debug ,
269
- n_predict = context .n_predict if hasattr (context , 'n_predict' ) else None ,
270
- server_seed = context .server_seed if hasattr (context , 'server_seed' ) else None ,
271
- user_api_key = context .user_api_key if hasattr (context ,
272
- 'user_api_key' ) else None )
265
+ await concurrent_requests (context ,
266
+ request_completion ,
267
+ # prompt is inserted automatically
268
+ context .base_url ,
269
+ debug = context .debug ,
270
+ n_predict = context .n_predict if hasattr (context , 'n_predict' ) else None ,
271
+ server_seed = context .server_seed if hasattr (context , 'server_seed' ) else None ,
272
+ user_api_key = context .user_api_key if hasattr (context ,
273
+ 'user_api_key' ) else None )
273
274
274
275
275
276
@step (u'concurrent OAI completions requests' )
276
277
@async_run_until_complete
277
278
async def step_oai_chat_completions (context ):
278
- await concurrent_completion_requests (context , oai_chat_completions ,
279
- # user_prompt is inserted automatically
280
- context .system_prompt ,
281
- context .base_url ,
282
- True , # async_client
283
- model = context .model
284
- if hasattr (context , 'model' ) else None ,
285
- n_predict = context .n_predict
286
- if hasattr (context , 'n_predict' ) else None ,
287
- enable_streaming = context .enable_streaming
288
- if hasattr (context , 'enable_streaming' ) else None ,
289
- server_seed = context .server_seed
290
- if hasattr (context , 'server_seed' ) else None ,
291
- user_api_key = context .user_api_key
292
- if hasattr (context , 'user_api_key' ) else None )
279
+ await concurrent_requests (context , oai_chat_completions ,
280
+ # user_prompt is inserted automatically
281
+ context .system_prompt ,
282
+ context .base_url ,
283
+ True , # async_client
284
+ model = context .model
285
+ if hasattr (context , 'model' ) else None ,
286
+ n_predict = context .n_predict
287
+ if hasattr (context , 'n_predict' ) else None ,
288
+ enable_streaming = context .enable_streaming
289
+ if hasattr (context , 'enable_streaming' ) else None ,
290
+ server_seed = context .server_seed
291
+ if hasattr (context , 'server_seed' ) else None ,
292
+ user_api_key = context .user_api_key
293
+ if hasattr (context , 'user_api_key' ) else None )
293
294
294
295
295
296
@step (u'all prompts are predicted' )
@@ -316,36 +317,58 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
316
317
@step (u'embeddings are computed for' )
317
318
@async_run_until_complete
318
319
async def step_compute_embedding (context ):
319
- content = context .text
320
- base_url = context .base_url
321
- context .embeddings = await request_embedding (content , base_url )
320
+ context .embeddings = await request_embedding (context .text , base_url = context .base_url )
322
321
323
322
324
323
@step (u'embeddings are generated' )
325
324
def step_assert_embeddings (context ):
326
- assert_embeddings (context .embeddings )
325
+ if len (context .prompts ) == 0 :
326
+ assert_embeddings (context .embeddings )
327
+ else :
328
+ assert len (context .embeddings ) == len (context .prompts ), (f"unexpected response:\n "
329
+ f"context.prompts={ context .prompts } \n "
330
+ f"context.embeddings={ context .embeddings } " )
331
+ for embedding in context .embeddings :
332
+ context .prompts .pop ()
333
+ assert_embeddings (embedding )
327
334
328
335
329
336
@step (u'an OAI compatible embeddings computation request for' )
330
- def step_oai_compute_embedding (context ):
331
- openai .api_key = 'nope' # openai client always expects an api_keu
332
- if context .user_api_key is not None :
333
- openai .api_key = context .user_api_key
334
- openai .api_base = f'{ context .base_url } /v1'
335
- embeddings = openai .Embedding .create (
336
- model = context .model ,
337
- input = context .text ,
338
- )
339
- context .embeddings = embeddings
337
+ @async_run_until_complete
338
+ async def step_oai_compute_embeddings (context ):
339
+ context .embeddings = await request_oai_embeddings (context .text ,
340
+ base_url = context .base_url ,
341
+ user_api_key = context .user_api_key ,
342
+ model = context .model )
343
+
344
+
345
+ @step (u'an OAI compatible embeddings computation request for multiple inputs' )
346
+ @async_run_until_complete
347
+ async def step_oai_compute_embeddings_multiple_inputs (context ):
348
+ context .embeddings = await request_oai_embeddings (context .prompts ,
349
+ base_url = context .base_url ,
350
+ user_api_key = context .user_api_key ,
351
+ model = context .model )
340
352
341
353
342
354
@step (u'concurrent embedding requests' )
343
355
@async_run_until_complete ()
344
356
async def step_concurrent_embedding_requests (context ):
345
- await concurrent_completion_requests (context ,
346
- request_embedding ,
347
- # prompt is inserted automatically
348
- context .base_url )
357
+ await concurrent_requests (context ,
358
+ request_embedding ,
359
+ # prompt is inserted automatically
360
+ base_url = context .base_url )
361
+
362
+
363
+ @step (u'concurrent OAI embedding requests' )
364
+ @async_run_until_complete ()
365
+ async def step_concurrent_oai_embedding_requests (context ):
366
+ await concurrent_requests (context ,
367
+ request_oai_embeddings ,
368
+ # prompt is inserted automatically
369
+ base_url = context .base_url ,
370
+ async_client = True ,
371
+ model = context .model )
349
372
350
373
351
374
@step (u'all embeddings are generated' )
@@ -401,7 +424,7 @@ def step_check_options_header_value(context, cors_header, cors_header_value):
401
424
assert context .options_response .headers [cors_header ] == cors_header_value
402
425
403
426
404
- async def concurrent_completion_requests (context , f_completion , * args , ** kwargs ):
427
+ async def concurrent_requests (context , f_completion , * args , ** kwargs ):
405
428
n_prompts = len (context .prompts )
406
429
if context .debug :
407
430
print (f"starting { n_prompts } concurrent completion requests..." )
@@ -565,7 +588,7 @@ async def oai_chat_completions(user_prompt,
565
588
return completion_response
566
589
567
590
568
- async def request_embedding (content , base_url ):
591
+ async def request_embedding (content , base_url = None ):
569
592
async with aiohttp .ClientSession () as session :
570
593
async with session .post (f'{ base_url } /embedding' ,
571
594
json = {
@@ -576,6 +599,46 @@ async def request_embedding(content, base_url):
576
599
return response_json ['embedding' ]
577
600
578
601
602
+ async def request_oai_embeddings (input ,
603
+ base_url = None , user_api_key = None ,
604
+ model = None , async_client = False ):
605
+ # openai client always expects an api_key
606
+ user_api_key = user_api_key if user_api_key is not None else 'nope'
607
+ if async_client :
608
+ origin = 'llama.cpp'
609
+ if user_api_key is not None :
610
+ headers = {'Authorization' : f'Bearer { user_api_key } ' , 'Origin' : origin }
611
+ async with aiohttp .ClientSession () as session :
612
+ async with session .post (f'{ base_url } /v1/embeddings' ,
613
+ json = {
614
+ "input" : input ,
615
+ "model" : model ,
616
+ },
617
+ headers = headers ) as response :
618
+ assert response .status == 200 , f"received status code not expected: { response .status } "
619
+ assert response .headers ['Access-Control-Allow-Origin' ] == origin
620
+ assert response .headers ['Content-Type' ] == "application/json; charset=utf-8"
621
+ response_json = await response .json ()
622
+ assert response_json ['model' ] == model , f"invalid model received: { response_json ['model' ]} "
623
+ assert response_json ['object' ] == 'list'
624
+ return response_json ['data' ]
625
+ else :
626
+ openai .api_key = user_api_key
627
+ openai .api_base = f'{ base_url } /v1'
628
+ oai_embeddings = openai .Embedding .create (
629
+ model = model ,
630
+ input = input ,
631
+ )
632
+
633
+ if isinstance (input , collections .abc .Sequence ):
634
+ embeddings = []
635
+ for an_oai_embeddings in oai_embeddings .data :
636
+ embeddings .append (an_oai_embeddings .embedding )
637
+ else :
638
+ embeddings = oai_embeddings .data .embedding
639
+ return embeddings
640
+
641
+
579
642
def assert_n_tokens_predicted (completion_response , expected_predicted_n = None , re_content = None ):
580
643
content = completion_response ['content' ]
581
644
n_predicted = completion_response ['timings' ]['predicted_n' ]
0 commit comments