Skip to content

Commit

Permalink
one model to rule them all
Browse files Browse the repository at this point in the history
  • Loading branch information
hudson-ai committed Feb 22, 2025
1 parent 03bc11d commit 48c3fbc
Show file tree
Hide file tree
Showing 14 changed files with 282 additions and 129 deletions.
139 changes: 125 additions & 14 deletions guidance/experimental/test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -13,20 +13,88 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.\n",
"Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3cf50b682b7547e69f33d0e751158fad",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7a289b0b8dbe4e3c9032e311412d9efa",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"StitchWidget(initial_height='auto', initial_width='100%', srcdoc='<!doctype html>\\n<html lang=\"en\">\\n<head>\\n …"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Failed to get per token stats: 0\n"
]
}
],
"source": [
"# model = models.Transformers(\"microsoft/Phi-3-mini-4k-instruct\", trust_remote_code=True, attn_implementation=\"eager\")\n",
"model = models.OpenAI(\"gpt-4o-mini\")\n",
"# model = models.LlamaCpp(model_path)"
"model = models.Transformers(\"microsoft/Phi-3-mini-4k-instruct\", trust_remote_code=True, attn_implementation=\"eager\")\n",
"# model = models.OpenAI(\"gpt-4o-mini\")\n",
"# model = models.LlamaCpp(\"hf_hub://lmstudio-community/Phi-3.1-mini-4k-instruct-GGUF/Phi-3.1-mini-4k-instruct-IQ3_M.gguf\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ea540eb74575423085c6cb531e93f13c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"StitchWidget(initial_height='auto', initial_width='100%', srcdoc='<!doctype html>\\n<html lang=\"en\">\\n<head>\\n …"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"You are not running the flash-attention implementation, expect numerical differences.\n",
"gpustat is not installed, run `pip install gpustat` to collect GPU stats.\n",
"Failed to get per token stats: 30\n"
]
}
],
"source": [
"lm = model\n",
"with system():\n",
Expand Down Expand Up @@ -61,27 +129,70 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<|system|>\n",
"Talk like a pirate!<|end|>\n",
"<|user|>\n",
"Hello, model!<|end|>\n",
"<|assistant|>\n",
" Arrr matey!<|end|><|end|>\n",
"<|user|>\n",
"What is the capital of France?<|end|>\n",
"<|assistant|>\n",
" The capital of France be Paris, arrr! <|end|>\n"
]
}
],
"source": [
"print(lm)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"({'role': 'system', 'data': {'content': 'Talk like a pirate!'}},\n",
" {'role': 'user', 'data': {'content': 'Hello, model!'}},\n",
" {'role': 'assistant', 'data': {'content': ' Arrr matey!<|end|>'}},\n",
" {'role': 'user', 'data': {'content': 'What is the capital of France?'}})"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lm._state.messages"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"{'role': 'assistant',\n",
" 'data': {'content': ' The capital of France be Paris, arrr! <|end|>'}}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lm._state.active_message"
]
Expand Down
6 changes: 4 additions & 2 deletions guidance/models/_base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from ._client import Client
from ._model import Model
from ._state import BaseState, Message
from ._state import Message, State

__all__ = [
"Model",
"role",
"BaseState",
"State",
"Message",
"Client",
]
13 changes: 13 additions & 0 deletions guidance/models/_base/_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from abc import ABC, abstractmethod
from typing import Generic, Iterator, TypeVar

from ...experimental.ast import MessageChunk, Node
from ._state import State

S = TypeVar("S", bound=State)


class Client(ABC, Generic[S]):
@abstractmethod
def run(self, state: S, node: Node) -> Iterator[MessageChunk]:
pass
34 changes: 10 additions & 24 deletions guidance/models/_base/_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# TODO(nopdive): This module requires a memory review.

import re
from abc import ABC, abstractmethod
from base64 import b64encode
from contextvars import ContextVar
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Generic, Iterator, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union

from typing_extensions import Self, assert_never

Expand All @@ -23,7 +22,8 @@
TraceNode,
)
from ...visual import TraceMessage
from ._state import BaseState
from ._client import Client
from ._state import State

_active_role: ContextVar[Optional["RoleStart"]] = ContextVar("active_role", default=None)
_id_counter: int = 0
Expand All @@ -37,17 +37,20 @@ def _gen_id():
return _id


S = TypeVar("S", bound=BaseState)
S = TypeVar("S", bound=State)
D = TypeVar("D", bound=Any)


class Model(ABC, Generic[S]):
class Model(Generic[S]):
def __init__(
self,
client: Client[S],
state: S,
echo: bool = True,
) -> None:
self.echo = echo
self._state = self.initial_state()
self._client = client
self._state = state
self._active_role: Optional["RoleStart"] = None

self._parent: Optional["Model"] = None
Expand All @@ -56,14 +59,6 @@ def __init__(
self._trace_nodes: set[TraceNode] = set()
self._update_trace_node(self._id, self._parent_id, None)

@abstractmethod
def run(self, state: S, node: Node) -> Iterator[MessageChunk]:
pass

@abstractmethod
def initial_state(self) -> S:
pass

def _update_trace_node(
self, identifier: int, parent_id: Optional[int], node_attr: Optional[NodeAttr] = None
) -> None:
Expand Down Expand Up @@ -91,7 +86,7 @@ def __add__(self, other: Node) -> Self:
return self

def _apply_node(self, node: Node) -> Self:
for chunk in self.run(self._state, node):
for chunk in self._client.run(self._state, node):
self = self._apply_chunk(chunk)
return self

Expand Down Expand Up @@ -246,12 +241,3 @@ def extract_embedded_nodes(value: str) -> Node:
grammar += part
is_id = not is_id
return grammar


def partial_decode(data: bytes) -> tuple[str, bytes]:
try:
return (data.decode("utf-8"), b"")
except UnicodeDecodeError as e:
valid_part = data[: e.start].decode("utf-8")
delayed_part = data[e.start :]
return (valid_part, delayed_part)
2 changes: 1 addition & 1 deletion guidance/models/_base/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class CaptureVar(TypedDict):
log_prob: Optional[float]


class BaseState(ABC):
class State(ABC):
def __init__(self) -> None:
self.chunks: list[MessageChunk] = []
self.captures: dict[str, Union[CaptureVar, list[CaptureVar]]] = {}
Expand Down
9 changes: 5 additions & 4 deletions guidance/models/_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from ._tokenizer import Tokenizer # isort:skip
from ._client import EngineClient
from ._engine import Engine
from ._model import Model, ModelWithEngine
from ._state import EngineState
from ._state import EngineState, Llama3VisionState, Phi3VisionState

__all__ = [
"Tokenizer",
"Engine",
"Model",
"ModelWithEngine",
"EngineClient",
"EngineState",
"Llama3VisionState",
"Phi3VisionState",
]
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
RoleOpenerInput,
TextOutput,
)
from .._base._model import Model, partial_decode
from .._base import Client
from ._engine import Engine
from ._state import EngineState


class ModelWithEngine(Model[EngineState]):
def __init__(self, engine: Engine, echo: bool = True):
class EngineClient(Client[EngineState]):
def __init__(self, engine: Engine):
self.engine = engine
super().__init__(echo=echo)

def run(self, state: EngineState, node: Node) -> Iterator[MessageChunk]:
if isinstance(node, str):
Expand Down Expand Up @@ -107,3 +106,12 @@ def initial_state(self) -> EngineState:
# TODO: for llama_cpp and transformers, we need to provide an interface
# for getting these from something like a model id..?
return EngineState()


def partial_decode(data: bytes) -> tuple[str, bytes]:
try:
return (data.decode("utf-8"), b"")
except UnicodeDecodeError as e:
valid_part = data[: e.start].decode("utf-8")
delayed_part = data[e.start :]
return (valid_part, delayed_part)
2 changes: 1 addition & 1 deletion guidance/models/_engine/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import time
import weakref
from abc import ABC, abstractmethod
from abc import ABC
from asyncio import CancelledError
from enum import Enum
from multiprocessing import Manager, Process
Expand Down
23 changes: 21 additions & 2 deletions guidance/models/_engine/_state.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Any

from .._base import BaseState
from ...experimental.ast import ImageBlob
from .._base import State


class EngineState(BaseState):
class EngineState(State):
def __init__(self) -> None:
super().__init__()
self.images: list[Any] = []
Expand All @@ -16,3 +17,21 @@ def apply_text(self, text: str) -> None:
else:
self.active_message["data"]["content"] += text
self.text += text


class Llama3VisionState(EngineState):
def apply_image(self, image: ImageBlob) -> None:
self.images.append(image.image)
text = "<|image|>"
EngineState.apply_text(self, text)


class Phi3VisionState(EngineState):
def apply_image(self, image: ImageBlob) -> None:
pil_image = image.image
if pil_image in self.images:
ix = self.images.index(pil_image) + 1
else:
self.images.append(pil_image)
ix = len(self.images)
EngineState.apply_text(self, f"<|image_{ix}|>")
Loading

0 comments on commit 48c3fbc

Please sign in to comment.