|
6 | 6 | import tempfile
|
7 | 7 | from enum import Enum, auto
|
8 | 8 | from io import BufferedWriter
|
9 |
| -from typing import IO, Any, Sequence |
| 9 | +from typing import IO, Any, Sequence, Mapping |
| 10 | +from string import ascii_letters, digits |
10 | 11 |
|
11 | 12 | import numpy as np
|
12 | 13 |
|
@@ -466,7 +467,33 @@ def add_add_eos_token(self, value: bool) -> None:
|
466 | 467 | def add_add_space_prefix(self, value: bool) -> None:
|
467 | 468 | self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
|
468 | 469 |
|
469 |
| - def add_chat_template(self, value: str) -> None: |
| 470 | + def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None: |
| 471 | + if isinstance(value, list): |
| 472 | + template_default = None |
| 473 | + template_names = set() |
| 474 | + |
| 475 | + for choice in value: |
| 476 | + name = choice.get('name', '') |
| 477 | + template = choice.get('template') |
| 478 | + |
| 479 | + # Allowing non-alphanumerical characters in template name is probably not a good idea, so filter it |
| 480 | + name = ''.join((c for c in name if c in ['_'] + list(ascii_letters) + list(digits))) |
| 481 | + |
| 482 | + if name and template is not None: |
| 483 | + if name == 'default': |
| 484 | + template_default = template |
| 485 | + else: |
| 486 | + template_names.add(name) |
| 487 | + self.add_string(Keys.Tokenizer.CHAT_TEMPLATE_N.format(name=name), template) |
| 488 | + |
| 489 | + if template_names: |
| 490 | + self.add_array(Keys.Tokenizer.CHAT_TEMPLATES, list(template_names)) |
| 491 | + |
| 492 | + if template_default is None: |
| 493 | + return |
| 494 | + |
| 495 | + value = template_default |
| 496 | + |
470 | 497 | self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)
|
471 | 498 |
|
472 | 499 | def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
|
|
0 commit comments