Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exllamav2 Integration #1010

Closed
wants to merge 33 commits into from
Closed
Changes from 5 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
8191d21
Exllamav2_filter
isamu-isozaki Jun 29, 2024
42978c4
Fix comment
isamu-isozaki Jun 29, 2024
d271fff
Fixed precommit issues
isamu-isozaki Jun 29, 2024
1bdcd4e
Removed text
isamu-isozaki Jul 10, 2024
ecf1d3c
Merge branch 'main' of https://github.com/outlines-dev/outlines into …
isamu-isozaki Jul 10, 2024
1a193a5
Merge branch 'main' of https://github.com/outlines-dev/outlines into …
isamu-isozaki Jul 21, 2024
37d2471
Basic draft done
isamu-isozaki Aug 1, 2024
a68ddd7
Passed local test
isamu-isozaki Aug 6, 2024
197718f
Fixed tests+precommit
isamu-isozaki Aug 13, 2024
df4bd6a
Revert change for pyairports
isamu-isozaki Aug 13, 2024
4ffdf34
Fixed precommit
isamu-isozaki Aug 13, 2024
39ecf7d
Wrap up
isamu-isozaki Aug 13, 2024
ab731bd
Remove | for union
isamu-isozaki Aug 13, 2024
4cda254
Attempt changing to List
isamu-isozaki Aug 13, 2024
f402f33
Fixed for 3.8
isamu-isozaki Aug 13, 2024
e014a63
Adding exllamav2 to optional dependency
isamu-isozaki Aug 19, 2024
06e9b64
Fixed model
isamu-isozaki Aug 19, 2024
f43e4d2
Changed to fork
isamu-isozaki Aug 19, 2024
7b29e8c
Fix format
isamu-isozaki Aug 19, 2024
8d1fca6
Changed order
isamu-isozaki Aug 19, 2024
09e4843
Skip exllamav2 tests
isamu-isozaki Aug 20, 2024
511591a
Merge branch 'main' into exllamav2_filter
isamu-isozaki Aug 20, 2024
785d7de
Attempt fixing coverage
isamu-isozaki Aug 30, 2024
91c3e7a
Merge branch 'main' of https://github.com/outlines-dev/outlines into …
isamu-isozaki Aug 30, 2024
faadf5b
Attempt fix coverage
isamu-isozaki Aug 30, 2024
2a909af
Merge branch 'exllamav2_filter' of https://github.com/isamu-isozaki/o…
isamu-isozaki Aug 30, 2024
7ca151c
Remove flash-attn requirement
isamu-isozaki Aug 30, 2024
2c241ff
Fixed fixture tests
isamu-isozaki Aug 30, 2024
a289b5a
Removed lora
isamu-isozaki Aug 30, 2024
c3681a8
Passed coverage
isamu-isozaki Aug 30, 2024
e6b3af6
Added back transformers install
isamu-isozaki Sep 20, 2024
5508c92
Fixed per review
isamu-isozaki Sep 20, 2024
b7e92a1
Made coverage 100%
isamu-isozaki Sep 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions outlines/integrations/exllamav2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
"""Make ExllamaV2 compatible with Outlines' structured generation.

_______________________________
/ Don't want to self-host? \
\\ Try .json at http://dottxt.co /
-------------------------------
\\ ^__^
\\ (oo)\\_______
(__)\\ )\\/\
||----w |
|| ||

Copyright 2024- the Outlines developers

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import copy
from collections import defaultdict
from typing import DefaultDict, Optional, Type, Union

import torch
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase

from outlines.fsm.guide import Guide, RegexGuide
from outlines.fsm.json_schema import build_regex_from_schema
from outlines.generate.generator import is_generation_finished
from outlines.integrations.utils import adapt_tokenizer, convert_json_schema_to_str


class FSMFilter:
"""Bias transformers generation based on a fsm.

Attributes
----------
fsm
The finite state machine which is used to bias the logits.
"""

token_sequence: list[int]
seq_id: int

def __init__(self, fsm: Guide):
"""Compile the FSM that drives generation.

Parameters
----------
fsm
The fsm of the model.
"""
self.fsm = fsm
self._fsm_state: DefaultDict[int, int] = defaultdict(int)
self.token_sequence = []

def begin(self, prefix_str: str = "") -> None:
self._fsm_state = defaultdict(int)
self.seq_id = hash(tuple([]))

def feed(self, token: torch.Tensor) -> None:
int_token = int(token[0][0].numpy())

last_seq_id = self.seq_id
self.token_sequence.append(int_token)
self.seq_id = hash(tuple(self.token_sequence))
self._fsm_state[self.seq_id] = self.fsm.get_next_state(
state=self._fsm_state[last_seq_id], token_id=int_token
)

def clone(self):
return copy.deepcopy(self)

def next(self) -> tuple[set[int], set[int]]:
allowed_tokens = self.fsm.get_next_instruction(
state=self._fsm_state[self.seq_id]
).tokens
if allowed_tokens is None:
allowed_tokens = []
end_tokens = []
for token in allowed_tokens:
next_state = self.fsm.get_next_state(
state=self._fsm_state[self.seq_id], token_id=token
)
if is_generation_finished([self.fsm], [next_state]):
end_tokens.append(token)
return set(allowed_tokens), set(end_tokens)


class RegexFilter(FSMFilter):
"""Bias transformers generation based on a regular expression.

Attributes
----------
fsm
The finite state machine which is used to bias the logits.
"""

def __init__(
self,
regex_string: str,
tokenizer: PreTrainedTokenizerBase,
):
"""Compile the FSM that drives the regex-structured generation.

Parameters
----------
tokenizer
The tokenizer of the model.

Raises
------
ValueError
If the `tokenizer` parameter is not a tokenizer.
"""
assert isinstance(tokenizer, PreTrainedTokenizerBase)
tokenizer = adapt_tokenizer(tokenizer=tokenizer)
fsm = RegexGuide(regex_string=regex_string, tokenizer=tokenizer)
super().__init__(fsm)


class JSONFilter(RegexFilter):
"""Bias exllamav2 generation based on a JSON schema.

Attributes
----------
fsm
The finite state machine which is used to bias the logits.
"""

def __init__(
self,
schema: Union[dict, Type[BaseModel], str],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Optional[str] = None,
):
"""Compile the FSM that drives the JSON-guided generation.

Parameters
----------
schema
A schema that encodes the structure we want the model to generate.
tokenizer
The tokenizer of the model.
whitespace_pattern
Pattern to use for JSON syntactic whitespace (doesn't impact string
literals). For example, to allow only a single space or newline with
`whitespace_pattern=r"[\n ]?"`
"""
schema_str = convert_json_schema_to_str(json_schema=schema)
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
super().__init__(regex_string=regex_string, tokenizer=tokenizer)


class ChoiceFilter(RegexFilter):
"""Bias exllamav2 generation based on choices.

Attributes
----------
fsm
The finite state machine which is used to bias the logits.
"""

def __init__(
self,
choices: list[str],
tokenizer: PreTrainedTokenizerBase,
):
"""Compile the FSM that drives the JSON-guided generation.

Parameters
----------
schema
A schema that encodes the structure we want the model to generate.
tokenizer
The tokenizer of the model.
whitespace_pattern
Pattern to use for JSON syntactic whitespace (doesn't impact string
literals). For example, to allow only a single space or newline with
`whitespace_pattern=r"[\n ]?"`
"""
regex_string = r"(" + r"|".join(choices) + r")"
super().__init__(regex_string=regex_string, tokenizer=tokenizer)
Loading