diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index 0f39ff28e42..5ac903516d9 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -5,6 +5,7 @@ Text Generation Inference enables serving optimized models. The following sectio - [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2) - [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal) +- [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal) - [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) - [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f) - [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index c9c477665a3..c702ae709ee 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -354,6 +354,7 @@ def local_launcher( kv_cache_dtype: Optional[str] = None, revision: Optional[str] = None, max_input_length: Optional[int] = None, + max_input_tokens: Optional[int] = None, max_batch_prefill_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None, lora_adapters: Optional[List[str]] = None, @@ -402,6 +403,9 @@ def local_launcher( if max_input_length: args.append("--max-input-length") args.append(str(max_input_length)) + if max_input_tokens: + args.append("--max-input-tokens") + args.append(str(max_input_tokens)) if max_batch_prefill_tokens: args.append("--max-batch-prefill-tokens") args.append(str(max_batch_prefill_tokens)) diff --git a/integration-tests/models/__snapshots__/test_idefics3/test_flash_idefics3_next_simple_url.json b/integration-tests/models/__snapshots__/test_idefics3/test_flash_idefics3_next_simple_url.json new file mode 100644 index 00000000000..6bf2b93a2da --- /dev/null +++ b/integration-tests/models/__snapshots__/test_idefics3/test_flash_idefics3_next_simple_url.json @@ -0,0 +1,67 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 9, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 2684, + "logprob": -0.24902344, + "special": false, + "text": " There" + }, + { + "id": 374, + "logprob": -0.0703125, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.23535156, + "special": false, + "text": " a" + }, + { + "id": 35372, + "logprob": -0.125, + "special": false, + "text": " statue" + }, + { + "id": 304, + "logprob": -0.30273438, + "special": false, + "text": " in" + }, + { + "id": 279, + "logprob": -0.20507812, + "special": false, + "text": " the" + }, + { + "id": 2217, + "logprob": -0.076171875, + "special": false, + "text": " image" + }, + { + "id": 13, + "logprob": -0.053710938, + "special": false, + "text": "." + }, + { + "id": 128258, + "logprob": -0.011352539, + "special": true, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " There is a statue in the image." +} diff --git a/integration-tests/models/__snapshots__/test_smolvlm/test_flash_smolvlm_next_simple_url.json b/integration-tests/models/__snapshots__/test_smolvlm/test_flash_smolvlm_next_simple_url.json new file mode 100644 index 00000000000..17a69d0d409 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_smolvlm/test_flash_smolvlm_next_simple_url.json @@ -0,0 +1,61 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 8, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 330, + "logprob": -0.118652344, + "special": false, + "text": " A" + }, + { + "id": 11426, + "logprob": -0.28320312, + "special": false, + "text": " bee" + }, + { + "id": 335, + "logprob": -0.95703125, + "special": false, + "text": " on" + }, + { + "id": 253, + "logprob": -0.06982422, + "special": false, + "text": " a" + }, + { + "id": 11986, + "logprob": -0.49414062, + "special": false, + "text": " pink" + }, + { + "id": 8525, + "logprob": -0.07763672, + "special": false, + "text": " flower" + }, + { + "id": 30, + "logprob": -1.0703125, + "special": false, + "text": "." + }, + { + "id": 49154, + "logprob": -0.092285156, + "special": true, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " A bee on a pink flower." +} diff --git a/integration-tests/models/test_idefics3.py b/integration-tests/models/test_idefics3.py new file mode 100644 index 00000000000..80be2350fad --- /dev/null +++ b/integration-tests/models/test_idefics3.py @@ -0,0 +1,31 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_idefics3_next_handle(launcher): + with launcher("HuggingFaceM4/Idefics3-8B-Llama3") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_idefics3_next(flash_idefics3_next_handle): + await flash_idefics3_next_handle.health(300) + return flash_idefics3_next_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_idefics3_next_simple_url(flash_idefics3_next, response_snapshot): + ny_skyline = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" + query = "What is in this image?" + response = await flash_idefics3_next.generate( + f"<|begin_of_text|><|begin_of_text|>User:![]({ny_skyline}){query}\nAssistant:", + max_new_tokens=10, + seed=1337, + ) + print(response) + assert ( + response.generated_text == " There is a statue in the image." + ), f"{repr(response.generated_text)}" + assert response.details.generated_tokens == 9 + assert response == response_snapshot diff --git a/integration-tests/models/test_smolvlm.py b/integration-tests/models/test_smolvlm.py new file mode 100644 index 00000000000..cd105d84cb5 --- /dev/null +++ b/integration-tests/models/test_smolvlm.py @@ -0,0 +1,31 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_smolvlm_next_handle(launcher): + with launcher("HuggingFaceTB/SmolVLM-Instruct") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_smolvlm_next(flash_smolvlm_next_handle): + await flash_smolvlm_next_handle.health(300) + return flash_smolvlm_next_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_smolvlm_next_simple_url(flash_smolvlm_next, response_snapshot): + ny_skyline = "https://huggingface.co/spaces/merve/chameleon-7b/resolve/main/bee.jpg" + query = "What is in this image?" + response = await flash_smolvlm_next.generate( + f"<|begin_of_text|><|begin_of_text|>User:![]({ny_skyline}){query}\nAssistant:", + max_new_tokens=10, + seed=1337, + ) + print(response) + assert ( + response.generated_text == " A bee on a pink flower." + ), f"{repr(response.generated_text)}" + assert response.details.generated_tokens == 8 + assert response == response_snapshot diff --git a/router/src/config.rs b/router/src/config.rs index 5d07a293ecb..4d5fcfa0639 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -110,6 +110,24 @@ pub struct ClipVisionModel { patch_size: usize, } +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct Idefics3 {} + +impl Idefics3 { + pub fn get_max_longest_edge(&self) -> usize { + 364 + } + + pub fn get_number_of_features(&self) -> usize { + 169 + } + + pub fn get_max_longest_edge_for_image_resize(&self) -> usize { + 1456 + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub struct Idefics2 {} @@ -178,6 +196,7 @@ pub enum Config { Idefics, Mllama, Idefics2(Idefics2), + Idefics3(Idefics3), Ssm, GptBigcode, Granite, diff --git a/router/src/lib.rs b/router/src/lib.rs index 84e9bc48286..21c45241308 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -170,6 +170,7 @@ impl TokenizerConfigToken { #[serde(tag = "processor_class")] pub enum HubPreprocessorConfig { Idefics2Processor(Idefics2Preprocessor), + Idefics3Processor(Idefics2Preprocessor), } impl HubPreprocessorConfig { diff --git a/router/src/validation.rs b/router/src/validation.rs index 8137ac58d2b..6d5b06bd39e 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -614,6 +614,73 @@ fn image_tokens( image_string } + Idefics3(config) => { + const FAKE: &str = ""; + const IMAGE: &str = ""; + const GLOBAL_IMG: &str = ""; + + let max_longest_edge_for_image_resize = config.get_max_longest_edge_for_image_resize(); + + // resize image if it is larger than max_longest_edge_for_image_resize keeping aspect ratio + let (height, width) = if height > max_longest_edge_for_image_resize + || width > max_longest_edge_for_image_resize + { + let aspect_ratio = height as f32 / width as f32; + if height > width { + ( + max_longest_edge_for_image_resize, + (max_longest_edge_for_image_resize as f32 / aspect_ratio) as usize, + ) + } else { + ( + (max_longest_edge_for_image_resize as f32 * aspect_ratio) as usize, + max_longest_edge_for_image_resize, + ) + } + } else { + (height, width) + }; + + let image_seq_len = config.get_number_of_features(); + let max_edge = config.get_max_longest_edge(); + + let (image_rows, image_cols) = if height > max_edge || width > max_edge { + ( + (height as f32 / max_edge as f32).ceil() as usize, + (width as f32 / max_edge as f32).ceil() as usize, + ) + } else { + (0, 0) + }; + + let mut image_string = String::new(); + + if image_rows == 0 && image_cols == 0 { + // Single image case + image_string.push_str(FAKE); + image_string.push_str(GLOBAL_IMG); + image_string.push_str(&IMAGE.repeat(image_seq_len)); + image_string.push_str(FAKE); + } else { + // Split image case + for n_h in 0..image_rows { + for n_w in 0..image_cols { + image_string.push_str(FAKE); + image_string.push_str(&format!("", n_h + 1, n_w + 1)); + image_string.push_str(&IMAGE.repeat(image_seq_len)); + } + image_string.push('\n'); + } + + image_string.push('\n'); + image_string.push_str(FAKE); + image_string.push_str(GLOBAL_IMG); + image_string.push_str(&IMAGE.repeat(image_seq_len)); + image_string.push_str(FAKE); + } + + image_string + } Paligemma(config) => "".repeat(config.get_number_of_features(height, width)), LlavaNext(config) => "".repeat(config.get_number_of_features(height, width)), Qwen2Vl(config) => format!( @@ -647,7 +714,8 @@ fn prepare_input( static RE: Lazy = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); let (tokenizer_query, input_chunks) = match config { Some( - config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)), + config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Paligemma(_) | LlavaNext(_) + | Qwen2Vl(_)), ) => { let mut input_chunks = Vec::new(); let mut tokenizer_query = String::with_capacity(inputs.len()); diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fcc79608645..beefeb01672 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -152,6 +152,9 @@ from text_generation_server.models.custom_modeling.idefics2 import ( Idefics2ForConditionalGeneration, ) + from text_generation_server.models.custom_modeling.idefics3 import ( + Idefics3ForConditionalGeneration, + ) from text_generation_server.models.custom_modeling.qwen2_vl import ( Qwen2VLForConditionalGeneration, ) @@ -188,6 +191,12 @@ class ModelType(enum.Enum): "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b", "multimodal": True, } + IDEFICS3 = { + "type": "idefics3", + "name": "Idefics 3", + "url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3", + "multimodal": True, + } LLAVA_NEXT = { "type": "llava_next", "name": "Llava Next (1.6)", @@ -1253,6 +1262,24 @@ def get_model( ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) + if model_type == IDEFICS3: + if FLASH_ATTENTION: + return VlmCausalLM( + model_id=model_id, + model_class=Idefics3ForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + # XXX: Extremely important to cap resolution in order to limit + # VRAM usage. + processor_kwargs={"size": {"longest_edge": 1456}}, + ) + else: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == PALIGEMMA: if FLASH_ATTENTION: return VlmCausalLM( diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 10309006af9..28db42fea20 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -515,9 +515,7 @@ def __init__(self, prefix, config, weights): self.layers.append( FlashLlamaLayer( index=0, - prefix=( - "model.layers.0" if not prefix else f"{prefix}.model.layers.0" - ), + prefix=f"{prefix}.layers.0", config=config, weights=weights, ) @@ -533,11 +531,7 @@ def __init__(self, prefix, config, weights): self.layers.append( FlashLlamaCrossLayer( index=layer_id, - prefix=( - f"model.layers.{layer_id}" - if not prefix - else f"{prefix}.model.layers.{layer_id}" - ), + prefix=(f"{prefix}.layers.{layer_id}"), config=config, weights=weights, ) @@ -546,11 +540,7 @@ def __init__(self, prefix, config, weights): self.layers.append( FlashLlamaLayer( index=layer_id, - prefix=( - f"model.layers.{layer_id}" - if not prefix - else f"{prefix}.model.layers.{layer_id}" - ), + prefix=(f"{prefix}.layers.{layer_id}"), config=config, weights=weights, ) @@ -561,18 +551,14 @@ def __init__(self, prefix, config, weights): self.layers.append( FlashLlamaLayer( index=last_layer_id, - prefix=( - f"model.layers.{last_layer_id}" - if not prefix - else f"{prefix}.model.layers.{last_layer_id}" - ), + prefix=(f"{prefix}.layers.{last_layer_id}"), config=config, weights=weights, ) ) self.norm = FastRMSNorm.load( - prefix="model.norm" if not prefix else f"{prefix}.model.norm", + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps, ) @@ -629,19 +615,24 @@ def forward( class FlashLlamaForCausalLM(torch.nn.Module): - def __init__(self, prefix: str, config, weights): + def __init__(self, prefix: str, config, weights, name=None): + if name is None: + name = "model" super().__init__() - with no_fp8(weights): self.embed_tokens = TensorParallelEmbedding( prefix=( - "model.embed_tokens" + f"{name}.embed_tokens" if not prefix - else f"{prefix}.model.embed_tokens" + else f"{prefix}.{name}.embed_tokens" ), weights=weights, ) - self.model = FlashLlamaModel(prefix, config, weights) + self.model = FlashLlamaModel( + prefix=name if not prefix else f"{prefix}.{name}", + config=config, + weights=weights, + ) if config.tie_word_embeddings: suffix = "model.embed_tokens" else: @@ -652,11 +643,13 @@ def __init__(self, prefix: str, config, weights): if embedding_multiplier is not None: self.embed_tokens.weight.data *= embedding_multiplier + prefix = "lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}" + with no_fp8(weights): self.lm_head = SpeculativeHead.load( config, - prefix=suffix if not prefix else f"{prefix}.{suffix}", - weights=weights, + prefix, + weights, ) # Used in Granite diff --git a/server/text_generation_server/models/custom_modeling/idefics3.py b/server/text_generation_server/models/custom_modeling/idefics3.py new file mode 100644 index 00000000000..580398cb32e --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/idefics3.py @@ -0,0 +1,584 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Idefics3 model.""" + +from typing import List, Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from text_generation_server.models.custom_modeling.vlm import ( + load_text_model, +) +from text_generation_server.layers.attention import Seqlen +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask + +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, +) +from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Idefics3VisionEmbeddings(nn.Module): + """ + This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable + resolution. + + The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304) + which allows treating images in their native aspect ratio and without the need to resize them to the same + fixed size. In particular, we start from the original pre-trained SigLIP model + (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions. + """ + + def __init__(self, prefix, config, weights): + super().__init__() + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False + ) + self.patch_embedding.bias = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = TensorParallelEmbedding( + prefix=f"{prefix}.position_embedding", weights=weights + ) + + def forward( + self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor + ) -> torch.Tensor: + batch_size, _, max_im_h, max_im_w = pixel_values.shape + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_nb_patches_h, max_nb_patches_w = ( + max_im_h // self.patch_size, + max_im_w // self.patch_size, + ) + boundaries = torch.arange( + 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side + ) + position_ids = torch.full( + size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0 + ) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize( + fractional_coords_h, boundaries, right=True + ) + bucket_coords_w = torch.bucketize( + fractional_coords_w, boundaries, right=True + ) + + pos_ids = ( + bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w + ).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class Idefics3VisionAttention(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_size = self.embed_dim // self.num_heads + if self.head_size * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_size**-0.5 + self.dropout = config.attention_dropout + + self.num_heads = self.num_heads // weights.process_group.size() + self.embed_dim = self.embed_dim // weights.process_group.size() + + self.qkv = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=True, + ) + self.out_proj = TensorParallelRowLinear.load( + config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True + ) + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, q_len, _ = hidden_states.size() + + qkv = self.qkv(hidden_states) + query_states, key_states, value_states = qkv.split( + [ + self.head_size * self.num_heads, + self.head_size * self.num_heads, + self.head_size * self.num_heads, + ], + dim=2, + ) + + query_states = query_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + key_states = key_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + value_states = value_states.view( + batch_size, q_len, self.num_heads, self.head_size + ).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + ) + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output + + +class Idefics3VisionMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Idefics3EncoderLayer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Idefics3VisionAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.layer_norm1 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights + ) + self.layer_norm2 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights + ) + self.mlp = Idefics3VisionMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights + ) + + # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Idefics3Encoder(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [ + Idefics3EncoderLayer( + prefix=f"{prefix}.layers.{i}", config=config, weights=weights + ) + for i in range(config.num_hidden_layers) + ] + ) + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + ): + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask, + ) + return hidden_states + + +class Idefics3VisionTransformer(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embeddings = Idefics3VisionEmbeddings( + prefix=f"{prefix}.embeddings", config=config, weights=weights + ) + self.encoder = Idefics3Encoder( + prefix=f"{prefix}.encoder", config=config, weights=weights + ) + self.post_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.post_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + ): + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_size = self.config.patch_size + patch_attention_mask = torch.ones( + ( + batch_size, + pixel_values.size(2) // patch_size, + pixel_values.size(3) // patch_size, + ) + ) + patch_attention_mask = patch_attention_mask.to( + dtype=torch.bool, device=pixel_values.device + ) + + hidden_states = self.embeddings( + pixel_values=pixel_values, patch_attention_mask=patch_attention_mask + ) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, which is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + patch_attention_mask = None + else: + patch_attention_mask = _prepare_4d_attention_mask( + patch_attention_mask, hidden_states.dtype + ) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=patch_attention_mask, + ) + + last_hidden_state = encoder_outputs + last_hidden_state = self.post_layernorm(last_hidden_state) + + return last_hidden_state + + +class Idefics3SimpleMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + input_size = config.vision_config.hidden_size * (config.scale_factor**2) + output_size = config.text_config.hidden_size + proj = nn.Parameter( + weights.get_tensor(f"{prefix}.modality_projection.proj.weight"), + requires_grad=False, + ).to(weights.dtype) + self.proj = nn.Linear(input_size, output_size, bias=False) + self.proj.weight = proj + + def forward(self, x): + return self.proj(x) + + +class Idefics3Connector(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.modality_projection = Idefics3SimpleMLP(prefix, config, weights) + self.scale_factor = config.scale_factor + + def pixel_shuffle(self, x, scale_factor=2): + bsz, seq, embed_dim = x.size() + height = width = int(seq**0.5) + x = x.view(bsz, height, width, embed_dim) + x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape( + bsz, + int(width / scale_factor), + int(height / scale_factor), + embed_dim * (scale_factor**2), + ) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) + return x + + def forward(self, image_hidden_states): + image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) + image_hidden_states = self.modality_projection(image_hidden_states) + return image_hidden_states + + +class Idefics3ForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + config.text_config.quantize = config.quantize + config.text_config.speculator = config.speculator + # set tie_word_embeddings to True to load `.embed_tokens.weight` instead of `.lm_head.weight` + # since Idefics3 uses the `embed_tokens` for the final prediction + # config.text_config.tie_word_embeddings = True + + vision_config = config.vision_config + self.text_model = load_text_model( + prefix="model" if not prefix else f"{prefix}.model", + config=config.text_config, + weights=weights, + name="text_model", + ) + self.dtype = weights.dtype + + # The vision and connector models are not quantized. + with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)): + self.vision_model = Idefics3VisionTransformer( + prefix=( + f"{prefix}.model.vision_model" if prefix else "model.vision_model" + ), + config=vision_config, + weights=weights, + ) + + config.quantize = None + self.connector = Idefics3Connector( + prefix=f"{prefix}.model.connector" if prefix else "model.connector", + config=config, + weights=weights, + ) + + self.config = config + self.image_token_id = config.image_token_id + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) + + def _merge_input_ids_with_image_features( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + image_features: torch.Tensor, + ): + """In place merges in vision_embeddings with inputs_embeds.""" + # mask = input_ids == self.config.image_token_index + mask = input_ids == self.config.image_token_id + # Let's pray we have enabled enough slots ! + inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + pixel_values: torch.FloatTensor = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + # Unused here + image_sizes: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + image_indices=None, + ): + inputs_embeds = self.text_model.embed_tokens(input_ids) + if pixel_values is not None: + batch_size, num_images, num_channels, height, width = pixel_values.shape + all_states = [] + all_pixel_values = pixel_values + all_pixel_mask = pixel_attention_mask + for i in range(batch_size): + pixel_values = all_pixel_values.to( + dtype=self.dtype + ) # fp16 compatibility + pixel_values = pixel_values[i : i + 1] + pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum( + dim=(-1, -2, -3) + ) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=( + pixel_values.size(0), + pixel_values.size(2), + pixel_values.size(3), + ), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask/pP p + pixel_attention_mask = all_pixel_mask[i : i + 1] + pixel_attention_mask = pixel_attention_mask.view( + 1 * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[ + real_images_inds + ].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold( + dimension=1, size=patch_size, step=patch_size + ) + patches_subgrid = patches_subgrid.unfold( + dimension=2, size=patch_size, step=patch_size + ) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + # Modality projection & resampling + image_hidden_states = self.connector( + image_hidden_states, + ) + + all_states.append(image_hidden_states) + image_hidden_states = torch.stack(all_states, dim=0) + + inputs_embeds = self._merge_input_ids_with_image_features( + input_ids, inputs_embeds, image_hidden_states + ) + + hidden_states = self.text_model.model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + true_max_s=max_s, + prefill_cache_indices=None, + adapter_data=adapter_data, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.text_model.lm_head(hidden_states) + return logits, speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/vlm.py b/server/text_generation_server/models/custom_modeling/vlm.py index 82e409a673c..94b8522d4b6 100644 --- a/server/text_generation_server/models/custom_modeling/vlm.py +++ b/server/text_generation_server/models/custom_modeling/vlm.py @@ -4,7 +4,7 @@ def load_text_model(prefix, config, weights, name=None): FlashLlamaForCausalLM, ) - return FlashLlamaForCausalLM(prefix, config, weights) + return FlashLlamaForCausalLM(prefix, config, weights, name=name) elif config.model_type == "mistral": from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index c63ca1db0c4..19fda9f1c15 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1288,7 +1288,7 @@ def __init__( weights_loader=weights_loader, ) - prefix = "" + prefix = None model = model_class(prefix, config, weights) torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 81b4369b986..db78341d1ed 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -13,6 +13,7 @@ FlashCausalLM, ) from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION +from loguru import logger from text_generation_server.utils.log import log_master from transformers import AutoProcessor from text_generation_server.layers.attention import Seqlen @@ -23,6 +24,40 @@ IDEFICS2_FAKE_TOKEN = "" IDEFICS2_IMAGE_TOKEN = "" +IDEFICS3_IMAGE_TOKEN = "" +IDEFICS3_FAKE_IMAGE_TOKEN = "" +IDEFICS3_GLOBAL_IMG_TOKEN = "" + + +# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60 +def _prompt_split_image( + *, + image_seq_len: int, + image_rows: int, + image_cols: int, + fake_token_around_image: str, + image_token: str, + global_img_token: str, +): + """Prompt with expanded image tokens for when the image is split into patches.""" + text_split_images = "" + for n_h in range(image_rows): + for n_w in range(image_cols): + text_split_images += ( + f"{fake_token_around_image}" + + f"" + + f"{image_token}" * image_seq_len + ) + text_split_images += "\n" + + text_split_images += ( + f"\n{fake_token_around_image}" + + f"{global_img_token}" + + f"{image_token}" * image_seq_len + + f"{fake_token_around_image}" + ) + return text_split_images + def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ @@ -54,10 +89,26 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str if processor.image_processor.do_image_splitting: image_str *= 5 return image_str + if config.model_type == "idefics3": + # TODO: implement this in a more general way + n_rows = image_input["rows"][0][image_id] + n_cols = image_input["cols"][0][image_id] + image_seq_len = int( + ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) + / (config.scale_factor**2) + ) + image_str = _prompt_split_image( + image_seq_len=image_seq_len, + image_rows=n_rows, + image_cols=n_cols, + fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN, + image_token=IDEFICS3_IMAGE_TOKEN, + global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN, + ) + return image_str elif config.model_type == "llava_next": height, width = image_input["image_sizes"][image_id] num_features = get_number_of_features(height, width, config) - from loguru import logger log_master( logger.info, @@ -194,12 +245,21 @@ def batch_tokenized_inputs( raise RuntimeError(f"Invalid chunk type {chunk_type}") if images: - image_inputs = processor.image_processor(images, return_tensors="pt") + kwargs = {} + if ( + hasattr(processor, "image_processor_class") + and processor.image_processor_class == "Idefics3ImageProcessor" + ): + kwargs["return_row_col_info"] = True + + image_inputs = processor.image_processor( + images, return_tensors="pt", **kwargs + ) else: image_inputs = None - batch_inputs = [] - max_truncation = 0 + batch_tokenized_inputs = [] + max_length = 0 image_id = 0 for r in requests: full_text = "" @@ -214,16 +274,14 @@ def batch_tokenized_inputs( image_id += 1 full_text = image_text_replacement_fixup(config, full_text) - - batch_inputs.append(full_text) - max_truncation = max(max_truncation, r.truncate) - - batch_tokenized_inputs = tokenizer( - batch_inputs, - truncation=True, - max_length=max_truncation, - add_special_tokens=not config.model_type == "paligemma", - )["input_ids"] + input_ids = tokenizer( + full_text, + truncation=True, + max_length=r.truncate, + add_special_tokens=r.add_special_tokens, + )["input_ids"] + max_length = max(max_length, len(input_ids)) + batch_tokenized_inputs.append(input_ids) return batch_tokenized_inputs, image_inputs