Skip to content

Commit 783acf6

Browse files
committedJun 17, 2024
Fix mamba integration by making it a variant of outlines.models.transformers
1 parent 1537695 commit 783acf6

File tree

6 files changed

+116
-90
lines changed

6 files changed

+116
-90
lines changed
 

Diff for: ‎docs/reference/models/mamba.md

-7
This file was deleted.

Diff for: ‎docs/reference/models/transformers.md

+59
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,62 @@ print(output)
8282
```
8383

8484
[transformers]: https://github.com/huggingface/transformers
85+
86+
87+
# Alternative Model Classes
88+
89+
`outlines.models.transformers` defaults to `transformers.AutoModelForCausalLM`, which is the appropriate class for most standard large language models, including Llama 3, Mistral, Phi-3, etc.
90+
91+
However other variants with unique behavior can be used as well by passing the appropriate class.
92+
93+
### Mamba
94+
95+
[Mamba](https://github.com/state-spaces/mamba) is a transformers alternative which employs memory efficient, linear-time decoding.
96+
97+
To use Mamba with outlines you must first install the necessary requirements:
98+
```
99+
pip install causal-conv1d>=1.2.0 mamba-ssm torch transformers
100+
```
101+
102+
Then you can either create an Mamba-2 Outlines model via
103+
```
104+
import outlines
105+
106+
model = outlines.models.mamba("state-spaces/mamba-2.8b-hf")
107+
```
108+
109+
or explicitly with
110+
```
111+
import outlines
112+
from transformers import MambaForCausalLM
113+
114+
model = outlines.models.transformers(
115+
"state-spaces/mamba-2.8b-hf",
116+
model_class=MambaForCausalLM
117+
)
118+
```
119+
120+
Further Reading:
121+
- https://huggingface.co/docs/transformers/en/model_doc/mamba
122+
123+
### Encoder-Decoder Models
124+
125+
You can use encoder-decoder (seq2seq) models like T5 and BERT with Outlines.
126+
127+
Be cautious with model selection though, some models such as `t5-base` don't include certain characters (`{`) and you may get an error when trying to perform structured generation.
128+
129+
Example:
130+
```
131+
import outlines
132+
from transformers import AutoModelForSeq2SeqLM
133+
134+
model = models.transformers(
135+
model_name="EleutherAI/pile-t5-large",
136+
model_class=transformers.AutoModelForSeq2SeqLM,
137+
)
138+
```
139+
140+
141+
### Multi-Modal Models
142+
143+
/Coming soon/

Diff for: ‎outlines/models/__init__.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@
99

1010
from .exllamav2 import ExLlamaV2Model, exl2
1111
from .llamacpp import LlamaCpp, llamacpp
12-
from .mamba import Mamba, mamba
1312
from .mlxlm import MLXLM, mlxlm
1413
from .openai import OpenAI, azure_openai, openai
15-
from .transformers import Transformers, TransformerTokenizer, transformers
14+
from .transformers import Transformers, TransformerTokenizer, transformers, mamba
1615
from .vllm import VLLM, vllm
1716

18-
LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, Mamba]
17+
LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model]

Diff for: ‎outlines/models/mamba.py

-61
This file was deleted.

Diff for: ‎outlines/models/transformers.py

+38-8
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,8 @@ def transformers(
369369
device: Optional[str] = None,
370370
model_kwargs: dict = {},
371371
tokenizer_kwargs: dict = {},
372+
model_class=None,
373+
tokenizer_class=None,
372374
):
373375
"""Instantiate a model from the `transformers` library and its tokenizer.
374376
@@ -391,19 +393,47 @@ def transformers(
391393
A `TransformersModel` model instance.
392394
393395
"""
394-
try:
395-
from transformers import AutoModelForCausalLM, AutoTokenizer
396-
except ImportError:
397-
raise ImportError(
398-
"The `transformers` library needs to be installed in order to use `transformers` models."
399-
)
396+
if model_class is None or tokenizer_class is None:
397+
try:
398+
from transformers import AutoModelForCausalLM, AutoTokenizer
399+
except ImportError:
400+
raise ImportError(
401+
"The `transformers` library needs to be installed in order to use `transformers` models."
402+
)
403+
if model_class is None:
404+
model_class = AutoModelForCausalLM
405+
if tokenizer_class is None:
406+
tokenizer_class = AutoTokenizer
400407

401408
if device is not None:
402409
model_kwargs["device_map"] = device
403410

404-
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
411+
model = model_class.from_pretrained(model_name, **model_kwargs)
405412

406413
tokenizer_kwargs.setdefault("padding_side", "left")
407-
tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
414+
tokenizer = tokenizer_class.from_pretrained(model_name, **tokenizer_kwargs)
408415

409416
return Transformers(model, tokenizer)
417+
418+
419+
def mamba(
420+
model_name: str,
421+
device: Optional[str] = None,
422+
model_kwargs: dict = {},
423+
tokenizer_kwargs: dict = {},
424+
):
425+
try:
426+
from transformers import MambaForCausalLM
427+
428+
except ImportError:
429+
raise ImportError(
430+
"The `mamba_ssm`, `torch` and `transformer` libraries needs to be installed in order to use Mamba people."
431+
)
432+
433+
return transformers(
434+
model_name=model_name,
435+
device=device,
436+
model_kwargs=model_kwargs,
437+
tokenizer_kwargs=tokenizer_kwargs,
438+
model_class=MambaForCausalLM,
439+
)

Diff for: ‎tests/generate/test_generate.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,28 @@ def model_vllm(tmp_path_factory):
3131
return models.vllm("facebook/opt-125m")
3232

3333

34-
# TODO: mamba / exllamav2 failing in main, address in https://github.com/outlines-dev/outlines/issues/808
35-
"""
3634
@pytest.fixture(scope="session")
37-
def model_exllamav2(tmp_path_factory):
38-
return models.exllamav2(
39-
model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2",
40-
device="cpu"
35+
def model_mamba(tmp_path_factory):
36+
return models.mamba(model_name="state-spaces/mamba-130m-hf", device="cpu")
37+
38+
39+
@pytest.fixture(scope="session")
40+
def model_t5(tmp_path_factory):
41+
from transformers import T5ForConditionalGeneration
42+
43+
return models.transformers(
44+
"google/t5-efficient-mini", device="cpu", model_class=T5ForConditionalGeneration
4145
)
4246

4347

48+
# TODO: exllamav2 failing in main, address in https://github.com/outlines-dev/outlines/issues/808
49+
"""
4450
@pytest.fixture(scope="session")
45-
def model_mamba(tmp_path_factory):
46-
return models.mamba(
47-
model_name="state-spaces/mamba-130m-hf",
51+
def model_exllamav2(tmp_path_factory):
52+
return models.exllamav2(
53+
model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2",
4854
device="cpu"
4955
)
50-
51-
ALL_MODEL_FIXTURES = ("model_llamacpp", "model_mlxlm", "model_transformers", "model_vllm", "model_exllamav2", "model_mamba")
5256
"""
5357

5458

@@ -57,6 +61,8 @@ def model_mamba(tmp_path_factory):
5761
"model_mlxlm",
5862
"model_transformers",
5963
"model_vllm",
64+
"model_mamba",
65+
"model_t5",
6066
)
6167

6268

0 commit comments

Comments
 (0)
Failed to load comments.