Skip to content

Commit a810b5b

Browse files
authored
[BugFix] [ROCm]: Bugfix and handle addition case of input for rocm_aiter_rms_norm (#17857)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
1 parent 009b3d5 commit a810b5b

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

tests/models/language/generation/test_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
"Qwen/Qwen-7B-Chat",
2929
"Qwen/Qwen2.5-0.5B-Instruct",
3030
"TitanML/tiny-mixtral",
31+
"Qwen/Qwen3-8B",
3132
]
3233

3334

@@ -78,6 +79,9 @@
7879
"Qwen/Qwen2.5-0.5B-Instruct", # qwen2
7980
marks=[pytest.mark.core_model],
8081
),
82+
pytest.param(
83+
"Qwen/Qwen3-8B", # qwen (text-only)
84+
),
8185
pytest.param("stabilityai/stablelm-3b-4e1t"), # stablelm
8286
pytest.param("bigcode/starcoder2-3b"), # starcoder2
8387
pytest.param(

vllm/model_executor/layers/layernorm.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
4646
variance_epsilon: float) -> torch.Tensor:
4747

4848
import aiter as rocm_aiter
49+
if x.dim() > 2:
50+
x_original_shape = x.shape
51+
x = x.reshape(-1, x_original_shape[-1])
52+
x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
53+
return x.reshape(x_original_shape)
54+
4955
return rocm_aiter.rms_norm(x, weight, variance_epsilon)
5056

5157

@@ -55,16 +61,17 @@ def rocm_aiter_fused_add_rms_norm(
5561

5662
import aiter as rocm_aiter
5763

58-
# Assuming the correct signature for rmsnorm2d_fwd_with_add
64+
residual_out = torch.empty_like(residual)
65+
output = torch.empty_like(x)
5966
rocm_aiter.rmsnorm2d_fwd_with_add(
60-
x, # output
67+
output, # output
6168
x, # input
6269
residual, # residual input
63-
residual, # residual output
70+
residual_out, # residual output
6471
weight,
6572
variance_epsilon,
6673
)
67-
return x, residual
74+
return output, residual_out
6875

6976

7077
def dispatch_cuda_rmsnorm_func(add_residual: bool):

0 commit comments

Comments
 (0)