Skip to content

Commit ca5cc26

Browse files
Eviannnevian
authored and
Yuqi Zhang
committed
[Kernel] Use fused rmsnorm for some models like qwen3 series (vllm-project#17735)
Signed-off-by: evian <eviantai@u.nus.edu> Co-authored-by: evian <eviantai@u.nus.edu> Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
1 parent 3867895 commit ca5cc26

File tree

7 files changed

+19
-15
lines changed

7 files changed

+19
-15
lines changed

csrc/layernorm_kernels.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
140140
torch::Tensor& input, // [..., hidden_size]
141141
torch::Tensor& weight, // [hidden_size]
142142
double epsilon) {
143+
TORCH_CHECK(out.is_contiguous());
144+
TORCH_CHECK(input.is_contiguous());
145+
TORCH_CHECK(weight.is_contiguous());
146+
143147
int hidden_size = input.size(-1);
144148
int num_tokens = input.numel() / hidden_size;
145149

vllm/_custom_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,9 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
186186
# layer norm ops
187187
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
188188
epsilon: float) -> None:
189-
torch.ops._C.rms_norm(out, input, weight, epsilon)
189+
# TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input
190+
input_contiguous = input.contiguous()
191+
torch.ops._C.rms_norm(out, input_contiguous, weight, epsilon)
190192

191193

192194
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,

vllm/model_executor/models/intern_vit.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,8 @@ def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
190190
if self.tp_size > 1:
191191
q = tensor_model_parallel_all_gather(q.contiguous())
192192
k = tensor_model_parallel_all_gather(k.contiguous())
193-
q = self.q_norm.forward_native(q)
194-
k = self.k_norm.forward_native(k)
193+
q = self.q_norm(q)
194+
k = self.k_norm(k)
195195
if self.tp_size > 1:
196196
splitter = partial(split_tensor_along_last_dim,
197197
num_partitions=self.tp_size)
@@ -264,10 +264,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
264264

265265
if self.qk_normalization:
266266
B_, N_, H_, D_ = q.shape
267-
q = self.q_norm.forward_native(q.flatten(-2,
268-
-1)).view(B_, N_, H_, D_)
269-
k = self.k_norm.forward_native(k.flatten(-2,
270-
-1)).view(B_, N_, H_, D_)
267+
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
268+
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
271269
q = q.transpose(1, 2)
272270
k = k.transpose(1, 2)
273271
v = v.transpose(1, 2)

vllm/model_executor/models/molmo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,8 +438,8 @@ def _apply_qk_norm(self, q: torch.Tensor,
438438
if self.tp_size > 1:
439439
q = tensor_model_parallel_all_gather(q.contiguous())
440440
k = tensor_model_parallel_all_gather(k.contiguous())
441-
q = self.q_norm.forward_native(q)
442-
k = self.k_norm.forward_native(k)
441+
q = self.q_norm(q)
442+
k = self.k_norm(k)
443443
if self.tp_size > 1:
444444
splitter = partial(split_tensor_along_last_dim,
445445
num_partitions=self.tp_size)

vllm/model_executor/models/olmo2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ def _apply_qk_norm(self, q: torch.Tensor,
139139
if self.tp_size > 1:
140140
q = tensor_model_parallel_all_gather(q.contiguous())
141141
k = tensor_model_parallel_all_gather(k.contiguous())
142-
q = self.q_norm.forward_native(q)
143-
k = self.k_norm.forward_native(k)
142+
q = self.q_norm(q)
143+
k = self.k_norm(k)
144144
if self.tp_size > 1:
145145
splitter = partial(split_tensor_along_last_dim,
146146
num_partitions=self.tp_size)

vllm/model_executor/models/qwen3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,11 @@ def forward(
133133
# Add qk-norm
134134
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
135135
self.head_dim)
136-
q_by_head = self.q_norm.forward_native(q_by_head)
136+
q_by_head = self.q_norm(q_by_head)
137137
q = q_by_head.view(q.shape)
138138
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
139139
self.head_dim)
140-
k_by_head = self.k_norm.forward_native(k_by_head)
140+
k_by_head = self.k_norm(k_by_head)
141141
k = k_by_head.view(k.shape)
142142
q, k = self.rotary_emb(positions, q, k)
143143
attn_output = self.attn(q, k, v)

vllm/model_executor/models/qwen3_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,12 +225,12 @@ def forward(
225225
# Add qk-norm
226226
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
227227
self.head_dim)
228-
q_by_head = self.q_norm.forward_native(q_by_head)
228+
q_by_head = self.q_norm(q_by_head)
229229
q = q_by_head.view(q.shape)
230230

231231
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
232232
self.head_dim)
233-
k_by_head = self.k_norm.forward_native(k_by_head)
233+
k_by_head = self.k_norm(k_by_head)
234234
k = k_by_head.view(k.shape)
235235
q, k = self.rotary_emb(positions, q, k)
236236
attn_output = self.attn(q, k, v)

0 commit comments

Comments
 (0)