Skip to content

Commit 4304e4f

Browse files
Merge pull request #217 from mistralai/pixtral
Pixtral
2 parents 3fd585d + 510f7ae commit 4304e4f

File tree

10 files changed

+576
-160
lines changed

10 files changed

+576
-160
lines changed

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "mistral_inference"
3-
version = "1.3.1"
3+
version = "1.4.0"
44
description = ""
55
authors = ["bam4d <bam4d@mistral.ai>"]
66
readme = "README.md"
@@ -27,8 +27,9 @@ python = "^3.9.10"
2727
xformers = ">=0.0.24"
2828
simple-parsing = ">=0.1.5"
2929
fire = ">=0.6.0"
30-
mistral_common = "^1.3.0"
30+
mistral_common = ">=1.4.0"
3131
safetensors = ">=0.4.0"
32+
pillow = ">=10.3.0"
3233

3334
[tool.poetry.group.dev.dependencies]
3435
types-protobuf = "4.24.0.20240129"

src/mistral_inference/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.3.1"
1+
__version__ = "1.4.0"

src/mistral_inference/args.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,19 @@
77
from mistral_inference.moe import MoeArgs
88

99

10+
@dataclass
11+
class VisionEncoderArgs:
12+
hidden_size: int
13+
num_channels: int
14+
image_size: int
15+
patch_size: int
16+
intermediate_size: int
17+
num_hidden_layers: int
18+
num_attention_heads: int
19+
rope_theta: float = 1e4 # for rope-2D
20+
image_token_id: int = 10
21+
22+
1023
@dataclass
1124
class TransformerArgs(Serializable):
1225
dim: int
@@ -28,7 +41,9 @@ class TransformerArgs(Serializable):
2841
lora: Optional[LoraArgs] = None
2942
model_type: str = "transformer"
3043

31-
def __post_init__(self):
44+
vision_encoder: Optional[VisionEncoderArgs] = None
45+
46+
def __post_init__(self) -> None:
3247
assert self.model_type == "transformer", self.model_type
3348

3449

@@ -45,5 +60,5 @@ class MambaArgs(Serializable):
4560
tie_embeddings: bool
4661
model_type: str = "mamba"
4762

48-
def __post_init__(self):
63+
def __post_init__(self) -> None:
4964
assert self.model_type == "mamba", self.model_type

src/mistral_inference/generate.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import List, Optional, Tuple
22

3+
import numpy as np
34
import torch
45

56
from mistral_inference.cache import BufferCache
@@ -43,12 +44,21 @@ def generate_mamba(
4344
def generate(
4445
encoded_prompts: List[List[int]],
4546
model: Transformer,
47+
images: List[List[np.ndarray]] = [],
4648
*,
4749
max_tokens: int,
4850
temperature: float,
4951
chunk_size: Optional[int] = None,
5052
eos_id: Optional[int] = None,
5153
) -> Tuple[List[List[int]], List[List[float]]]:
54+
images_torch: List[List[torch.Tensor]] = []
55+
if images:
56+
assert chunk_size is None
57+
images_torch = [
58+
[torch.tensor(im, device=model.device, dtype=model.dtype) for im in images_for_sample]
59+
for images_for_sample in images
60+
]
61+
5262
model = model.eval()
5363
B, V = len(encoded_prompts), model.args.vocab_size
5464

@@ -75,12 +85,15 @@ def generate(
7585
if chunk_size is None:
7686
chunk_size = max_prompt_len
7787

88+
flattened_images: List[torch.Tensor] = sum(images_torch, [])
89+
7890
# Encode prompt by chunks
7991
for s in range(0, max_prompt_len, chunk_size):
8092
prompt_chunks = [p[s : s + chunk_size] for p in encoded_prompts]
8193
assert all(len(p) > 0 for p in prompt_chunks)
8294
prelogits = model.forward(
8395
torch.tensor(sum(prompt_chunks, []), device=model.device, dtype=torch.long),
96+
images=flattened_images,
8497
seqlens=[len(p) for p in prompt_chunks],
8598
cache=cache,
8699
)

src/mistral_inference/main.py

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,31 @@
33
import os
44
import warnings
55
from pathlib import Path
6-
from typing import List, Optional, Type, Union
6+
from typing import List, Optional, Tuple, Type, Union
77

88
import fire # type: ignore
99
import torch
1010
import torch.distributed as dist
11-
from mistral_common.protocol.instruct.messages import AssistantMessage, UserMessage
11+
from mistral_common.protocol.instruct.messages import (
12+
AssistantMessage,
13+
ContentChunk,
14+
ImageChunk,
15+
ImageURLChunk,
16+
TextChunk,
17+
UserMessage,
18+
)
1219
from mistral_common.protocol.instruct.request import ChatCompletionRequest
1320
from mistral_common.tokens.tokenizers.base import Tokenizer
1421
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
15-
from mistral_common.tokens.tokenizers.tekken import Tekkenizer, SpecialTokenPolicy
1622
from mistral_common.tokens.tokenizers.sentencepiece import is_sentencepiece
17-
from mistral_common.tokens.tokenizers.tekken import is_tekken
18-
23+
from mistral_common.tokens.tokenizers.tekken import (
24+
SpecialTokenPolicy,
25+
Tekkenizer,
26+
is_tekken,
27+
)
28+
from PIL import Image
29+
30+
from mistral_inference.args import TransformerArgs
1931
from mistral_inference.generate import generate, generate_mamba
2032
from mistral_inference.mamba import Mamba
2133
from mistral_inference.transformer import Transformer
@@ -62,6 +74,31 @@ def pad_and_convert_to_tensor(list_of_lists: List[List[int]], pad_id: int) -> Li
6274
return padded_lists
6375

6476

77+
def _get_multimodal_input() -> Tuple[UserMessage, bool]:
78+
chunks: List[ContentChunk] = []
79+
80+
response = input("Text prompt: ")
81+
if response:
82+
chunks.append(TextChunk(text=response))
83+
84+
print("[You can input zero, one or more images now.]")
85+
while True:
86+
did_something = False
87+
response = input("Image path or url [Leave empty and press enter to finish image input]: ")
88+
if response:
89+
if Path(response).is_file():
90+
chunks.append(ImageChunk(image=Image.open(response)))
91+
else:
92+
assert response.startswith("http"), f"{response} does not seem to be a valid url."
93+
chunks.append(ImageURLChunk(image_url=response))
94+
did_something = True
95+
96+
if not did_something:
97+
break
98+
99+
return UserMessage(content=chunks), not chunks
100+
101+
65102
def interactive(
66103
model_path: str,
67104
max_tokens: int = 35,
@@ -85,6 +122,10 @@ def interactive(
85122

86123
model_cls = get_model_cls(model_path)
87124
model = model_cls.from_folder(Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks)
125+
is_multimodal = isinstance(model.args, TransformerArgs) and model.args.vision_encoder is not None
126+
127+
if is_multimodal:
128+
assert instruct, "Multimodal models should only be used in instruct mode"
88129

89130
# load LoRA
90131
if lora_path is not None:
@@ -95,17 +136,27 @@ def interactive(
95136

96137
while True:
97138
if should_print:
98-
user_input = input("Prompt: ")
139+
if not is_multimodal:
140+
user_input = input("Prompt: ")
99141

100142
if instruct:
101-
messages += [UserMessage(content=user_input)]
143+
if is_multimodal:
144+
mm_input, finished = _get_multimodal_input()
145+
if finished:
146+
break
147+
messages += [mm_input]
148+
else:
149+
messages += [UserMessage(content=user_input)]
102150
chat_completion_request = ChatCompletionRequest(messages=messages)
103151

104-
tokens = mistral_tokenizer.encode_chat_completion(chat_completion_request).tokens
152+
tokenized = mistral_tokenizer.encode_chat_completion(chat_completion_request)
153+
tokens = tokenized.tokens
154+
images = tokenized.images
105155
else:
106156
prompt += user_input
107157

108158
tokens = tokenizer.encode(prompt, bos=True, eos=False)
159+
images = []
109160

110161
length_tensor = torch.tensor([len(tokens)], dtype=torch.int)
111162
else:
@@ -121,6 +172,7 @@ def interactive(
121172
generated_tokens, _ = generate_fn( # type: ignore[operator]
122173
[tokens],
123174
model,
175+
[images],
124176
max_tokens=max_tokens,
125177
temperature=temperature,
126178
eos_id=tokenizer.eos_id,

src/mistral_inference/rope.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,34 @@ def apply_rotary_emb(
1818
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
1919
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
2020
freqs_cis = freqs_cis[:, None, :]
21-
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)
22-
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)
21+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)
22+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)
2323
return xq_out.type_as(xq), xk_out.type_as(xk)
24+
25+
26+
def precompute_freqs_cis_2d(
27+
dim: int,
28+
height: int,
29+
width: int,
30+
theta: float,
31+
) -> torch.Tensor:
32+
"""
33+
freqs_cis: 2D complex tensor of shape (height, width, dim // 2) to be indexed by
34+
(height, width) position tuples
35+
"""
36+
# (dim / 2) frequency bases
37+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
38+
39+
h = torch.arange(height, device=freqs.device)
40+
w = torch.arange(width, device=freqs.device)
41+
42+
freqs_h = torch.outer(h, freqs[::2]).float()
43+
freqs_w = torch.outer(w, freqs[1::2]).float()
44+
freqs_2d = torch.cat(
45+
[
46+
freqs_h[:, None, :].repeat(1, width, 1),
47+
freqs_w[None, :, :].repeat(height, 1, 1),
48+
],
49+
dim=-1,
50+
)
51+
return torch.polar(torch.ones_like(freqs_2d), freqs_2d)

0 commit comments

Comments
 (0)