Skip to content

[WIP] Initial support for Qwen3. Will udpate when the model is released #2211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions unsloth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .llama import FastLlamaModel
from .loader import FastLanguageModel, FastVisionModel, FastTextModel, FastModel
from .mistral import FastMistralModel
from .qwen2 import FastQwen2Model
from .granite import FastGraniteModel
from .dpo import PatchDPOTrainer, PatchKTOTrainer
from ._utils import is_bfloat16_supported, __version__
from .rl import PatchFastRL, vLLMSamplingParams
from .llama import FastLlamaModel
from .loader import FastLanguageModel, FastVisionModel, FastTextModel, FastModel
from .mistral import FastMistralModel
from .qwen2 import FastQwen2Model
from .qwen3 import FastQwen3Model
from .qwen3_moe import FastQwen3MoeModel
from .granite import FastGraniteModel
from .dpo import PatchDPOTrainer, PatchKTOTrainer
from ._utils import is_bfloat16_supported, __version__
from .rl import PatchFastRL, vLLMSamplingParams
2 changes: 1 addition & 1 deletion unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def patch_mistral_nemo_config(config):

from transformers import __version__ as transformers_version
from transformers import PretrainedConfig
model_architectures = ["llama", "mistral", "gemma", "gemma2", "qwen2", "granite"]
model_architectures = ["llama", "mistral", "gemma", "gemma2", "qwen2", "granite", "qwen3", "qwen3_moe"]

for model_name in model_architectures:
config_filepath = f"transformers.models.{model_name}.configuration_{model_name}"
Expand Down
13 changes: 13 additions & 0 deletions unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from .llama import FastLlamaModel, logger
from .mistral import FastMistralModel
from .qwen2 import FastQwen2Model
from .qwen3 import FastQwen3Model
from .qwen3_moe import FastQwen3MoeModel
from .cohere import FastCohereModel
from transformers import AutoConfig
from transformers import __version__ as transformers_version
Expand Down Expand Up @@ -51,6 +53,8 @@
SUPPORTS_LLAMA31 = transformers_version >= Version("4.43.2")
SUPPORTS_LLAMA32 = transformers_version > Version("4.45.0")
SUPPORTS_GRANITE = transformers_version >= Version("4.46.0")
SUPPORTS_QWEN3 = transformers_version >= Version("4.50.3")
SUPPORTS_QWEN3_MOE = transformers_version >= Version("4.50.3")
if SUPPORTS_GEMMA:
from .gemma import FastGemmaModel
if SUPPORTS_GEMMA2:
Expand Down Expand Up @@ -298,6 +302,15 @@ def from_pretrained(
dispatch_model = FastGemma2Model
elif model_type == "qwen2":
dispatch_model = FastQwen2Model
elif model_type == "qwen3" or model_type == "qwen3_moe":
if not SUPPORTS_QWEN3 or not SUPPORTS_QWEN3_MOE:
raise ImportError(
f"Unsloth: Your transformers version of {transformers_version} does not support Qwen3.\n"\
f"The minimum required version is 4.50.3.\n"\
f'Try `pip install --upgrade "transformers>=4.50.3"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
dispatch_model = FastQwen3Model if model_type == "qwen3" else FastQwen3MoeModel
# Temporary disable optimized Cohere until errors match
# elif model_type == "cohere":
# dispatch_model = FastCohereModel
Expand Down
254 changes: 254 additions & 0 deletions unsloth/models/qwen3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .llama import *
import os
from ._utils import __version__
from .llama import (
LlamaRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
)
try:
from transformers.models.qwen3.modeling_qwen3 import (
Qwen3Attention,
Qwen3DecoderLayer,
Qwen3Model,
Qwen3ForCausalLM,
)
except:
from packaging.version import Version
transformers_version = Version(transformers_version)
if not transformers_version >= Version("4.50.3"): #TODO: Update when transformers is updated
raise ImportError(
f"Unsloth: Your transformers version of {transformers_version} does not support Qwen3 and Qwen3Moe.\n"\
f"The minimum required version is 4.50.3.\n"\
f'Try `pip install --upgrade "transformers>=4.50.3"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
pass

# For Pytorch 2.1.1
try:
from transformers.models.qwen3.modeling_qwen3 import (
Qwen3SdpaAttention,
Qwen3FlashAttention2,
)
except:
Qwen3SdpaAttention = Qwen3Attention
Qwen3FlashAttention2 = Qwen3Attention
pass
from unsloth_zoo.utils import Version, _get_dtype


def Qwen3Attention_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

# Clear inference
if hasattr(self, "paged_attention"):
del self.paged_attention_K
del self.paged_attention_V
del self.paged_attention
del self.temp_QA
del self.temp_KV
del self.RH_Q
del self.attention
pass

bsz, q_len, _ = hidden_states.size()

n_heads = self.config.num_attention_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.config.num_key_value_heads
head_dim = self.head_dim
assert(n_kv_heads * n_groups == n_heads)

Q, K, V = self.apply_qkv(self, hidden_states)
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)

#Qwen3 has QKNorm. This seems to be the only difference from Qwen2.
Q = fast_layernorm_compiled(self.q_norm, Q)
K = fast_layernorm_compiled(self.k_norm, K)

kv_seq_len = K.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

# Extend RoPE dynamically to fit in VRAM
self.rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)

if position_ids is None:
cos = self.rotary_emb.cos_cached
sin = self.rotary_emb.sin_cached
Q, K = fast_rope_embedding(Q, K, cos, sin)
else:
cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
pass

if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
pass
past_key_value = (K, V) if use_cache else None

# Attention module
if (not HAS_FLASH_ATTENTION and attention_mask is None):
# Xformers memory efficient attention
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
K_M = V_M = bsz * kv_seq_len
Q_M = bsz * q_len

has_swa = isinstance(causal_mask, xformers.attn_bias.BlockDiagonalCausalMask)

# Group query attention
K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
if hidden_states.requires_grad:
K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)

if has_swa:
Q = Q.view(1, Q_M, n_heads, head_dim)
K = K.view(1, K_M, n_heads, head_dim)
V = V.view(1, V_M, n_heads, head_dim)
pass
else:
# Xformers does support the forward pass though
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)

if has_swa:
Q = Q.view(1, Q_M, n_kv_heads, n_groups, head_dim)
K = K.view(1, K_M, n_kv_heads, n_groups, head_dim)
V = V.view(1, V_M, n_kv_heads, n_groups, head_dim)
pass
pass

A = xformers_attention(Q, K, V, attn_bias = causal_mask)
A = A.view(bsz, q_len, n_heads, head_dim)

elif HAS_FLASH_ATTENTION and attention_mask is None:
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
sw = getattr(self.config, "sliding_window", None)
sw = kv_seq_len if (sw is None or sw == "null") else sw
window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
A = flash_attn_func(Q, K, V, causal = True, window_size = window)
else:
# Grouped query attention
# if n_groups != 1:
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
# pass
# Must be contiguous or else results are False!
# https://github.com/pytorch/pytorch/issues/112577
Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
# Needs (batch_size, n_heads, seq_len, head_dim)
# is_casual and attention_mask must not be both set!
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2).contiguous()
pass

attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
attn_output = self.apply_o(self, attn_output)
attn_weights = None
return attn_output, attn_weights, past_key_value
pass


class FastQwen3Model(FastLlamaModel):

@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
model_name = "Qwen3",
rope_module = LlamaRotaryEmbedding,
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
attention_module = Qwen3Attention,
)
if init_name is not None:
exec(function, globals())
Qwen3Attention.__init__ = eval(init_name)
pass
Qwen3Attention .forward = Qwen3Attention_fast_forward
Qwen3SdpaAttention .forward = Qwen3Attention_fast_forward
Qwen3FlashAttention2.forward = Qwen3Attention_fast_forward
Qwen3DecoderLayer .forward = LlamaDecoderLayer_fast_forward
Qwen3Model .forward = LlamaModel_fast_forward
Qwen3ForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference)
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(Qwen3ForCausalLM)

# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.qwen3.modeling_qwen3
transformers.models.Qwen3.modeling_qwen3.Qwen3RotaryEmbedding = LlamaRotaryEmbedding
return
pass


@staticmethod
def from_pretrained( #TODO: Change after release
model_name = "Qwen/Qwen3-7B",
max_seq_length = 4096,
dtype = None,
load_in_4bit = True,
token = None,
device_map = "sequential",
rope_scaling = None,
fix_tokenizer = True,
model_patcher = None,
tokenizer_name = None,
trust_remote_code = False,
**kwargs,
):
return FastLlamaModel.from_pretrained(
model_name = model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
token = token,
device_map = device_map,
rope_scaling = rope_scaling,
fix_tokenizer = fix_tokenizer,
model_patcher = FastQwen3Model,
tokenizer_name = tokenizer_name,
trust_remote_code = trust_remote_code,
**kwargs,
)
pass
pass
Loading