Skip to content

Commit 22bab11

Browse files
youkaichaosumitd2
authored andcommitted
[core] remove beam search from the core (vllm-project#9105)
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
1 parent 2ab4a63 commit 22bab11

25 files changed

+98
-596
lines changed

benchmarks/backend_request_func.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ class RequestFuncInput:
2323
output_len: int
2424
model: str
2525
best_of: int = 1
26-
use_beam_search: bool = False
2726
logprobs: Optional[int] = None
2827
multi_modal_content: Optional[dict] = None
2928
ignore_eos: bool = False
@@ -49,7 +48,6 @@ async def async_request_tgi(
4948
assert api_url.endswith("generate_stream")
5049

5150
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
52-
assert not request_func_input.use_beam_search
5351
params = {
5452
"best_of": request_func_input.best_of,
5553
"max_new_tokens": request_func_input.output_len,
@@ -121,7 +119,6 @@ async def async_request_trt_llm(
121119
assert api_url.endswith("generate_stream")
122120

123121
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
124-
assert not request_func_input.use_beam_search
125122
assert request_func_input.best_of == 1
126123
payload = {
127124
"accumulate_tokens": True,
@@ -187,7 +184,6 @@ async def async_request_deepspeed_mii(
187184
) -> RequestFuncOutput:
188185
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
189186
assert request_func_input.best_of == 1
190-
assert not request_func_input.use_beam_search
191187

192188
payload = {
193189
"prompt": request_func_input.prompt,
@@ -235,7 +231,6 @@ async def async_request_openai_completions(
235231
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
236232

237233
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
238-
assert not request_func_input.use_beam_search
239234
payload = {
240235
"model": request_func_input.model,
241236
"prompt": request_func_input.prompt,
@@ -317,7 +312,6 @@ async def async_request_openai_chat_completions(
317312
), "OpenAI Chat Completions API URL must end with 'chat/completions'."
318313

319314
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
320-
assert not request_func_input.use_beam_search
321315
content = [{"type": "text", "text": request_func_input.prompt}]
322316
if request_func_input.multi_modal_content:
323317
content.append(request_func_input.multi_modal_content)

benchmarks/benchmark_latency.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,8 @@ def main(args: argparse.Namespace):
5151

5252
sampling_params = SamplingParams(
5353
n=args.n,
54-
temperature=0.0 if args.use_beam_search else 1.0,
54+
temperature=1.0,
5555
top_p=1.0,
56-
use_beam_search=args.use_beam_search,
5756
ignore_eos=True,
5857
max_tokens=args.output_len,
5958
)

benchmarks/benchmark_prioritization.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def run_vllm(
6868
tensor_parallel_size: int,
6969
seed: int,
7070
n: int,
71-
use_beam_search: bool,
7271
trust_remote_code: bool,
7372
dtype: str,
7473
max_model_len: Optional[int],
@@ -114,9 +113,8 @@ def run_vllm(
114113
sampling_params.append(
115114
SamplingParams(
116115
n=n,
117-
temperature=0.0 if use_beam_search else 1.0,
116+
temperature=1.0,
118117
top_p=1.0,
119-
use_beam_search=use_beam_search,
120118
ignore_eos=True,
121119
max_tokens=output_len,
122120
))
@@ -144,15 +142,16 @@ def main(args: argparse.Namespace):
144142
args.output_len)
145143

146144
if args.backend == "vllm":
147-
elapsed_time = run_vllm(
148-
requests, args.model, args.tokenizer, args.quantization,
149-
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
150-
args.trust_remote_code, args.dtype, args.max_model_len,
151-
args.enforce_eager, args.kv_cache_dtype,
152-
args.quantization_param_path, args.device,
153-
args.enable_prefix_caching, args.enable_chunked_prefill,
154-
args.max_num_batched_tokens, args.gpu_memory_utilization,
155-
args.download_dir)
145+
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
146+
args.quantization, args.tensor_parallel_size,
147+
args.seed, args.n, args.trust_remote_code,
148+
args.dtype, args.max_model_len,
149+
args.enforce_eager, args.kv_cache_dtype,
150+
args.quantization_param_path, args.device,
151+
args.enable_prefix_caching,
152+
args.enable_chunked_prefill,
153+
args.max_num_batched_tokens,
154+
args.gpu_memory_utilization, args.download_dir)
156155
else:
157156
raise ValueError(f"Unknown backend: {args.backend}")
158157
total_num_tokens = sum(prompt_len + output_len
@@ -203,7 +202,6 @@ def main(args: argparse.Namespace):
203202
type=int,
204203
default=1,
205204
help="Number of generated sequences per prompt.")
206-
parser.add_argument("--use-beam-search", action="store_true")
207205
parser.add_argument("--num-prompts",
208206
type=int,
209207
default=200,

benchmarks/benchmark_serving.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,6 @@ async def benchmark(
391391
input_requests: List[Tuple[str, int, int]],
392392
logprobs: Optional[int],
393393
best_of: int,
394-
use_beam_search: bool,
395394
request_rate: float,
396395
disable_tqdm: bool,
397396
profile: bool,
@@ -419,7 +418,6 @@ async def benchmark(
419418
output_len=test_output_len,
420419
logprobs=logprobs,
421420
best_of=best_of,
422-
use_beam_search=use_beam_search,
423421
multi_modal_content=test_mm_content,
424422
ignore_eos=ignore_eos,
425423
)
@@ -441,7 +439,6 @@ async def benchmark(
441439
output_len=test_output_len,
442440
logprobs=logprobs,
443441
best_of=best_of,
444-
use_beam_search=use_beam_search,
445442
multi_modal_content=test_mm_content,
446443
)
447444
profile_output = await request_func(request_func_input=profile_input)
@@ -464,7 +461,6 @@ async def benchmark(
464461
output_len=output_len,
465462
logprobs=logprobs,
466463
best_of=best_of,
467-
use_beam_search=use_beam_search,
468464
multi_modal_content=mm_content,
469465
)
470466
tasks.append(
@@ -483,7 +479,6 @@ async def benchmark(
483479
output_len=test_output_len,
484480
logprobs=logprobs,
485481
best_of=best_of,
486-
use_beam_search=use_beam_search,
487482
)
488483
profile_output = await request_func(request_func_input=profile_input)
489484
if profile_output.success:
@@ -679,7 +674,6 @@ def main(args: argparse.Namespace):
679674
input_requests=input_requests,
680675
logprobs=args.logprobs,
681676
best_of=args.best_of,
682-
use_beam_search=args.use_beam_search,
683677
request_rate=args.request_rate,
684678
disable_tqdm=args.disable_tqdm,
685679
profile=args.profile,
@@ -701,7 +695,6 @@ def main(args: argparse.Namespace):
701695
result_json["model_id"] = model_id
702696
result_json["tokenizer_id"] = tokenizer_id
703697
result_json["best_of"] = args.best_of
704-
result_json["use_beam_search"] = args.use_beam_search
705698
result_json["num_prompts"] = args.num_prompts
706699

707700
# Metadata

benchmarks/benchmark_throughput.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def run_vllm(
7373
tensor_parallel_size: int,
7474
seed: int,
7575
n: int,
76-
use_beam_search: bool,
7776
trust_remote_code: bool,
7877
dtype: str,
7978
max_model_len: Optional[int],
@@ -91,7 +90,6 @@ def run_vllm(
9190
download_dir: Optional[str] = None,
9291
load_format: str = EngineArgs.load_format,
9392
disable_async_output_proc: bool = False,
94-
use_new_beam_search_impl: bool = False,
9593
) -> float:
9694
from vllm import LLM, SamplingParams
9795
llm = LLM(
@@ -127,19 +125,19 @@ def run_vllm(
127125
sampling_params.append(
128126
SamplingParams(
129127
n=n,
130-
temperature=0.0 if use_beam_search else 1.0,
128+
temperature=1.0,
131129
top_p=1.0,
132-
use_beam_search=use_beam_search,
133130
ignore_eos=True,
134131
max_tokens=output_len,
135132
))
136133

137-
if not use_new_beam_search_impl:
134+
use_beam_search = False
135+
136+
if not use_beam_search:
138137
start = time.perf_counter()
139138
llm.generate(prompts, sampling_params, use_tqdm=True)
140139
end = time.perf_counter()
141140
else:
142-
assert use_beam_search
143141
prompts = [prompt for prompt, _, _ in requests]
144142
# output_len should be the same for all requests.
145143
output_len = requests[0][2]
@@ -165,7 +163,6 @@ async def run_vllm_async(
165163
tensor_parallel_size: int,
166164
seed: int,
167165
n: int,
168-
use_beam_search: bool,
169166
trust_remote_code: bool,
170167
dtype: str,
171168
max_model_len: Optional[int],
@@ -224,9 +221,8 @@ async def run_vllm_async(
224221
sampling_params.append(
225222
SamplingParams(
226223
n=n,
227-
temperature=0.0 if use_beam_search else 1.0,
224+
temperature=1.0,
228225
top_p=1.0,
229-
use_beam_search=use_beam_search,
230226
ignore_eos=True,
231227
max_tokens=output_len,
232228
))
@@ -248,11 +244,9 @@ def run_hf(
248244
model: str,
249245
tokenizer: PreTrainedTokenizerBase,
250246
n: int,
251-
use_beam_search: bool,
252247
max_batch_size: int,
253248
trust_remote_code: bool,
254249
) -> float:
255-
assert not use_beam_search
256250
llm = AutoModelForCausalLM.from_pretrained(
257251
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
258252
if llm.config.model_type == "llama":
@@ -284,7 +278,7 @@ def run_hf(
284278
padding=True).input_ids
285279
llm_outputs = llm.generate(
286280
input_ids=input_ids.cuda(),
287-
do_sample=not use_beam_search,
281+
do_sample=True,
288282
num_return_sequences=n,
289283
temperature=1.0,
290284
top_p=1.0,
@@ -340,7 +334,7 @@ def main(args: argparse.Namespace):
340334
if args.backend == "vllm":
341335
run_args = [
342336
requests, args.model, args.tokenizer, args.quantization,
343-
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
337+
args.tensor_parallel_size, args.seed, args.n,
344338
args.trust_remote_code, args.dtype, args.max_model_len,
345339
args.enforce_eager, args.kv_cache_dtype,
346340
args.quantization_param_path, args.device,
@@ -355,12 +349,11 @@ def main(args: argparse.Namespace):
355349
run_args.append(args.disable_frontend_multiprocessing)
356350
elapsed_time = uvloop.run(run_vllm_async(*run_args))
357351
else:
358-
elapsed_time = run_vllm(*run_args, args.use_new_beam_search_impl)
352+
elapsed_time = run_vllm(*run_args)
359353
elif args.backend == "hf":
360354
assert args.tensor_parallel_size == 1
361355
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
362-
args.use_beam_search, args.hf_max_batch_size,
363-
args.trust_remote_code)
356+
args.hf_max_batch_size, args.trust_remote_code)
364357
elif args.backend == "mii":
365358
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
366359
args.output_len)
@@ -414,8 +407,6 @@ def main(args: argparse.Namespace):
414407
type=int,
415408
default=1,
416409
help="Number of generated sequences per prompt.")
417-
parser.add_argument("--use-beam-search", action="store_true")
418-
parser.add_argument("--use-new-beam-search-impl", action="store_true")
419410
parser.add_argument("--num-prompts",
420411
type=int,
421412
default=1000,
@@ -570,8 +561,6 @@ def main(args: argparse.Namespace):
570561
raise ValueError("dtype must be auto for MII backend.")
571562
if args.n != 1:
572563
raise ValueError("n must be 1 for MII backend.")
573-
if args.use_beam_search:
574-
raise ValueError("Beam search is not supported for MII backend.")
575564
if args.quantization is not None:
576565
raise ValueError("Quantization is only for vLLM backend.")
577566
if args.hf_max_batch_size is not None:

examples/llm_engine_example.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@ def create_test_prompts() -> List[Tuple[str, SamplingParams]]:
1818
temperature=0.8,
1919
top_p=0.95,
2020
frequency_penalty=0.1)),
21-
("It is only with the heart that one can see rightly",
22-
SamplingParams(n=3, best_of=3, use_beam_search=True,
23-
temperature=0.0)),
2421
]
2522

2623

examples/multilora_inference.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,6 @@ def create_test_prompts(
4343
max_tokens=128,
4444
stop_token_ids=[32003]),
4545
LoRARequest("sql-lora", 1, lora_path)),
46-
(
47-
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
48-
SamplingParams(n=3,
49-
best_of=3,
50-
use_beam_search=True,
51-
temperature=0,
52-
max_tokens=128,
53-
stop_token_ids=[32003]),
54-
LoRARequest("sql-lora", 1, lora_path)),
5546
(
5647
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
5748
SamplingParams(temperature=0.0,
@@ -60,15 +51,6 @@ def create_test_prompts(
6051
max_tokens=128,
6152
stop_token_ids=[32003]),
6253
LoRARequest("sql-lora2", 2, lora_path)),
63-
(
64-
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
65-
SamplingParams(n=3,
66-
best_of=3,
67-
use_beam_search=True,
68-
temperature=0,
69-
max_tokens=128,
70-
stop_token_ids=[32003]),
71-
LoRARequest("sql-lora", 1, lora_path)),
7254
]
7355

7456

0 commit comments

Comments
 (0)