diff --git a/prompts/templates.py b/prompts/templates.py index 094c385..762b8e7 100644 --- a/prompts/templates.py +++ b/prompts/templates.py @@ -1,14 +1,11 @@ import inspect import re -import warnings from dataclasses import dataclass, field from functools import lru_cache from typing import Callable, Dict, Hashable, Optional from jinja2 import Environment, StrictUndefined -from prompts.tokens import SPECIAL_TOKENS, Special - @dataclass class Template: @@ -148,10 +145,6 @@ def render( allow users to enter prompts more naturally than if they used Python's constructs directly. See the examples for a detailed explanation. - We also define the `bos` and `eos` special variables which, when used, will - be replaced by the model's BOS and EOS tokens respectively. This allows you - to write prompts that are model-agnostic. - Examples -------- @@ -252,28 +245,12 @@ def render( # used to continue to the next line without linebreak. cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template) - # Warn the user when the model is not present in the special token registry - if model_name not in SPECIAL_TOKENS: - warnings.warn( - UserWarning( - f"The model {model_name} is not present in the special token registry." - "As a result, EOS and BOS tokens will be rendered as the empty string." - "Please open an issue: https://github.com/outlines-dev/prompts/issues" - "And ask for the model to be added to the registry." - ) - ) - env = Environment( trim_blocks=True, lstrip_blocks=True, keep_trailing_newline=True, undefined=StrictUndefined, ) - env.globals["bos"] = SPECIAL_TOKENS.get(model_name, Special()).sequence.begin - env.globals["eos"] = SPECIAL_TOKENS.get(model_name, Special()).sequence.end - env.globals["user"] = SPECIAL_TOKENS.get(model_name, Special()).user - env.globals["assistant"] = SPECIAL_TOKENS.get(model_name, Special()).assistant - env.globals["system"] = SPECIAL_TOKENS.get(model_name, Special()).system jinja_template = env.from_string(cleaned_template) return jinja_template.render(**values) diff --git a/prompts/tokens.py b/prompts/tokens.py deleted file mode 100644 index c9f3a98..0000000 --- a/prompts/tokens.py +++ /dev/null @@ -1,29 +0,0 @@ -from dataclasses import dataclass, field -from typing import Dict, Optional - - -@dataclass -class Limits: - begin: str = "" - end: str = "" - - -@dataclass -class Special: - sequence: Limits = field(default_factory=lambda: Limits()) - user: Limits = field(default_factory=lambda: Limits()) - assistant: Limits = field(default_factory=lambda: Limits()) - system: Limits = field(default_factory=lambda: Limits()) - - -SPECIAL_TOKENS: Dict[Optional[str], Special] = { - None: Special(), - "google/gemma-2-9b": Special(Limits("", "")), - "openai-community/gpt2": Special(Limits("", "<|endoftext|>")), - "mistralai/Mistral-7B-v0.1": Special(Limits("", "")), - "mistralai/Mistral-7B-Instruct-v0.1": Special( - Limits("", ""), - Limits("[INST]", "[/INST]"), - Limits("", ""), - ), -} diff --git a/tests/test_templates.py b/tests/test_templates.py index 4784134..3032b36 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -192,24 +192,3 @@ def simple_prompt_name(query: str): assert simple_prompt("test") == "test" assert simple_prompt["gpt2"]("test") == "test" assert simple_prompt["provider/name"]("test") == "name: test" - - -def test_special_tokens(): - - @prompts.template - def simple_prompt(query: str): - return """{{ bos + query + eos }}""" - - assert simple_prompt("test") == "test" - assert simple_prompt["openai-community/gpt2"]("test") == "test<|endoftext|>" - assert simple_prompt["mistralai/Mistral-7B-v0.1"]("test") == "test" - - -def test_warn(): - - @prompts.template - def simple_prompt(): - return """test""" - - with pytest.warns(UserWarning, match="not present in the special token"): - simple_prompt["non-existent-model"]() diff --git a/tests/test_tokens.py b/tests/test_tokens.py deleted file mode 100644 index c9391e9..0000000 --- a/tests/test_tokens.py +++ /dev/null @@ -1,6 +0,0 @@ -from prompts.tokens import Special - - -def test_simple(): - special = Special() - assert special.assistant.begin == ""