@@ -180,18 +180,34 @@ def compare_two_settings(model: str,
180
180
env1: The first set of environment variables to pass to the API server.
181
181
env2: The second set of environment variables to pass to the API server.
182
182
"""
183
+ compare_all_settings (model , [arg1 , arg2 ], [env1 , env2 ], max_wait_seconds )
184
+
185
+
186
+ def compare_all_settings (model : str ,
187
+ all_args : List [List [str ]],
188
+ all_envs : List [Optional [Dict [str , str ]]],
189
+ max_wait_seconds : Optional [float ] = None ) -> None :
190
+ """
191
+ Launch API server with several different sets of arguments/environments
192
+ and compare the results of the API calls with the first set of arguments.
193
+ Args:
194
+ model: The model to test.
195
+ all_args: A list of argument lists to pass to the API server.
196
+ all_envs: A list of environment dictionaries to pass to the API server.
197
+ """
183
198
184
199
trust_remote_code = "--trust-remote-code"
185
- if trust_remote_code in arg1 or trust_remote_code in arg2 :
200
+ if any ( trust_remote_code in args for args in all_args ) :
186
201
tokenizer = AutoTokenizer .from_pretrained (model ,
187
202
trust_remote_code = True )
188
203
else :
189
204
tokenizer = AutoTokenizer .from_pretrained (model )
190
205
191
206
prompt = "Hello, my name is"
192
207
token_ids = tokenizer (prompt )["input_ids" ]
193
- results = []
194
- for args , env in ((arg1 , env1 ), (arg2 , env2 )):
208
+ ref_results : List = []
209
+ for i , (args , env ) in enumerate (zip (all_args , all_envs )):
210
+ compare_results : List = []
195
211
with RemoteOpenAIServer (model ,
196
212
args ,
197
213
env_dict = env ,
@@ -202,10 +218,13 @@ def compare_two_settings(model: str,
202
218
models = client .models .list ()
203
219
models = models .data
204
220
served_model = models [0 ]
205
- results .append ({
206
- "test" : "models_list" ,
207
- "id" : served_model .id ,
208
- "root" : served_model .root ,
221
+ (ref_results if i == 0 else compare_results ).append ({
222
+ "test" :
223
+ "models_list" ,
224
+ "id" :
225
+ served_model .id ,
226
+ "root" :
227
+ served_model .root ,
209
228
})
210
229
211
230
# test with text prompt
@@ -214,11 +233,15 @@ def compare_two_settings(model: str,
214
233
max_tokens = 5 ,
215
234
temperature = 0.0 )
216
235
217
- results .append ({
218
- "test" : "single_completion" ,
219
- "text" : completion .choices [0 ].text ,
220
- "finish_reason" : completion .choices [0 ].finish_reason ,
221
- "usage" : completion .usage ,
236
+ (ref_results if i == 0 else compare_results ).append ({
237
+ "test" :
238
+ "single_completion" ,
239
+ "text" :
240
+ completion .choices [0 ].text ,
241
+ "finish_reason" :
242
+ completion .choices [0 ].finish_reason ,
243
+ "usage" :
244
+ completion .usage ,
222
245
})
223
246
224
247
# test using token IDs
@@ -229,11 +252,15 @@ def compare_two_settings(model: str,
229
252
temperature = 0.0 ,
230
253
)
231
254
232
- results .append ({
233
- "test" : "token_ids" ,
234
- "text" : completion .choices [0 ].text ,
235
- "finish_reason" : completion .choices [0 ].finish_reason ,
236
- "usage" : completion .usage ,
255
+ (ref_results if i == 0 else compare_results ).append ({
256
+ "test" :
257
+ "token_ids" ,
258
+ "text" :
259
+ completion .choices [0 ].text ,
260
+ "finish_reason" :
261
+ completion .choices [0 ].finish_reason ,
262
+ "usage" :
263
+ completion .usage ,
237
264
})
238
265
239
266
# test seeded random sampling
@@ -243,11 +270,15 @@ def compare_two_settings(model: str,
243
270
seed = 33 ,
244
271
temperature = 1.0 )
245
272
246
- results .append ({
247
- "test" : "seeded_sampling" ,
248
- "text" : completion .choices [0 ].text ,
249
- "finish_reason" : completion .choices [0 ].finish_reason ,
250
- "usage" : completion .usage ,
273
+ (ref_results if i == 0 else compare_results ).append ({
274
+ "test" :
275
+ "seeded_sampling" ,
276
+ "text" :
277
+ completion .choices [0 ].text ,
278
+ "finish_reason" :
279
+ completion .choices [0 ].finish_reason ,
280
+ "usage" :
281
+ completion .usage ,
251
282
})
252
283
253
284
# test seeded random sampling with multiple prompts
@@ -257,7 +288,7 @@ def compare_two_settings(model: str,
257
288
seed = 33 ,
258
289
temperature = 1.0 )
259
290
260
- results .append ({
291
+ ( ref_results if i == 0 else compare_results ) .append ({
261
292
"test" :
262
293
"seeded_sampling" ,
263
294
"text" : [choice .text for choice in completion .choices ],
@@ -275,10 +306,13 @@ def compare_two_settings(model: str,
275
306
temperature = 0.0 ,
276
307
)
277
308
278
- results .append ({
279
- "test" : "simple_list" ,
280
- "text0" : batch .choices [0 ].text ,
281
- "text1" : batch .choices [1 ].text ,
309
+ (ref_results if i == 0 else compare_results ).append ({
310
+ "test" :
311
+ "simple_list" ,
312
+ "text0" :
313
+ batch .choices [0 ].text ,
314
+ "text1" :
315
+ batch .choices [1 ].text ,
282
316
})
283
317
284
318
# test streaming
@@ -294,18 +328,25 @@ def compare_two_settings(model: str,
294
328
assert len (chunk .choices ) == 1
295
329
choice = chunk .choices [0 ]
296
330
texts [choice .index ] += choice .text
297
- results .append ({
331
+ ( ref_results if i == 0 else compare_results ) .append ({
298
332
"test" : "streaming" ,
299
333
"texts" : texts ,
300
334
})
301
335
302
- n = len (results ) // 2
303
- arg1_results = results [:n ]
304
- arg2_results = results [n :]
305
- for arg1_result , arg2_result in zip (arg1_results , arg2_results ):
306
- assert arg1_result == arg2_result , (
307
- f"Results for { model = } are not the same with { arg1 = } and { arg2 = } . "
308
- f"{ arg1_result = } != { arg2_result = } " )
336
+ if i > 0 :
337
+ # if any setting fails, raise an error early
338
+ ref_args = all_args [0 ]
339
+ ref_envs = all_envs [0 ]
340
+ compare_args = all_args [i ]
341
+ compare_envs = all_envs [i ]
342
+ for ref_result , compare_result in zip (ref_results ,
343
+ compare_results ):
344
+ assert ref_result == compare_result , (
345
+ f"Results for { model = } are not the same.\n "
346
+ f"{ ref_args = } { ref_envs = } \n "
347
+ f"{ compare_args = } { compare_envs = } \n "
348
+ f"{ ref_result = } \n "
349
+ f"{ compare_result = } \n " )
309
350
310
351
311
352
def init_test_distributed_environment (
0 commit comments