diff --git a/lightllm/models/qwen2_5_vl/__init__.py b/lightllm/models/qwen2_5_vl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py new file mode 100644 index 000000000..23a63b0ce --- /dev/null +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -0,0 +1,552 @@ +import os +import re +import json +import torch +import torch.nn.functional as F +from PIL import Image +from typing import Any, Dict, List, Optional, Tuple, Union +from torchvision import transforms as T +from torchvision.transforms.functional import InterpolationMode +from transformers import AutoModel, AutoTokenizer +from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data +from io import BytesIO +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_utils import PreTrainedModel +import torch.nn as nn +from torch.nn import LayerNorm +from transformers.activations import ACT2FN +import math +from lightllm.models.qwen2_vl.vision_process import get_image, Qwen2VLImageProcessor +from transformers import AutoProcessor +from safetensors import safe_open +from transformers.utils import TensorType +from lightllm.server.multimodal_params import MultimodalParams, ImageItem +from lightllm.models.qwen2_vl.qwen2_visual import PatchEmbed, VisionRotaryEmbedding + +# adapted from +# https://github.com/huggingface/transformers/blob/ +# be37d34f44ff1bc928e59ffb8a30adecab8835a8/src +# /transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py#L30C1-L31C1 +class Qwen2_5_VLVisionConfig(PretrainedConfig): + model_type = "qwen2_5_vl" + + def __init__( + self, + depth=32, + hidden_size=3584, + hidden_act="silu", + intermediate_size=3420, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + tokens_per_second=4, + window_size=112, + out_hidden_size=3584, + fullatt_block_indexes=[7, 15, 23, 31], + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.tokens_per_second = tokens_per_second + self.window_size = window_size + self.fullatt_block_indexes = fullatt_block_indexes + self.out_hidden_size = out_hidden_size + + +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen2_5_VLMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + bias: bool = False, + ): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + + +class Qwen2_5_VLVisionAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.full( + [1, seq_length, seq_length], + torch.finfo(q.dtype).min, + device=q.device, + dtype=q.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[ + ..., + cu_seqlens[i - 1] : cu_seqlens[i], + cu_seqlens[i - 1] : cu_seqlens[i], + ] = 0 + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5_VLVisionSdpaAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[ + ..., + cu_seqlens[i - 1] : cu_seqlens[i], + cu_seqlens[i - 1] : cu_seqlens[i], + ] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +QWEN2_5_VL_VISION_ATTENTION_CLASSES = { + "eager": Qwen2_5_VLVisionAttention, + # "flash_attention_2": Qwen2_5_VLVisionFlashAttention2, + "sdpa": Qwen2_5_VLVisionSdpaAttention, +} + + +class Qwen2_5_VLVisionBlock(nn.Module): + def __init__( + self, + hidden_size, + intermediate_size, + num_heads, + hidden_act, + attn_implementation: str = "eager", + ) -> None: + super().__init__() + self.norm1 = Qwen2RMSNorm(hidden_size, eps=1e-6) + self.norm2 = Qwen2RMSNorm(hidden_size, eps=1e-6) + + self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation](hidden_size, num_heads=num_heads) + self.mlp = Qwen2_5_VLMLP( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + bias=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen2_5_VLPatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size ** 2) + self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +class Qwen2_5_VisionTransformerPretrainedModel(nn.Module): + def __init__( + self, + depth=32, + hidden_size=3584, + hidden_act="silu", + intermediate_size=3420, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + tokens_per_second=4, + window_size=112, + out_hidden_size=3584, + fullatt_block_indexes=[7, 15, 23, 31], + **kwargs, + ): + super().__init__() + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.tokens_per_second = tokens_per_second + self.window_size = window_size + self.fullatt_block_indexes = fullatt_block_indexes + self.out_hidden_size = out_hidden_size + + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.attn_implementation = "eager" + + self.patch_embed = PatchEmbed( + patch_size=self.patch_size, + temporal_patch_size=self.temporal_patch_size, + in_channels=self.in_channels, + embed_dim=self.hidden_size, + ) + + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [ + Qwen2_5_VLVisionBlock( + self.hidden_size, + self.intermediate_size, + self.num_heads, + self.hidden_act, + self.attn_implementation, + ) + for _ in range(self.depth) + ] + ) + + self.merger = Qwen2_5_VLPatchMerger( + dim=self.out_hidden_size, + context_dim=self.hidden_size, + spatial_merge_size=self.spatial_merge_size, + ) + + self.gradient_checkpointing = False + + self.device = self.get_device() + self.dtype = self.get_dtype() + + def get_dtype(self) -> torch.dtype: + return self.blocks[0].mlp.down_proj.weight.dtype + + def get_device(self) -> torch.device: + return self.blocks[0].mlp.down_proj.weight.device + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same + # dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 + # for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + blk.__call__, + hidden_states, + cu_seqlens_now, + None, + position_embeddings, + ) + else: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + position_embeddings=position_embeddings, + ) + + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + + return hidden_states + + def load_model(self, weight_dir): + + processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") + with open(processor_config_path, "r") as f: + processor_config_dict = json.load(f) + self.processor = Qwen2VLImageProcessor(**processor_config_dict) + + bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] + if bin_weight_files: + weight_dict = {} + for file_ in bin_weight_files: + f = torch.load(os.path.join(weight_dir, file_), "cpu") + for k, v in f.items(): + if "visual" in k: + weight_dict[k[len("visual.") :]] = v + + else: + hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")] + weight_dict = {} + for file_ in hf_weight_files: + f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") + for k in f.keys(): + if "visual" in k: + weight_dict[k[len("visual.") :]] = f.get_tensor(k) + + self.load_state_dict(weight_dict) + + def encode(self, images: List[ImageItem]): + img_tensors = [] + valid_ids = [] + valid_id = 0 + img_grids = [] + uuids = [] + + for i, img in enumerate(images): + if isinstance(img, ImageItem): + uuids.append(img.uuid) + image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = Image.open(BytesIO(image_data)) + image_data = get_image(image_data) + image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt") + pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16) + image_grid_thw = image_inputs["image_grid_thw"] + img_tensors.append(pixel_values) + img_grids.append(image_grid_thw) + else: + raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + + # must devide merge_length + cur_num = img_tensors[-1].shape[0] // (self.spatial_merge_size ** 2) + + valid_ids.append([valid_id, valid_id + cur_num]) + valid_id += cur_num + + if len(img_tensors) <= 0: + return None + + imgs = torch.cat(img_tensors, dim=0) + grid_thw = torch.cat(img_grids, dim=0) + + pixel_values = imgs.cuda().to(dtype=torch.float32) + image_grid_thw = grid_thw.cuda() + + pixel_values = pixel_values.type(self.get_dtype()) + all_img_embeds = self.forward(pixel_values, grid_thw=image_grid_thw) + + return all_img_embeds, uuids, valid_ids diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 2d8e68156..aa29eb9f5 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -177,7 +177,7 @@ def init_model(self, kvargs): self.model = Qwen2RewardTpPartModel(model_kvargs) else: self.model = Qwen2TpPartModel(model_kvargs) - elif self.model_type == "qwen2_vl": + elif self.model_type in ["qwen2_vl", "qwen2_5_vl"]: self.model = Qwen2VLTpPartModel(model_kvargs) self.is_multimodal = True elif self.model_type == "gemma": diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 35c4e846d..37cd6b5ef 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -79,7 +79,7 @@ def get_tokenizer( tokenizer = LlavaTokenizer(tokenizer, model_cfg) elif model_type == "qwen" and "visual" in model_cfg: tokenizer = QWenVLTokenizer(tokenizer, model_cfg) - elif model_type == "qwen2_vl" and "vision_config" in model_cfg: + elif model_type in ["qwen2_vl", "qwen2_5_vl"] and "vision_config" in model_cfg: from transformers import AutoProcessor image_processor = AutoProcessor.from_pretrained(tokenizer_name) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 9eae497c7..3dcb19a38 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -14,6 +14,7 @@ from lightllm.models.vit.model import VisionTransformer from lightllm.server.multimodal_params import MultimodalParams, ImageItem from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel +from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end @@ -46,6 +47,8 @@ def exposed_init_model(self, kvargs): self.model = QWenVisionTransformer(**model_cfg["visual"]).eval().bfloat16() elif self.model_type == "qwen2_vl": self.model = Qwen2VisionTransformerPretrainedModel(**model_cfg["vision_config"]).eval().bfloat16() + elif self.model_type == "qwen2_5_vl": + self.model = Qwen2_5_VisionTransformerPretrainedModel(**model_cfg["vision_config"]).eval().bfloat16() elif self.model_type == "llava": self.model = LlavaVisionModel() elif self.model_type == "internvl_chat": diff --git a/requirements.txt b/requirements.txt index 10e3f3046..96fc9da12 100644 --- a/requirements.txt +++ b/requirements.txt @@ -63,7 +63,7 @@ toolz==0.12.0 torch==2.5.1 torchvision==0.20.1 tqdm==4.65.0 -transformers==4.48.3 +transformers==4.50.1 tokenizers==0.21.0 huggingface-hub==0.26.5 triton==3.1.0