Skip to content

Commit f9d626c

Browse files
Merge pull request #2211 from Datta0/qwen3_support
[WIP] Initial support for Qwen3. Will udpate when the model is released
2 parents 7a8f99e + fa11441 commit f9d626c

File tree

5 files changed

+502
-9
lines changed

5 files changed

+502
-9
lines changed

unsloth/models/__init__.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .llama import FastLlamaModel
16-
from .loader import FastLanguageModel, FastVisionModel, FastTextModel, FastModel
17-
from .mistral import FastMistralModel
18-
from .qwen2 import FastQwen2Model
19-
from .granite import FastGraniteModel
20-
from .dpo import PatchDPOTrainer, PatchKTOTrainer
21-
from ._utils import is_bfloat16_supported, __version__
22-
from .rl import PatchFastRL, vLLMSamplingParams
15+
from .llama import FastLlamaModel
16+
from .loader import FastLanguageModel, FastVisionModel, FastTextModel, FastModel
17+
from .mistral import FastMistralModel
18+
from .qwen2 import FastQwen2Model
19+
from .qwen3 import FastQwen3Model
20+
from .qwen3_moe import FastQwen3MoeModel
21+
from .granite import FastGraniteModel
22+
from .dpo import PatchDPOTrainer, PatchKTOTrainer
23+
from ._utils import is_bfloat16_supported, __version__
24+
from .rl import PatchFastRL, vLLMSamplingParams

unsloth/models/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def patch_mistral_nemo_config(config):
243243

244244
from transformers import __version__ as transformers_version
245245
from transformers import PretrainedConfig
246-
model_architectures = ["llama", "mistral", "gemma", "gemma2", "qwen2", "granite"]
246+
model_architectures = ["llama", "mistral", "gemma", "gemma2", "qwen2", "granite", "qwen3", "qwen3_moe"]
247247

248248
for model_name in model_architectures:
249249
config_filepath = f"transformers.models.{model_name}.configuration_{model_name}"

unsloth/models/loader.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from .llama import FastLlamaModel, logger
2424
from .mistral import FastMistralModel
2525
from .qwen2 import FastQwen2Model
26+
from .qwen3 import FastQwen3Model
27+
from .qwen3_moe import FastQwen3MoeModel
2628
from .cohere import FastCohereModel
2729
from transformers import AutoConfig
2830
from transformers import __version__ as transformers_version
@@ -51,6 +53,8 @@
5153
SUPPORTS_LLAMA31 = transformers_version >= Version("4.43.2")
5254
SUPPORTS_LLAMA32 = transformers_version > Version("4.45.0")
5355
SUPPORTS_GRANITE = transformers_version >= Version("4.46.0")
56+
SUPPORTS_QWEN3 = transformers_version >= Version("4.50.3")
57+
SUPPORTS_QWEN3_MOE = transformers_version >= Version("4.50.3")
5458
if SUPPORTS_GEMMA:
5559
from .gemma import FastGemmaModel
5660
if SUPPORTS_GEMMA2:
@@ -298,6 +302,15 @@ def from_pretrained(
298302
dispatch_model = FastGemma2Model
299303
elif model_type == "qwen2":
300304
dispatch_model = FastQwen2Model
305+
elif model_type == "qwen3" or model_type == "qwen3_moe":
306+
if not SUPPORTS_QWEN3 or not SUPPORTS_QWEN3_MOE:
307+
raise ImportError(
308+
f"Unsloth: Your transformers version of {transformers_version} does not support Qwen3.\n"\
309+
f"The minimum required version is 4.50.3.\n"\
310+
f'Try `pip install --upgrade "transformers>=4.50.3"`\n'\
311+
f"to obtain the latest transformers build, then restart this session."\
312+
)
313+
dispatch_model = FastQwen3Model if model_type == "qwen3" else FastQwen3MoeModel
301314
# Temporary disable optimized Cohere until errors match
302315
# elif model_type == "cohere":
303316
# dispatch_model = FastCohereModel

unsloth/models/qwen3.py

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .llama import *
16+
import os
17+
from ._utils import __version__
18+
from .llama import (
19+
LlamaRotaryEmbedding,
20+
LlamaLinearScalingRotaryEmbedding,
21+
)
22+
try:
23+
from transformers.models.qwen3.modeling_qwen3 import (
24+
Qwen3Attention,
25+
Qwen3DecoderLayer,
26+
Qwen3Model,
27+
Qwen3ForCausalLM,
28+
)
29+
except:
30+
from packaging.version import Version
31+
transformers_version = Version(transformers_version)
32+
if not transformers_version >= Version("4.50.3"): #TODO: Update when transformers is updated
33+
raise ImportError(
34+
f"Unsloth: Your transformers version of {transformers_version} does not support Qwen3 and Qwen3Moe.\n"\
35+
f"The minimum required version is 4.50.3.\n"\
36+
f'Try `pip install --upgrade "transformers>=4.50.3"`\n'\
37+
f"to obtain the latest transformers build, then restart this session."\
38+
)
39+
pass
40+
41+
# For Pytorch 2.1.1
42+
try:
43+
from transformers.models.qwen3.modeling_qwen3 import (
44+
Qwen3SdpaAttention,
45+
Qwen3FlashAttention2,
46+
)
47+
except:
48+
Qwen3SdpaAttention = Qwen3Attention
49+
Qwen3FlashAttention2 = Qwen3Attention
50+
pass
51+
from unsloth_zoo.utils import Version, _get_dtype
52+
53+
54+
def Qwen3Attention_fast_forward(
55+
self,
56+
hidden_states: torch.Tensor,
57+
causal_mask: Optional[BlockDiagonalCausalMask] = None,
58+
attention_mask: Optional[torch.Tensor] = None,
59+
position_ids: Optional[torch.LongTensor] = None,
60+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
61+
output_attentions: bool = False,
62+
use_cache: bool = False,
63+
padding_mask: Optional[torch.LongTensor] = None,
64+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
65+
*args, **kwargs,
66+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
67+
68+
# Clear inference
69+
if hasattr(self, "paged_attention"):
70+
del self.paged_attention_K
71+
del self.paged_attention_V
72+
del self.paged_attention
73+
del self.temp_QA
74+
del self.temp_KV
75+
del self.RH_Q
76+
del self.attention
77+
pass
78+
79+
bsz, q_len, _ = hidden_states.size()
80+
81+
n_heads = self.config.num_attention_heads
82+
n_groups = self.num_key_value_groups
83+
n_kv_heads = self.config.num_key_value_heads
84+
head_dim = self.head_dim
85+
assert(n_kv_heads * n_groups == n_heads)
86+
87+
Q, K, V = self.apply_qkv(self, hidden_states)
88+
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
89+
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
90+
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
91+
92+
#Qwen3 has QKNorm. This seems to be the only difference from Qwen2.
93+
Q = fast_layernorm_compiled(self.q_norm, Q)
94+
K = fast_layernorm_compiled(self.k_norm, K)
95+
96+
kv_seq_len = K.shape[-2]
97+
if past_key_value is not None:
98+
kv_seq_len += past_key_value[0].shape[-2]
99+
100+
# Extend RoPE dynamically to fit in VRAM
101+
self.rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
102+
103+
if position_ids is None:
104+
cos = self.rotary_emb.cos_cached
105+
sin = self.rotary_emb.sin_cached
106+
Q, K = fast_rope_embedding(Q, K, cos, sin)
107+
else:
108+
cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
109+
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
110+
pass
111+
112+
if past_key_value is not None:
113+
K = torch.cat([past_key_value[0], K], dim = 2)
114+
V = torch.cat([past_key_value[1], V], dim = 2)
115+
pass
116+
past_key_value = (K, V) if use_cache else None
117+
118+
# Attention module
119+
if (not HAS_FLASH_ATTENTION and attention_mask is None):
120+
# Xformers memory efficient attention
121+
Q = Q.transpose(1, 2)
122+
K = K.transpose(1, 2)
123+
V = V.transpose(1, 2)
124+
K_M = V_M = bsz * kv_seq_len
125+
Q_M = bsz * q_len
126+
127+
has_swa = isinstance(causal_mask, xformers.attn_bias.BlockDiagonalCausalMask)
128+
129+
# Group query attention
130+
K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
131+
V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
132+
K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
133+
V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
134+
if hidden_states.requires_grad:
135+
K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
136+
V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
137+
138+
if has_swa:
139+
Q = Q.view(1, Q_M, n_heads, head_dim)
140+
K = K.view(1, K_M, n_heads, head_dim)
141+
V = V.view(1, V_M, n_heads, head_dim)
142+
pass
143+
else:
144+
# Xformers does support the forward pass though
145+
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
146+
147+
if has_swa:
148+
Q = Q.view(1, Q_M, n_kv_heads, n_groups, head_dim)
149+
K = K.view(1, K_M, n_kv_heads, n_groups, head_dim)
150+
V = V.view(1, V_M, n_kv_heads, n_groups, head_dim)
151+
pass
152+
pass
153+
154+
A = xformers_attention(Q, K, V, attn_bias = causal_mask)
155+
A = A.view(bsz, q_len, n_heads, head_dim)
156+
157+
elif HAS_FLASH_ATTENTION and attention_mask is None:
158+
Q = Q.transpose(1, 2)
159+
K = K.transpose(1, 2)
160+
V = V.transpose(1, 2)
161+
sw = getattr(self.config, "sliding_window", None)
162+
sw = kv_seq_len if (sw is None or sw == "null") else sw
163+
window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
164+
A = flash_attn_func(Q, K, V, causal = True, window_size = window)
165+
else:
166+
# Grouped query attention
167+
# if n_groups != 1:
168+
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
169+
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
170+
K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
171+
V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
172+
# pass
173+
# Must be contiguous or else results are False!
174+
# https://github.com/pytorch/pytorch/issues/112577
175+
Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
176+
# Needs (batch_size, n_heads, seq_len, head_dim)
177+
# is_casual and attention_mask must not be both set!
178+
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False)
179+
# Go back to (batch_size, seq_len, n_heads, head_dim)
180+
A = A.transpose(1, 2).contiguous()
181+
pass
182+
183+
attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
184+
attn_output = self.apply_o(self, attn_output)
185+
attn_weights = None
186+
return attn_output, attn_weights, past_key_value
187+
pass
188+
189+
190+
class FastQwen3Model(FastLlamaModel):
191+
192+
@staticmethod
193+
def pre_patch():
194+
init_name, function = patch_linear_scaling(
195+
model_name = "Qwen3",
196+
rope_module = LlamaRotaryEmbedding,
197+
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
198+
attention_module = Qwen3Attention,
199+
)
200+
if init_name is not None:
201+
exec(function, globals())
202+
Qwen3Attention.__init__ = eval(init_name)
203+
pass
204+
Qwen3Attention .forward = Qwen3Attention_fast_forward
205+
Qwen3SdpaAttention .forward = Qwen3Attention_fast_forward
206+
Qwen3FlashAttention2.forward = Qwen3Attention_fast_forward
207+
Qwen3DecoderLayer .forward = LlamaDecoderLayer_fast_forward
208+
Qwen3Model .forward = LlamaModel_fast_forward
209+
Qwen3ForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference)
210+
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
211+
fix_prepare_inputs_for_generation(Qwen3ForCausalLM)
212+
213+
# Solves https://github.com/unslothai/unsloth/issues/168
214+
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
215+
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
216+
# https://github.com/huggingface/transformers/pull/27931
217+
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
218+
import transformers.models.qwen3.modeling_qwen3
219+
transformers.models.Qwen3.modeling_qwen3.Qwen3RotaryEmbedding = LlamaRotaryEmbedding
220+
return
221+
pass
222+
223+
224+
@staticmethod
225+
def from_pretrained( #TODO: Change after release
226+
model_name = "Qwen/Qwen3-7B",
227+
max_seq_length = 4096,
228+
dtype = None,
229+
load_in_4bit = True,
230+
token = None,
231+
device_map = "sequential",
232+
rope_scaling = None,
233+
fix_tokenizer = True,
234+
model_patcher = None,
235+
tokenizer_name = None,
236+
trust_remote_code = False,
237+
**kwargs,
238+
):
239+
return FastLlamaModel.from_pretrained(
240+
model_name = model_name,
241+
max_seq_length = max_seq_length,
242+
dtype = dtype,
243+
load_in_4bit = load_in_4bit,
244+
token = token,
245+
device_map = device_map,
246+
rope_scaling = rope_scaling,
247+
fix_tokenizer = fix_tokenizer,
248+
model_patcher = FastQwen3Model,
249+
tokenizer_name = tokenizer_name,
250+
trust_remote_code = trust_remote_code,
251+
**kwargs,
252+
)
253+
pass
254+
pass

0 commit comments

Comments
 (0)