@@ -310,14 +310,38 @@ def compare_two_settings(model: str,
310
310
env2: The second set of environment variables to pass to the API server.
311
311
"""
312
312
313
+ compare_all_settings (
314
+ model ,
315
+ [arg1 , arg2 ],
316
+ [env1 , env2 ],
317
+ method = method ,
318
+ max_wait_seconds = max_wait_seconds ,
319
+ )
320
+
321
+
322
+ def compare_all_settings (model : str ,
323
+ all_args : List [List [str ]],
324
+ all_envs : List [Optional [Dict [str , str ]]],
325
+ * ,
326
+ method : Literal ["generate" , "encode" ] = "generate" ,
327
+ max_wait_seconds : Optional [float ] = None ) -> None :
328
+ """
329
+ Launch API server with several different sets of arguments/environments
330
+ and compare the results of the API calls with the first set of arguments.
331
+ Args:
332
+ model: The model to test.
333
+ all_args: A list of argument lists to pass to the API server.
334
+ all_envs: A list of environment dictionaries to pass to the API server.
335
+ """
336
+
313
337
trust_remote_code = False
314
- for args in ( arg1 , arg2 ) :
338
+ for args in all_args :
315
339
if "--trust-remote-code" in args :
316
340
trust_remote_code = True
317
341
break
318
342
319
343
tokenizer_mode = "auto"
320
- for args in ( arg1 , arg2 ) :
344
+ for args in all_args :
321
345
if "--tokenizer-mode" in args :
322
346
tokenizer_mode = args [args .index ("--tokenizer-mode" ) + 1 ]
323
347
break
@@ -330,8 +354,10 @@ def compare_two_settings(model: str,
330
354
331
355
prompt = "Hello, my name is"
332
356
token_ids = tokenizer (prompt ).input_ids
333
- results = []
334
- for args , env in ((arg1 , env1 ), (arg2 , env2 )):
357
+ ref_results : List = []
358
+ for i , (args , env ) in enumerate (zip (all_args , all_envs )):
359
+ compare_results : List = []
360
+ results = ref_results if i == 0 else compare_results
335
361
with RemoteOpenAIServer (model ,
336
362
args ,
337
363
env_dict = env ,
@@ -355,13 +381,20 @@ def compare_two_settings(model: str,
355
381
else :
356
382
assert_never (method )
357
383
358
- n = len (results ) // 2
359
- arg1_results = results [:n ]
360
- arg2_results = results [n :]
361
- for arg1_result , arg2_result in zip (arg1_results , arg2_results ):
362
- assert arg1_result == arg2_result , (
363
- f"Results for { model = } are not the same with { arg1 = } and { arg2 = } . "
364
- f"{ arg1_result = } != { arg2_result = } " )
384
+ if i > 0 :
385
+ # if any setting fails, raise an error early
386
+ ref_args = all_args [0 ]
387
+ ref_envs = all_envs [0 ]
388
+ compare_args = all_args [i ]
389
+ compare_envs = all_envs [i ]
390
+ for ref_result , compare_result in zip (ref_results ,
391
+ compare_results ):
392
+ assert ref_result == compare_result , (
393
+ f"Results for { model = } are not the same.\n "
394
+ f"{ ref_args = } { ref_envs = } \n "
395
+ f"{ compare_args = } { compare_envs = } \n "
396
+ f"{ ref_result = } \n "
397
+ f"{ compare_result = } \n " )
365
398
366
399
367
400
def init_test_distributed_environment (
0 commit comments