|
| 1 | +import itertools |
1 | 2 | from contextlib import contextmanager
|
2 |
| -from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast, |
3 |
| - overload) |
| 3 | +from dataclasses import dataclass |
| 4 | +from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, |
| 5 | + Union, cast, overload) |
4 | 6 |
|
5 | 7 | from tqdm import tqdm
|
6 | 8 |
|
|
30 | 32 | logger = init_logger(__name__)
|
31 | 33 |
|
32 | 34 |
|
| 35 | +@dataclass |
| 36 | +class BeamSearchSequence: |
| 37 | + """A sequence for beam search. |
| 38 | + It keeps track of the tokens and the log probability of the sequence. |
| 39 | + The text field is optional and will only be filled when the sequence is |
| 40 | + about to be returned to the user. |
| 41 | + """ |
| 42 | + # The tokens includes the prompt. |
| 43 | + tokens: List[int] |
| 44 | + cum_logprob: float = 0.0 |
| 45 | + text: Optional[str] = None |
| 46 | + |
| 47 | + |
| 48 | +@dataclass |
| 49 | +class BeamSearchOutput: |
| 50 | + """The output of beam search. |
| 51 | + It contains the list of the best beam search sequences. |
| 52 | + The length of the list is equal to the beam width. |
| 53 | + """ |
| 54 | + sequences: List[BeamSearchSequence] |
| 55 | + |
| 56 | + |
| 57 | +class BeamSearchInstance: |
| 58 | + |
| 59 | + def __init__(self, prompt_tokens: List[int]): |
| 60 | + self.beams: List[BeamSearchSequence] = [ |
| 61 | + BeamSearchSequence(tokens=prompt_tokens) |
| 62 | + ] |
| 63 | + self.completed: List[BeamSearchSequence] = [] |
| 64 | + |
| 65 | + |
33 | 66 | class LLM:
|
34 | 67 | """An LLM for generating texts from given prompts and sampling parameters.
|
35 | 68 |
|
@@ -354,6 +387,105 @@ def generate(
|
354 | 387 | outputs = self._run_engine(use_tqdm=use_tqdm)
|
355 | 388 | return LLMEngine.validate_outputs(outputs, RequestOutput)
|
356 | 389 |
|
| 390 | + def beam_search( |
| 391 | + self, |
| 392 | + prompts: List[Union[str, List[int]]], |
| 393 | + beam_width: int, |
| 394 | + max_tokens: int, |
| 395 | + ignore_eos: bool = False, |
| 396 | + ) -> List[BeamSearchOutput]: |
| 397 | + """ |
| 398 | + Generate sequences using beam search. |
| 399 | +
|
| 400 | + Args: |
| 401 | + prompts: A list of prompts. Each prompt can be a string or a list |
| 402 | + of token IDs. |
| 403 | + beam_width: The number of beams to keep at each step. |
| 404 | + max_tokens: The max number of tokens to generate for each prompt. |
| 405 | + |
| 406 | + TODO: how does beam search work together with length penalty, frequency |
| 407 | + penalty, and stopping criteria, etc.? |
| 408 | + """ |
| 409 | + |
| 410 | + tokenizer = self.get_tokenizer() |
| 411 | + # generate 2 * beam_width candidates at each step |
| 412 | + # following the huggingface transformers implementation |
| 413 | + # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa |
| 414 | + beam_search_params = SamplingParams(logprobs=2 * beam_width, |
| 415 | + max_tokens=1, |
| 416 | + temperature=0.0) |
| 417 | + instances: List[BeamSearchInstance] = [] |
| 418 | + |
| 419 | + for prompt in prompts: |
| 420 | + prompt_tokens = prompt if isinstance( |
| 421 | + prompt, list) else tokenizer.encode(prompt) |
| 422 | + instances.append(BeamSearchInstance(prompt_tokens)) |
| 423 | + |
| 424 | + for _ in range(max_tokens): |
| 425 | + all_beams: List[BeamSearchSequence] = list( |
| 426 | + sum((instance.beams for instance in instances), [])) |
| 427 | + pos = [0] + list( |
| 428 | + itertools.accumulate( |
| 429 | + len(instance.beams) for instance in instances)) |
| 430 | + instance_start_and_end: List[Tuple[int, int]] = list( |
| 431 | + zip(pos[:-1], pos[1:])) |
| 432 | + |
| 433 | + if len(all_beams) == 0: |
| 434 | + break |
| 435 | + |
| 436 | + prompts_batch = [ |
| 437 | + TokensPrompt(prompt_token_ids=beam.tokens) |
| 438 | + for beam in all_beams |
| 439 | + ] |
| 440 | + |
| 441 | + # only runs for one step |
| 442 | + # we don't need to use tqdm here |
| 443 | + output = self.generate(prompts_batch, |
| 444 | + sampling_params=beam_search_params, |
| 445 | + use_tqdm=False) |
| 446 | + |
| 447 | + for (start, end), instance in zip(instance_start_and_end, |
| 448 | + instances): |
| 449 | + instance_new_beams = [] |
| 450 | + for i in range(start, end): |
| 451 | + current_beam = all_beams[i] |
| 452 | + result = output[i] |
| 453 | + |
| 454 | + if result.outputs[0].logprobs is not None: |
| 455 | + # if `result.outputs[0].logprobs` is None, it means |
| 456 | + # the sequence is completed because of the max-model-len |
| 457 | + # or abortion. we don't need to add it to the new beams. |
| 458 | + logprobs = result.outputs[0].logprobs[0] |
| 459 | + for token_id, logprob_obj in logprobs.items(): |
| 460 | + new_beam = BeamSearchSequence( |
| 461 | + tokens=current_beam.tokens + [token_id], |
| 462 | + cum_logprob=current_beam.cum_logprob + |
| 463 | + logprob_obj.logprob) |
| 464 | + |
| 465 | + if token_id == tokenizer.eos_token_id and \ |
| 466 | + not ignore_eos: |
| 467 | + instance.completed.append(new_beam) |
| 468 | + else: |
| 469 | + instance_new_beams.append(new_beam) |
| 470 | + sorted_beams = sorted(instance_new_beams, |
| 471 | + key=lambda x: x.cum_logprob, |
| 472 | + reverse=True) |
| 473 | + instance.beams = sorted_beams[:beam_width] |
| 474 | + |
| 475 | + outputs = [] |
| 476 | + for instance in instances: |
| 477 | + instance.completed.extend(instance.beams) |
| 478 | + sorted_completed = sorted(instance.completed, |
| 479 | + key=lambda x: x.cum_logprob, |
| 480 | + reverse=True) |
| 481 | + best_beams = sorted_completed[:beam_width] |
| 482 | + |
| 483 | + for beam in best_beams: |
| 484 | + beam.text = tokenizer.decode(beam.tokens) |
| 485 | + outputs.append(BeamSearchOutput(sequences=best_beams)) |
| 486 | + |
| 487 | + return outputs |
| 488 | + |
357 | 489 | def chat(
|
358 | 490 | self,
|
359 | 491 | messages: List[ChatCompletionMessageParam],
|
|
0 commit comments