|
2 | 2 | from dataclasses import dataclass
|
3 | 3 | from typing import Dict, List, Optional, Tuple
|
4 | 4 |
|
| 5 | +import numpy as np |
5 | 6 | import torch
|
6 | 7 |
|
7 | 8 | from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
|
@@ -457,16 +458,20 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
|
457 | 458 | if do_penalties:
|
458 | 459 | prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
|
459 | 460 | default=0)
|
460 |
| - prompt_padded_tokens = [ |
461 |
| - tokens + [vocab_size] * (prompt_max_len - len(tokens)) |
462 |
| - for tokens in prompt_tokens |
463 |
| - ] |
| 461 | + prompt_padded_tokens = np.full( |
| 462 | + (len(prompt_tokens), prompt_max_len), |
| 463 | + vocab_size, |
| 464 | + dtype=np.int64) |
| 465 | + for i, tokens in enumerate(prompt_tokens): |
| 466 | + prompt_padded_tokens[i, :len(tokens)] = tokens |
464 | 467 | output_max_len = max([len(tokens) for tokens in output_tokens],
|
465 | 468 | default=0)
|
466 |
| - output_padded_tokens = [ |
467 |
| - tokens + [vocab_size] * (output_max_len - len(tokens)) |
468 |
| - for tokens in output_tokens |
469 |
| - ] |
| 469 | + output_padded_tokens = np.full( |
| 470 | + (len(output_tokens), output_max_len), |
| 471 | + vocab_size, |
| 472 | + dtype=np.int64) |
| 473 | + for i, tokens in enumerate(output_tokens): |
| 474 | + output_padded_tokens[i, :len(tokens)] = tokens |
470 | 475 |
|
471 | 476 | temperatures_t = torch.tensor(
|
472 | 477 | temperatures,
|
@@ -517,18 +522,8 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
|
517 | 522 | pin_memory=pin_memory,
|
518 | 523 | )
|
519 | 524 | if do_penalties:
|
520 |
| - prompt_tensor = torch.tensor( |
521 |
| - prompt_padded_tokens, |
522 |
| - device="cpu", |
523 |
| - dtype=torch.long, |
524 |
| - pin_memory=pin_memory, |
525 |
| - ) |
526 |
| - output_tensor = torch.tensor( |
527 |
| - output_padded_tokens, |
528 |
| - device="cpu", |
529 |
| - dtype=torch.long, |
530 |
| - pin_memory=pin_memory, |
531 |
| - ) |
| 525 | + prompt_tensor = torch.from_numpy(prompt_padded_tokens) |
| 526 | + output_tensor = torch.from_numpy(output_padded_tokens) |
532 | 527 | else:
|
533 | 528 | prompt_tensor = None
|
534 | 529 | output_tensor = None
|
|
0 commit comments