Skip to content

Commit 22ef180

Browse files
youkaichaogarg-amit
authored andcommitted
[misc] update utils to support comparing multiple settings (vllm-project#9140)
Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
1 parent cbb5af3 commit 22ef180

File tree

1 file changed

+44
-11
lines changed

1 file changed

+44
-11
lines changed

tests/utils.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -310,14 +310,38 @@ def compare_two_settings(model: str,
310310
env2: The second set of environment variables to pass to the API server.
311311
"""
312312

313+
compare_all_settings(
314+
model,
315+
[arg1, arg2],
316+
[env1, env2],
317+
method=method,
318+
max_wait_seconds=max_wait_seconds,
319+
)
320+
321+
322+
def compare_all_settings(model: str,
323+
all_args: List[List[str]],
324+
all_envs: List[Optional[Dict[str, str]]],
325+
*,
326+
method: Literal["generate", "encode"] = "generate",
327+
max_wait_seconds: Optional[float] = None) -> None:
328+
"""
329+
Launch API server with several different sets of arguments/environments
330+
and compare the results of the API calls with the first set of arguments.
331+
Args:
332+
model: The model to test.
333+
all_args: A list of argument lists to pass to the API server.
334+
all_envs: A list of environment dictionaries to pass to the API server.
335+
"""
336+
313337
trust_remote_code = False
314-
for args in (arg1, arg2):
338+
for args in all_args:
315339
if "--trust-remote-code" in args:
316340
trust_remote_code = True
317341
break
318342

319343
tokenizer_mode = "auto"
320-
for args in (arg1, arg2):
344+
for args in all_args:
321345
if "--tokenizer-mode" in args:
322346
tokenizer_mode = args[args.index("--tokenizer-mode") + 1]
323347
break
@@ -330,8 +354,10 @@ def compare_two_settings(model: str,
330354

331355
prompt = "Hello, my name is"
332356
token_ids = tokenizer(prompt).input_ids
333-
results = []
334-
for args, env in ((arg1, env1), (arg2, env2)):
357+
ref_results: List = []
358+
for i, (args, env) in enumerate(zip(all_args, all_envs)):
359+
compare_results: List = []
360+
results = ref_results if i == 0 else compare_results
335361
with RemoteOpenAIServer(model,
336362
args,
337363
env_dict=env,
@@ -355,13 +381,20 @@ def compare_two_settings(model: str,
355381
else:
356382
assert_never(method)
357383

358-
n = len(results) // 2
359-
arg1_results = results[:n]
360-
arg2_results = results[n:]
361-
for arg1_result, arg2_result in zip(arg1_results, arg2_results):
362-
assert arg1_result == arg2_result, (
363-
f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
364-
f"{arg1_result=} != {arg2_result=}")
384+
if i > 0:
385+
# if any setting fails, raise an error early
386+
ref_args = all_args[0]
387+
ref_envs = all_envs[0]
388+
compare_args = all_args[i]
389+
compare_envs = all_envs[i]
390+
for ref_result, compare_result in zip(ref_results,
391+
compare_results):
392+
assert ref_result == compare_result, (
393+
f"Results for {model=} are not the same.\n"
394+
f"{ref_args=} {ref_envs=}\n"
395+
f"{compare_args=} {compare_envs=}\n"
396+
f"{ref_result=}\n"
397+
f"{compare_result=}\n")
365398

366399

367400
def init_test_distributed_environment(

0 commit comments

Comments
 (0)