From 48c3fbc745db6b86c9f7ac75bd334ec7a9d33ec0 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Fri, 21 Feb 2025 16:02:29 -0800 Subject: [PATCH] one model to rule them all --- guidance/experimental/test.ipynb | 139 ++++++++++++++++-- guidance/models/_base/__init__.py | 6 +- guidance/models/_base/_client.py | 13 ++ guidance/models/_base/_model.py | 34 ++--- guidance/models/_base/_state.py | 2 +- guidance/models/_engine/__init__.py | 9 +- .../models/_engine/{_model.py => _client.py} | 16 +- guidance/models/_engine/_engine.py | 2 +- guidance/models/_engine/_state.py | 23 ++- guidance/models/_mock.py | 38 +++-- guidance/models/_openai.py | 75 +++++----- guidance/models/llama_cpp/_llama_cpp.py | 10 +- guidance/models/transformers/_model.py | 23 ++- guidance/models/transformers/_state.py | 21 --- 14 files changed, 282 insertions(+), 129 deletions(-) create mode 100644 guidance/models/_base/_client.py rename guidance/models/_engine/{_model.py => _client.py} (91%) delete mode 100644 guidance/models/transformers/_state.py diff --git a/guidance/experimental/test.ipynb b/guidance/experimental/test.ipynb index e320175fb..fae88a82a 100644 --- a/guidance/experimental/test.ipynb +++ b/guidance/experimental/test.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -13,20 +13,88 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.\n", + "Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3cf50b682b7547e69f33d0e751158fad", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00\\n\\n\\n …" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Failed to get per token stats: 0\n" + ] + } + ], "source": [ - "# model = models.Transformers(\"microsoft/Phi-3-mini-4k-instruct\", trust_remote_code=True, attn_implementation=\"eager\")\n", - "model = models.OpenAI(\"gpt-4o-mini\")\n", - "# model = models.LlamaCpp(model_path)" + "model = models.Transformers(\"microsoft/Phi-3-mini-4k-instruct\", trust_remote_code=True, attn_implementation=\"eager\")\n", + "# model = models.OpenAI(\"gpt-4o-mini\")\n", + "# model = models.LlamaCpp(\"hf_hub://lmstudio-community/Phi-3.1-mini-4k-instruct-GGUF/Phi-3.1-mini-4k-instruct-IQ3_M.gguf\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ea540eb74575423085c6cb531e93f13c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "StitchWidget(initial_height='auto', initial_width='100%', srcdoc='\\n\\n\\n …" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "You are not running the flash-attention implementation, expect numerical differences.\n", + "gpustat is not installed, run `pip install gpustat` to collect GPU stats.\n", + "Failed to get per token stats: 30\n" + ] + } + ], "source": [ "lm = model\n", "with system():\n", @@ -61,27 +129,70 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "<|system|>\n", + "Talk like a pirate!<|end|>\n", + "<|user|>\n", + "Hello, model!<|end|>\n", + "<|assistant|>\n", + " Arrr matey!<|end|><|end|>\n", + "<|user|>\n", + "What is the capital of France?<|end|>\n", + "<|assistant|>\n", + " The capital of France be Paris, arrr! <|end|>\n" + ] + } + ], "source": [ "print(lm)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "({'role': 'system', 'data': {'content': 'Talk like a pirate!'}},\n", + " {'role': 'user', 'data': {'content': 'Hello, model!'}},\n", + " {'role': 'assistant', 'data': {'content': ' Arrr matey!<|end|>'}},\n", + " {'role': 'user', 'data': {'content': 'What is the capital of France?'}})" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "lm._state.messages" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'role': 'assistant',\n", + " 'data': {'content': ' The capital of France be Paris, arrr! <|end|>'}}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "lm._state.active_message" ] diff --git a/guidance/models/_base/__init__.py b/guidance/models/_base/__init__.py index 48169f4e1..df9b6a800 100644 --- a/guidance/models/_base/__init__.py +++ b/guidance/models/_base/__init__.py @@ -1,9 +1,11 @@ +from ._client import Client from ._model import Model -from ._state import BaseState, Message +from ._state import Message, State __all__ = [ "Model", "role", - "BaseState", + "State", "Message", + "Client", ] diff --git a/guidance/models/_base/_client.py b/guidance/models/_base/_client.py new file mode 100644 index 000000000..8bbc035ee --- /dev/null +++ b/guidance/models/_base/_client.py @@ -0,0 +1,13 @@ +from abc import ABC, abstractmethod +from typing import Generic, Iterator, TypeVar + +from ...experimental.ast import MessageChunk, Node +from ._state import State + +S = TypeVar("S", bound=State) + + +class Client(ABC, Generic[S]): + @abstractmethod + def run(self, state: S, node: Node) -> Iterator[MessageChunk]: + pass diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index 7ad559ee7..806b43364 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -1,11 +1,10 @@ # TODO(nopdive): This module requires a memory review. import re -from abc import ABC, abstractmethod from base64 import b64encode from contextvars import ContextVar from copy import deepcopy -from typing import TYPE_CHECKING, Any, Generic, Iterator, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union from typing_extensions import Self, assert_never @@ -23,7 +22,8 @@ TraceNode, ) from ...visual import TraceMessage -from ._state import BaseState +from ._client import Client +from ._state import State _active_role: ContextVar[Optional["RoleStart"]] = ContextVar("active_role", default=None) _id_counter: int = 0 @@ -37,17 +37,20 @@ def _gen_id(): return _id -S = TypeVar("S", bound=BaseState) +S = TypeVar("S", bound=State) D = TypeVar("D", bound=Any) -class Model(ABC, Generic[S]): +class Model(Generic[S]): def __init__( self, + client: Client[S], + state: S, echo: bool = True, ) -> None: self.echo = echo - self._state = self.initial_state() + self._client = client + self._state = state self._active_role: Optional["RoleStart"] = None self._parent: Optional["Model"] = None @@ -56,14 +59,6 @@ def __init__( self._trace_nodes: set[TraceNode] = set() self._update_trace_node(self._id, self._parent_id, None) - @abstractmethod - def run(self, state: S, node: Node) -> Iterator[MessageChunk]: - pass - - @abstractmethod - def initial_state(self) -> S: - pass - def _update_trace_node( self, identifier: int, parent_id: Optional[int], node_attr: Optional[NodeAttr] = None ) -> None: @@ -91,7 +86,7 @@ def __add__(self, other: Node) -> Self: return self def _apply_node(self, node: Node) -> Self: - for chunk in self.run(self._state, node): + for chunk in self._client.run(self._state, node): self = self._apply_chunk(chunk) return self @@ -246,12 +241,3 @@ def extract_embedded_nodes(value: str) -> Node: grammar += part is_id = not is_id return grammar - - -def partial_decode(data: bytes) -> tuple[str, bytes]: - try: - return (data.decode("utf-8"), b"") - except UnicodeDecodeError as e: - valid_part = data[: e.start].decode("utf-8") - delayed_part = data[e.start :] - return (valid_part, delayed_part) diff --git a/guidance/models/_base/_state.py b/guidance/models/_base/_state.py index 1842e3a7d..51168c869 100644 --- a/guidance/models/_base/_state.py +++ b/guidance/models/_base/_state.py @@ -23,7 +23,7 @@ class CaptureVar(TypedDict): log_prob: Optional[float] -class BaseState(ABC): +class State(ABC): def __init__(self) -> None: self.chunks: list[MessageChunk] = [] self.captures: dict[str, Union[CaptureVar, list[CaptureVar]]] = {} diff --git a/guidance/models/_engine/__init__.py b/guidance/models/_engine/__init__.py index 68cab2779..e2a0454ec 100644 --- a/guidance/models/_engine/__init__.py +++ b/guidance/models/_engine/__init__.py @@ -1,12 +1,13 @@ from ._tokenizer import Tokenizer # isort:skip +from ._client import EngineClient from ._engine import Engine -from ._model import Model, ModelWithEngine -from ._state import EngineState +from ._state import EngineState, Llama3VisionState, Phi3VisionState __all__ = [ "Tokenizer", "Engine", - "Model", - "ModelWithEngine", + "EngineClient", "EngineState", + "Llama3VisionState", + "Phi3VisionState", ] diff --git a/guidance/models/_engine/_model.py b/guidance/models/_engine/_client.py similarity index 91% rename from guidance/models/_engine/_model.py rename to guidance/models/_engine/_client.py index d4b8259e1..aee9fbe75 100644 --- a/guidance/models/_engine/_model.py +++ b/guidance/models/_engine/_client.py @@ -9,15 +9,14 @@ RoleOpenerInput, TextOutput, ) -from .._base._model import Model, partial_decode +from .._base import Client from ._engine import Engine from ._state import EngineState -class ModelWithEngine(Model[EngineState]): - def __init__(self, engine: Engine, echo: bool = True): +class EngineClient(Client[EngineState]): + def __init__(self, engine: Engine): self.engine = engine - super().__init__(echo=echo) def run(self, state: EngineState, node: Node) -> Iterator[MessageChunk]: if isinstance(node, str): @@ -107,3 +106,12 @@ def initial_state(self) -> EngineState: # TODO: for llama_cpp and transformers, we need to provide an interface # for getting these from something like a model id..? return EngineState() + + +def partial_decode(data: bytes) -> tuple[str, bytes]: + try: + return (data.decode("utf-8"), b"") + except UnicodeDecodeError as e: + valid_part = data[: e.start].decode("utf-8") + delayed_part = data[e.start :] + return (valid_part, delayed_part) diff --git a/guidance/models/_engine/_engine.py b/guidance/models/_engine/_engine.py index a3166aa50..e1604779b 100644 --- a/guidance/models/_engine/_engine.py +++ b/guidance/models/_engine/_engine.py @@ -3,7 +3,7 @@ import logging import time import weakref -from abc import ABC, abstractmethod +from abc import ABC from asyncio import CancelledError from enum import Enum from multiprocessing import Manager, Process diff --git a/guidance/models/_engine/_state.py b/guidance/models/_engine/_state.py index e235a9de7..2b2340aba 100644 --- a/guidance/models/_engine/_state.py +++ b/guidance/models/_engine/_state.py @@ -1,9 +1,10 @@ from typing import Any -from .._base import BaseState +from ...experimental.ast import ImageBlob +from .._base import State -class EngineState(BaseState): +class EngineState(State): def __init__(self) -> None: super().__init__() self.images: list[Any] = [] @@ -16,3 +17,21 @@ def apply_text(self, text: str) -> None: else: self.active_message["data"]["content"] += text self.text += text + + +class Llama3VisionState(EngineState): + def apply_image(self, image: ImageBlob) -> None: + self.images.append(image.image) + text = "<|image|>" + EngineState.apply_text(self, text) + + +class Phi3VisionState(EngineState): + def apply_image(self, image: ImageBlob) -> None: + pil_image = image.image + if pil_image in self.images: + ix = self.images.index(pil_image) + 1 + else: + self.images.append(pil_image) + ix = len(self.images) + EngineState.apply_text(self, f"<|image_{ix}|>") diff --git a/guidance/models/_mock.py b/guidance/models/_mock.py index 099cdac2c..37980b967 100644 --- a/guidance/models/_mock.py +++ b/guidance/models/_mock.py @@ -1,17 +1,16 @@ -from typing import Sequence, Optional -import numpy as np import logging +from typing import Optional, Sequence +import numpy as np -from .._utils import softmax from .._schema import EngineOutput, GenToken, GenTokenExtra -from ..visual._renderer import DoNothingRenderer +from .._utils import softmax from ..trace import TraceHandler - -from ._engine import ModelWithEngine, Engine, Tokenizer +from ..visual._renderer import DoNothingRenderer +from ._base import Model +from ._engine import Engine, EngineClient, EngineState, Tokenizer from ._remote import RemoteEngine - logger = logging.getLogger(__name__) # TODO: this import pattern happens in a few places, should be cleaned up @@ -64,7 +63,12 @@ def recode(self, tokens: Sequence[int]) -> list[int]: class MockEngine(Engine): def __init__(self, tokenizer, byte_patterns, compute_log_probs, force): renderer = DoNothingRenderer(trace_handler=TraceHandler()) - super().__init__(tokenizer, compute_log_probs=compute_log_probs, enable_monitoring=False, renderer=renderer) + super().__init__( + tokenizer, + compute_log_probs=compute_log_probs, + enable_monitoring=False, + renderer=renderer, + ) self._valid_mask = np.zeros(len(tokenizer.tokens)) for i, t in enumerate(tokenizer.tokens): @@ -105,7 +109,9 @@ def get_next_token_with_top_k( force_return_unmasked_probs: bool = False, ) -> EngineOutput: self.called_temperatures.append(temperature) - return super().get_next_token_with_top_k(logits, logits_lat_ms, token_ids, mask, temperature, k, force_return_unmasked_probs) + return super().get_next_token_with_top_k( + logits, logits_lat_ms, token_ids, mask, temperature, k, force_return_unmasked_probs + ) def get_logits(self, token_ids: list[int]) -> np.ndarray: """Pretends to compute the logits for the given token state.""" @@ -134,11 +140,13 @@ def get_logits(self, token_ids: list[int]) -> np.ndarray: return logits - def get_per_token_topk_probs(self, token_ids: list[int], top_k: int = 5) -> list[GenTokenExtra]: + def get_per_token_topk_probs( + self, token_ids: list[int], top_k: int = 5 + ) -> list[GenTokenExtra]: result_list = [] if len(token_ids) == 0: return result_list - + added_bos = False if self.tokenizer.bos_token is not None and token_ids[0] != self.tokenizer.bos_token_id: token_ids = [self.tokenizer.bos_token_id] + token_ids @@ -209,7 +217,7 @@ def _get_next_tokens(self, byte_string): yield i -class Mock(ModelWithEngine): +class Mock(Model): def __init__( self, byte_patterns=[], @@ -233,7 +241,11 @@ def __init__( tokenizer = MockTokenizer(tokens) engine = MockEngine(tokenizer, byte_patterns, compute_log_probs, force) - super().__init__(engine, echo=echo) + super().__init__( + client=EngineClient(engine), + state=EngineState(), + echo=echo, + ) # class MockChat(Mock, Chat): diff --git a/guidance/models/_openai.py b/guidance/models/_openai.py index a7653560d..5431b08d3 100644 --- a/guidance/models/_openai.py +++ b/guidance/models/_openai.py @@ -5,18 +5,9 @@ from .._grammar import Gen, Join from ..experimental.ast import ContentChunk, ImageBlob, Node, RoleEnd, RoleStart from ..trace import LiteralInput, TextOutput, RoleOpenerInput, RoleCloserInput -from ._base import Model, BaseState - -class OpenAIState(BaseState): - @classmethod - def from_model_id(cls, model_id: str) -> "OpenAIState": - if "audio-preview" in model_id: - return OpenAIAudioState() - if model_id.startswith("gpt-4o") or model_id.startswith("o1"): - return OpenAIImageState() - else: - return OpenAIState() +from ._base import Model, State, Client +class OpenAIState(State): def apply_content_chunk(self, chunk: ContentChunk) -> None: if self.active_message["role"] is None: raise ValueError("OpenAI models require chat blocks (e.g. use `with assistant(): ...`)") @@ -54,42 +45,21 @@ def __init__(self) -> None: raise NotImplementedError("OpenAI audio not yet implemented") -class OpenAI(Model): +class OpenAIClient(Client[OpenAIState]): def __init__( self, model: str, - echo: bool = True, api_key: Optional[str] = None, **kwargs, ): - """Build a new OpenAI model object that represents a model in a given state. - - Parameters - ---------- - model : str - The name of the OpenAI model to use (e.g. gpt-4o-mini). - echo : bool - If true the final result of creating this model state will be displayed (as HTML in a notebook). - api_key : None or str - The OpenAI API key to use for remote requests, passed directly to the `openai.OpenAI` constructor. - - **kwargs : - All extra keyword arguments are passed directly to the `openai.OpenAI` constructor. Commonly used argument - names include `base_url` and `organization` - """ - try: import openai except ImportError: raise Exception( "Please install the openai package version >= 1 using `pip install openai -U` in order to use guidance.models.OpenAI!" ) - self.client = openai.OpenAI(api_key=api_key, **kwargs) self.model = model - super().__init__(echo=echo) - - def initial_state(self) -> OpenAIState: - return OpenAIState.from_model_id(self.model) + self.client = openai.OpenAI(api_key=api_key, **kwargs) def run( self, state: OpenAIState, node: Node @@ -180,3 +150,40 @@ def inner(node): ) yield from inner(node) + +class OpenAI(Model): + def __init__( + self, + model: str, + echo: bool = True, + api_key: Optional[str] = None, + **kwargs, + ): + """Build a new OpenAI model object that represents a model in a given state. + + Parameters + ---------- + model : str + The name of the OpenAI model to use (e.g. gpt-4o-mini). + echo : bool + If true the final result of creating this model state will be displayed (as HTML in a notebook). + api_key : None or str + The OpenAI API key to use for remote requests, passed directly to the `openai.OpenAI` constructor. + + **kwargs : + All extra keyword arguments are passed directly to the `openai.OpenAI` constructor. Commonly used argument + names include `base_url` and `organization` + """ + + if model.startswith("gpt-4o") or model.startswith("o1"): + state = OpenAIImageState() + elif "audio-preview" in model: + state = OpenAIAudioState() + else: + state = OpenAIState() + + super().__init__( + client = OpenAIClient(model, api_key=api_key, **kwargs), + state = state, + echo=echo + ) diff --git a/guidance/models/llama_cpp/_llama_cpp.py b/guidance/models/llama_cpp/_llama_cpp.py index 2d311f22f..3681e9c76 100644 --- a/guidance/models/llama_cpp/_llama_cpp.py +++ b/guidance/models/llama_cpp/_llama_cpp.py @@ -11,8 +11,8 @@ from ..._schema import GenToken, GenTokenExtra from ..._utils import normalize_notebook_stdout_stderr, softmax -from .._base import Message -from .._engine import Engine, ModelWithEngine, Tokenizer +from .._base import Message, Model +from .._engine import Engine, EngineClient, EngineState, Tokenizer from .._remote import RemoteEngine try: @@ -359,7 +359,7 @@ def apply_chat_template( ) -class LlamaCpp(ModelWithEngine): +class LlamaCpp(Model): def __init__( self, model=None, @@ -386,7 +386,9 @@ def __init__( enable_monitoring=enable_monitoring, **llama_cpp_kwargs, ) - super().__init__(engine, echo=echo) + state = EngineState() + client = EngineClient(engine) + super().__init__(client=client, state=state, echo=echo) def get_chat_formatter(model_obj: "Llama") -> "Jinja2ChatFormatter": diff --git a/guidance/models/transformers/_model.py b/guidance/models/transformers/_model.py index 397fee11f..007a00525 100644 --- a/guidance/models/transformers/_model.py +++ b/guidance/models/transformers/_model.py @@ -1,9 +1,11 @@ -from .._engine import ModelWithEngine +import re + +from .._base import Model +from .._engine import EngineClient, EngineState, Llama3VisionState, Phi3VisionState from ._engine import TransformersEngine -# TODO: Expose a non-chat version -class Transformers(ModelWithEngine): +class Transformers(Model): def __init__( self, model=None, @@ -17,7 +19,14 @@ def __init__( **kwargs, ): """Build a new Transformers model object that represents a model in a given state.""" - super().__init__( + if re.search("Llama-3.*-Vision", model): + state = Llama3VisionState() + elif re.search("Phi-3-vision", model): + state = Phi3VisionState() + else: + state = EngineState() + + client = EngineClient( TransformersEngine( model, tokenizer, @@ -27,6 +36,10 @@ def __init__( enable_ff_tokens=enable_ff_tokens, enable_monitoring=enable_monitoring, **kwargs, - ), + ) + ) + super().__init__( + client=client, + state=state, echo=echo, ) diff --git a/guidance/models/transformers/_state.py b/guidance/models/transformers/_state.py deleted file mode 100644 index 80dd80508..000000000 --- a/guidance/models/transformers/_state.py +++ /dev/null @@ -1,21 +0,0 @@ -from guidance.models._engine import EngineState - -from ...experimental.ast import ImageBlob - - -class Llama3(EngineState): - def apply_image(self, image: ImageBlob) -> None: - self.images.append(image.image) - text = "<|image|>" - EngineState.apply_text(self, text) - - -class Phi3(EngineState): - def apply_image(self, image: ImageBlob) -> None: - pil_image = image.image - if pil_image in self.images: - ix = self.images.index(pil_image) + 1 - else: - self.images.append(pil_image) - ix = len(self.images) - EngineState.apply_text(self, f"<|image_{ix}|>")