|
| 1 | +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from .llama import * |
| 16 | +import os |
| 17 | +from ._utils import __version__ |
| 18 | +from .llama import ( |
| 19 | + LlamaRotaryEmbedding, |
| 20 | + LlamaLinearScalingRotaryEmbedding, |
| 21 | +) |
| 22 | +try: |
| 23 | + from transformers.models.qwen3.modeling_qwen3 import ( |
| 24 | + Qwen3Attention, |
| 25 | + Qwen3DecoderLayer, |
| 26 | + Qwen3Model, |
| 27 | + Qwen3ForCausalLM, |
| 28 | + ) |
| 29 | +except: |
| 30 | + from packaging.version import Version |
| 31 | + transformers_version = Version(transformers_version) |
| 32 | + if not transformers_version >= Version("4.50.3"): #TODO: Update when transformers is updated |
| 33 | + raise ImportError( |
| 34 | + f"Unsloth: Your transformers version of {transformers_version} does not support Qwen3 and Qwen3Moe.\n"\ |
| 35 | + f"The minimum required version is 4.50.3.\n"\ |
| 36 | + f'Try `pip install --upgrade "transformers>=4.50.3"`\n'\ |
| 37 | + f"to obtain the latest transformers build, then restart this session."\ |
| 38 | + ) |
| 39 | + pass |
| 40 | + |
| 41 | +# For Pytorch 2.1.1 |
| 42 | +try: |
| 43 | + from transformers.models.qwen3.modeling_qwen3 import ( |
| 44 | + Qwen3SdpaAttention, |
| 45 | + Qwen3FlashAttention2, |
| 46 | + ) |
| 47 | +except: |
| 48 | + Qwen3SdpaAttention = Qwen3Attention |
| 49 | + Qwen3FlashAttention2 = Qwen3Attention |
| 50 | +pass |
| 51 | +from unsloth_zoo.utils import Version, _get_dtype |
| 52 | + |
| 53 | + |
| 54 | +def Qwen3Attention_fast_forward( |
| 55 | + self, |
| 56 | + hidden_states: torch.Tensor, |
| 57 | + causal_mask: Optional[BlockDiagonalCausalMask] = None, |
| 58 | + attention_mask: Optional[torch.Tensor] = None, |
| 59 | + position_ids: Optional[torch.LongTensor] = None, |
| 60 | + past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| 61 | + output_attentions: bool = False, |
| 62 | + use_cache: bool = False, |
| 63 | + padding_mask: Optional[torch.LongTensor] = None, |
| 64 | + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| 65 | + *args, **kwargs, |
| 66 | +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| 67 | + |
| 68 | + # Clear inference |
| 69 | + if hasattr(self, "paged_attention"): |
| 70 | + del self.paged_attention_K |
| 71 | + del self.paged_attention_V |
| 72 | + del self.paged_attention |
| 73 | + del self.temp_QA |
| 74 | + del self.temp_KV |
| 75 | + del self.RH_Q |
| 76 | + del self.attention |
| 77 | + pass |
| 78 | + |
| 79 | + bsz, q_len, _ = hidden_states.size() |
| 80 | + |
| 81 | + n_heads = self.config.num_attention_heads |
| 82 | + n_groups = self.num_key_value_groups |
| 83 | + n_kv_heads = self.config.num_key_value_heads |
| 84 | + head_dim = self.head_dim |
| 85 | + assert(n_kv_heads * n_groups == n_heads) |
| 86 | + |
| 87 | + Q, K, V = self.apply_qkv(self, hidden_states) |
| 88 | + Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) |
| 89 | + K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) |
| 90 | + V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) |
| 91 | + |
| 92 | + #Qwen3 has QKNorm. This seems to be the only difference from Qwen2. |
| 93 | + Q = fast_layernorm_compiled(self.q_norm, Q) |
| 94 | + K = fast_layernorm_compiled(self.k_norm, K) |
| 95 | + |
| 96 | + kv_seq_len = K.shape[-2] |
| 97 | + if past_key_value is not None: |
| 98 | + kv_seq_len += past_key_value[0].shape[-2] |
| 99 | + |
| 100 | + # Extend RoPE dynamically to fit in VRAM |
| 101 | + self.rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len) |
| 102 | + |
| 103 | + if position_ids is None: |
| 104 | + cos = self.rotary_emb.cos_cached |
| 105 | + sin = self.rotary_emb.sin_cached |
| 106 | + Q, K = fast_rope_embedding(Q, K, cos, sin) |
| 107 | + else: |
| 108 | + cos, sin = self.rotary_emb(V, seq_len = kv_seq_len) |
| 109 | + Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) |
| 110 | + pass |
| 111 | + |
| 112 | + if past_key_value is not None: |
| 113 | + K = torch.cat([past_key_value[0], K], dim = 2) |
| 114 | + V = torch.cat([past_key_value[1], V], dim = 2) |
| 115 | + pass |
| 116 | + past_key_value = (K, V) if use_cache else None |
| 117 | + |
| 118 | + # Attention module |
| 119 | + if (not HAS_FLASH_ATTENTION and attention_mask is None): |
| 120 | + # Xformers memory efficient attention |
| 121 | + Q = Q.transpose(1, 2) |
| 122 | + K = K.transpose(1, 2) |
| 123 | + V = V.transpose(1, 2) |
| 124 | + K_M = V_M = bsz * kv_seq_len |
| 125 | + Q_M = bsz * q_len |
| 126 | + |
| 127 | + has_swa = isinstance(causal_mask, xformers.attn_bias.BlockDiagonalCausalMask) |
| 128 | + |
| 129 | + # Group query attention |
| 130 | + K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) |
| 131 | + V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) |
| 132 | + K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) |
| 133 | + V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) |
| 134 | + if hidden_states.requires_grad: |
| 135 | + K = K.reshape(bsz, kv_seq_len, n_heads, head_dim) |
| 136 | + V = V.reshape(bsz, kv_seq_len, n_heads, head_dim) |
| 137 | + |
| 138 | + if has_swa: |
| 139 | + Q = Q.view(1, Q_M, n_heads, head_dim) |
| 140 | + K = K.view(1, K_M, n_heads, head_dim) |
| 141 | + V = V.view(1, V_M, n_heads, head_dim) |
| 142 | + pass |
| 143 | + else: |
| 144 | + # Xformers does support the forward pass though |
| 145 | + Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim) |
| 146 | + |
| 147 | + if has_swa: |
| 148 | + Q = Q.view(1, Q_M, n_kv_heads, n_groups, head_dim) |
| 149 | + K = K.view(1, K_M, n_kv_heads, n_groups, head_dim) |
| 150 | + V = V.view(1, V_M, n_kv_heads, n_groups, head_dim) |
| 151 | + pass |
| 152 | + pass |
| 153 | + |
| 154 | + A = xformers_attention(Q, K, V, attn_bias = causal_mask) |
| 155 | + A = A.view(bsz, q_len, n_heads, head_dim) |
| 156 | + |
| 157 | + elif HAS_FLASH_ATTENTION and attention_mask is None: |
| 158 | + Q = Q.transpose(1, 2) |
| 159 | + K = K.transpose(1, 2) |
| 160 | + V = V.transpose(1, 2) |
| 161 | + sw = getattr(self.config, "sliding_window", None) |
| 162 | + sw = kv_seq_len if (sw is None or sw == "null") else sw |
| 163 | + window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw) |
| 164 | + A = flash_attn_func(Q, K, V, causal = True, window_size = window) |
| 165 | + else: |
| 166 | + # Grouped query attention |
| 167 | + # if n_groups != 1: |
| 168 | + K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) |
| 169 | + V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) |
| 170 | + K = K.reshape(bsz, n_heads, kv_seq_len, head_dim) |
| 171 | + V = V.reshape(bsz, n_heads, kv_seq_len, head_dim) |
| 172 | + # pass |
| 173 | + # Must be contiguous or else results are False! |
| 174 | + # https://github.com/pytorch/pytorch/issues/112577 |
| 175 | + Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() |
| 176 | + # Needs (batch_size, n_heads, seq_len, head_dim) |
| 177 | + # is_casual and attention_mask must not be both set! |
| 178 | + A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False) |
| 179 | + # Go back to (batch_size, seq_len, n_heads, head_dim) |
| 180 | + A = A.transpose(1, 2).contiguous() |
| 181 | + pass |
| 182 | + |
| 183 | + attn_output = A.reshape(bsz, q_len, n_heads*head_dim) |
| 184 | + attn_output = self.apply_o(self, attn_output) |
| 185 | + attn_weights = None |
| 186 | + return attn_output, attn_weights, past_key_value |
| 187 | +pass |
| 188 | + |
| 189 | + |
| 190 | +class FastQwen3Model(FastLlamaModel): |
| 191 | + |
| 192 | + @staticmethod |
| 193 | + def pre_patch(): |
| 194 | + init_name, function = patch_linear_scaling( |
| 195 | + model_name = "Qwen3", |
| 196 | + rope_module = LlamaRotaryEmbedding, |
| 197 | + scaled_rope_module = LlamaLinearScalingRotaryEmbedding, |
| 198 | + attention_module = Qwen3Attention, |
| 199 | + ) |
| 200 | + if init_name is not None: |
| 201 | + exec(function, globals()) |
| 202 | + Qwen3Attention.__init__ = eval(init_name) |
| 203 | + pass |
| 204 | + Qwen3Attention .forward = Qwen3Attention_fast_forward |
| 205 | + Qwen3SdpaAttention .forward = Qwen3Attention_fast_forward |
| 206 | + Qwen3FlashAttention2.forward = Qwen3Attention_fast_forward |
| 207 | + Qwen3DecoderLayer .forward = LlamaDecoderLayer_fast_forward |
| 208 | + Qwen3Model .forward = LlamaModel_fast_forward |
| 209 | + Qwen3ForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference) |
| 210 | + PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward |
| 211 | + fix_prepare_inputs_for_generation(Qwen3ForCausalLM) |
| 212 | + |
| 213 | + # Solves https://github.com/unslothai/unsloth/issues/168 |
| 214 | + # Static KV Cache was introduced in 4.38.0, causing training to be much slower. |
| 215 | + # Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings. |
| 216 | + # https://github.com/huggingface/transformers/pull/27931 |
| 217 | + # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py |
| 218 | + import transformers.models.qwen3.modeling_qwen3 |
| 219 | + transformers.models.Qwen3.modeling_qwen3.Qwen3RotaryEmbedding = LlamaRotaryEmbedding |
| 220 | + return |
| 221 | + pass |
| 222 | + |
| 223 | + |
| 224 | + @staticmethod |
| 225 | + def from_pretrained( #TODO: Change after release |
| 226 | + model_name = "Qwen/Qwen3-7B", |
| 227 | + max_seq_length = 4096, |
| 228 | + dtype = None, |
| 229 | + load_in_4bit = True, |
| 230 | + token = None, |
| 231 | + device_map = "sequential", |
| 232 | + rope_scaling = None, |
| 233 | + fix_tokenizer = True, |
| 234 | + model_patcher = None, |
| 235 | + tokenizer_name = None, |
| 236 | + trust_remote_code = False, |
| 237 | + **kwargs, |
| 238 | + ): |
| 239 | + return FastLlamaModel.from_pretrained( |
| 240 | + model_name = model_name, |
| 241 | + max_seq_length = max_seq_length, |
| 242 | + dtype = dtype, |
| 243 | + load_in_4bit = load_in_4bit, |
| 244 | + token = token, |
| 245 | + device_map = device_map, |
| 246 | + rope_scaling = rope_scaling, |
| 247 | + fix_tokenizer = fix_tokenizer, |
| 248 | + model_patcher = FastQwen3Model, |
| 249 | + tokenizer_name = tokenizer_name, |
| 250 | + trust_remote_code = trust_remote_code, |
| 251 | + **kwargs, |
| 252 | + ) |
| 253 | + pass |
| 254 | +pass |
0 commit comments