Skip to content

Commit b75beeb

Browse files
committed
Use LogitsProcessors for models.transformers -> outlines.generate.*
1 parent 26142d5 commit b75beeb

18 files changed

+624
-357
lines changed

README.md

+2-3
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,9 @@ model = outlines.models.transformers("mistralai/Mistral-7B-Instruct-v0.2")
191191
generator = outlines.generate.json(model, Character)
192192

193193
# Draw a sample
194-
rng = torch.Generator(device="cuda")
195-
rng.manual_seed(789001)
194+
seed = 789001
196195

197-
character = generator("Give me a character description", rng=rng)
196+
character = generator("Give me a character description", seed=seed)
198197

199198
print(repr(character))
200199
# Character(name='Anderson', age=28, armor=<Armor.chainmail: 'chainmail'>, weapon=<Weapon.sword: 'sword'>, strength=8)

docs/reference/models/transformers.md

+52-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Outlines provides an integration with the `torch` implementation of causal model
1515
```python
1616
from outlines import models
1717

18-
model = models.transformers("mistralai/Mistral-7B-v0.1", device="cuda")
18+
model = models.transformers("mistralai/Mistral-7B-v0.3", device="cuda")
1919
```
2020

2121
If you need more fine-grained control you can also initialize the model and tokenizer separately:
@@ -30,4 +30,55 @@ tokenizer = AutoTokenizer.from_pretrained("gpt2")
3030
model = models.Transformers(llm, tokenizer)
3131
```
3232

33+
# Using Logits Processors
34+
35+
There are two ways to use Outlines Structured Generation with HuggingFace Transformers:
36+
- 1) Use Outlines generation wrapper, `outlines.models.transformers`
37+
- 2) Use `OutlinesLogitsProcessor` with `transformers.AutoModelForCausalLM`
38+
39+
Outlines supports a myriad of logits processors for structured generation. In these example, we will use the `RegexLogitsProcessor` which guarantees generated text matches the specified pattern.
40+
41+
## Example: `outlines.models.transformers`
42+
43+
```
44+
import outlines
45+
46+
time_regex_pattern = r"(0?[1-9]|1[0-2]):[0-5]\d\s?(am|pm)?"
47+
48+
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct", device="cuda")
49+
generator = outlines.generate.regex(model, time_regex_pattern)
50+
51+
output = generator("The the best time to visit a dentist is at ")
52+
print(output)
53+
# 2:30 pm
54+
```
55+
56+
## Example: Direct `transformers` library use
57+
58+
```
59+
import outlines
60+
import transformers
61+
62+
63+
model_uri = "microsoft/Phi-3-mini-4k-instruct"
64+
65+
outlines_tokenizer = outlines.models.TransformerTokenizer(
66+
transformers.AutoTokenizer.from_pretrained(model_uri)
67+
)
68+
phone_number_logits_processor = outlines.processors.RegexLogitsProcessor(
69+
"\\+?[1-9][0-9]{7,14}", # phone number pattern
70+
outlines_tokenizer,
71+
)
72+
73+
generator = transformers.pipeline('text-generation', model=model_uri)
74+
75+
output = generator(
76+
"Jenny gave me her number it's ",
77+
logits_processor=transformers.LogitsProcessorList([phone_number_logits_processor])
78+
)
79+
print(output)
80+
# [{'generated_text': "Jenny gave me her number it's 2125550182"}]
81+
# not quite 8675309 what we expected, but it is a valid phone number
82+
```
83+
3384
[transformers]: https://github.com/huggingface/transformers

docs/reference/text.md

+2-3
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ from outlines import models, generate
8080

8181
model = models.transformers("mistralai/Mistral-7B-v0.1")
8282

83-
rng = torch.Generator(device="cuda")
84-
rng.manual_seed(789001)
83+
seed = 789001
8584

86-
answer = generator("What is 2+2?", rng=rng)
85+
answer = generator("What is 2+2?", seed=seed)
8786
```

examples/llamacpp_example.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from enum import Enum
22

3-
import torch
43
from pydantic import BaseModel, constr
54

65
import outlines
@@ -37,10 +36,9 @@ class Character(BaseModel):
3736
generator = outlines.generate.json(model, Character)
3837

3938
# Draw a sample
40-
rng = torch.Generator(device="cpu")
41-
rng.manual_seed(789005)
39+
seed = 789005
4240

4341
prompt = "Instruct: You are a leading role play gamer. You have seen thousands of different characters and their attributes.\nPlease return a JSON object with common attributes of an RPG character. Give me a character description\nOutput:"
4442

45-
sequence = generator(prompt, rng=rng, max_tokens=512)
43+
sequence = generator(prompt, seed=seed, max_tokens=512)
4644
print(sequence)

outlines/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import outlines.generate
33
import outlines.grammars
44
import outlines.models
5+
import outlines.processors
56
import outlines.types
67
from outlines.base import vectorize
78
from outlines.caching import clear_cache, disable_cache, get_cache

outlines/generate/cfg.py

+8-34
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
from functools import singledispatch
22

3-
from outlines.fsm.guide import CFGGuide
4-
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
3+
from outlines.generate.api import SequenceGeneratorAdapter
54
from outlines.models import OpenAI
6-
from outlines.models.llamacpp import LlamaCpp
7-
from outlines.models.mlxlm import MLXLM
8-
from outlines.models.vllm import VLLM
95
from outlines.samplers import Sampler, multinomial
106

117

128
@singledispatch
13-
def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenerator:
9+
def cfg(
10+
model, cfg_str: str, sampler: Sampler = multinomial()
11+
) -> SequenceGeneratorAdapter:
1412
"""Generate text in the language of a Context-Free Grammar
1513
1614
Arguments
@@ -24,40 +22,16 @@ def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenera
2422
2523
Returns
2624
-------
27-
A `SequenceGenerator` instance that generates text.
25+
A `SequenceGeneratorAdapter` instance that generates text.
2826
2927
"""
30-
fsm = CFGGuide(cfg_str, model.tokenizer)
31-
device = model.device
32-
generator = SequenceGenerator(fsm, model, sampler, device)
33-
34-
return generator
35-
36-
37-
@cfg.register(MLXLM)
38-
@cfg.register(VLLM)
39-
def cfg_unimplemented(
40-
model,
41-
cfg_str: str,
42-
sampler: Sampler = multinomial(),
43-
):
4428
raise NotImplementedError(
45-
f"The CFG Logits processor is not available for {type(model)}."
29+
f"The CFG Logits processor is not available for {type(model)}. "
30+
+ "Please subscribe to https://github.com/outlines-dev/outlines/issues/684"
31+
+ " for updates on the fix."
4632
)
4733

4834

49-
@cfg.register(LlamaCpp)
50-
def cfg_llamacpp(
51-
model: LlamaCpp,
52-
cfg_str: str,
53-
sampler: Sampler = multinomial(),
54-
):
55-
from outlines.integrations.llamacpp import CFGLogitsProcessor
56-
57-
logits_processor = CFGLogitsProcessor(cfg_str, model.model)
58-
return SequenceGeneratorAdapter(model, logits_processor, sampler)
59-
60-
6135
@cfg.register(OpenAI)
6236
def cfg_openai(model, cfg_str: str, sampler: Sampler = multinomial()):
6337
raise NotImplementedError(

outlines/generate/regex.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from outlines.models import OpenAI
66
from outlines.models.llamacpp import LlamaCpp
77
from outlines.models.mlxlm import MLXLM
8+
from outlines.models.transformers import Transformers
89
from outlines.models.vllm import VLLM
910
from outlines.samplers import Sampler, multinomial
1011

@@ -39,8 +40,9 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()):
3940

4041

4142
@regex.register(MLXLM)
42-
def regex_mlxlm(
43-
model: MLXLM,
43+
@regex.register(Transformers)
44+
def regex_unified(
45+
model,
4446
regex_str: str,
4547
sampler: Sampler = multinomial(),
4648
):

outlines/generate/text.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from outlines.fsm.guide import StopAtEOSGuide
44
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
5-
from outlines.models import MLXLM, VLLM, LlamaCpp, OpenAI
5+
from outlines.models import MLXLM, VLLM, LlamaCpp, OpenAI, Transformers
66
from outlines.samplers import Sampler, multinomial
77

88

@@ -37,7 +37,8 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator:
3737

3838

3939
@text.register(MLXLM)
40-
def text_mlxlm(model: MLXLM, sampler: Sampler = multinomial()):
40+
@text.register(Transformers)
41+
def text_unified(model, sampler: Sampler = multinomial()):
4142
return SequenceGeneratorAdapter(model, None, sampler)
4243

4344

outlines/models/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .mamba import Mamba, mamba
1313
from .mlxlm import MLXLM, mlxlm
1414
from .openai import OpenAI, azure_openai, openai
15-
from .transformers import Transformers, transformers
15+
from .transformers import Transformers, TransformerTokenizer, transformers
1616
from .vllm import VLLM, vllm
1717

1818
LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, Mamba, MLXLM, VLLM]

outlines/models/mlxlm.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from transformers import PreTrainedTokenizer
1010

1111
from outlines.generate.api import GenerationParameters, SamplingParameters
12-
from outlines.processors import BaseLogitsProcessor
12+
from outlines.processors import OutlinesLogitsProcessor
1313

1414

1515
class MLXLM:
@@ -120,7 +120,7 @@ def generate_step(
120120
temp: Optional[float],
121121
top_p: Optional[float],
122122
sampler: str,
123-
logits_processor: "BaseLogitsProcessor",
123+
logits_processor: "OutlinesLogitsProcessor",
124124
) -> Generator[Tuple[int, float], None, None]:
125125
"""
126126
Adapted from
@@ -135,7 +135,7 @@ def generate_step(
135135
top_p (float, optional): Nulceus sampling, higher means model considers
136136
more less likely words.
137137
sampler (str): The sampler string defined by SequenceGeneratorAdapter
138-
logits_processor (BaseLogitsProcessor): Augment logits before sampling.
138+
logits_processor (OutlinesLogitsProcessor): Augment logits before sampling.
139139
"""
140140
import mlx.core as mx
141141
import mlx_lm

0 commit comments

Comments
 (0)