Skip to content

Commit b29bacf

Browse files
committedMay 2, 2025
1. fix call FusedMoE.select_experts failed
2. fix potential overflow and remove debug cruft with tlrmchlsmth's review 3. add benchmark for performance Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
1 parent 840cd41 commit b29bacf

File tree

4 files changed

+351
-5
lines changed

4 files changed

+351
-5
lines changed
 
Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import argparse
4+
from typing import Any, TypedDict
5+
6+
import ray
7+
import torch
8+
from transformers import AutoConfig
9+
10+
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
11+
_moe_permute, _moe_unpermute_and_reduce)
12+
from vllm.model_executor.layers.fused_moe.fused_moe import *
13+
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import *
14+
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
15+
from vllm.platforms import current_platform
16+
from vllm.utils import FlexibleArgumentParser
17+
18+
FP8_DTYPE = current_platform.fp8_dtype()
19+
20+
21+
class BenchmarkConfig(TypedDict):
22+
BLOCK_SIZE_M: int
23+
BLOCK_SIZE_N: int
24+
BLOCK_SIZE_K: int
25+
GROUP_SIZE_M: int
26+
num_warps: int
27+
num_stages: int
28+
29+
30+
def benchmark_permute(num_tokens: int,
31+
num_experts: int,
32+
hidden_size: int,
33+
topk: int,
34+
dtype: torch.dtype,
35+
use_fp8_w8a8: bool,
36+
use_int8_w8a16: bool,
37+
num_iters: int = 100,
38+
use_customized_permute: bool = False) -> float:
39+
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
40+
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
41+
# output_hidden_states = torch.empty_like(hidden_states)
42+
if use_fp8_w8a8:
43+
align_block_size = 128 # deepgemm needs 128 m aligned block
44+
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
45+
else:
46+
align_block_size = None
47+
qhidden_states = hidden_states
48+
49+
gating_output = torch.randn(num_iters,
50+
num_tokens,
51+
num_experts,
52+
dtype=torch.float32)
53+
54+
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
55+
topk_weights, topk_ids, token_expert_indices = fused_topk(
56+
qhidden_states, input_gating, topk, False)
57+
58+
def prepare(i: int):
59+
input_gating.copy_(gating_output[i])
60+
61+
def run():
62+
if use_customized_permute:
63+
(permuted_hidden_states, first_token_off, inv_perm_idx,
64+
m_indices) = moe_permute(
65+
qhidden_states,
66+
topk_weights=topk_weights,
67+
topk_ids=topk_ids,
68+
token_expert_indices=token_expert_indices,
69+
topk=topk,
70+
n_expert=num_experts,
71+
n_local_expert=num_experts,
72+
expert_map=None,
73+
align_block_size=align_block_size,
74+
)
75+
else:
76+
(permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
77+
inv_perm) = _moe_permute(qhidden_states, None, topk_ids,
78+
num_experts, None, align_block_size)
79+
80+
# JIT compilation & warmup
81+
run()
82+
torch.cuda.synchronize()
83+
84+
# Capture 10 invocations with CUDA graph
85+
graph = torch.cuda.CUDAGraph()
86+
with torch.cuda.graph(graph):
87+
for _ in range(10):
88+
run()
89+
torch.cuda.synchronize()
90+
91+
# Warmup
92+
for _ in range(5):
93+
graph.replay()
94+
torch.cuda.synchronize()
95+
96+
start_event = torch.cuda.Event(enable_timing=True)
97+
end_event = torch.cuda.Event(enable_timing=True)
98+
99+
latencies: list[float] = []
100+
for i in range(num_iters):
101+
prepare(i)
102+
torch.cuda.synchronize()
103+
104+
start_event.record()
105+
graph.replay()
106+
end_event.record()
107+
end_event.synchronize()
108+
latencies.append(start_event.elapsed_time(end_event))
109+
avg = sum(latencies) / (num_iters * 10) * 1000 # us
110+
graph.reset()
111+
return avg
112+
113+
114+
def benchmark_unpermute(num_tokens: int,
115+
num_experts: int,
116+
hidden_size: int,
117+
topk: int,
118+
dtype: torch.dtype,
119+
use_fp8_w8a8: bool,
120+
use_int8_w8a16: bool,
121+
num_iters: int = 100,
122+
use_customized_permute: bool = False) -> float:
123+
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
124+
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
125+
output_hidden_states = torch.empty_like(hidden_states)
126+
if use_fp8_w8a8:
127+
align_block_size = 128 # deepgemm needs 128 m aligned block
128+
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
129+
else:
130+
align_block_size = None
131+
qhidden_states = hidden_states
132+
133+
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
134+
135+
topk_weights, topk_ids, token_expert_indices = fused_topk(
136+
qhidden_states, input_gating, topk, False)
137+
138+
def prepare():
139+
if use_customized_permute:
140+
(permuted_hidden_states, first_token_off, inv_perm_idx,
141+
m_indices) = moe_permute(
142+
qhidden_states,
143+
topk_weights=topk_weights,
144+
topk_ids=topk_ids,
145+
token_expert_indices=token_expert_indices,
146+
topk=topk,
147+
n_expert=num_experts,
148+
n_local_expert=num_experts,
149+
expert_map=None,
150+
align_block_size=align_block_size,
151+
)
152+
# convert to fp16/bf16 as gemm output
153+
return (permuted_hidden_states.to(dtype), first_token_off,
154+
inv_perm_idx, m_indices)
155+
else:
156+
(permuted_qhidden_states, a1q_scale, sorted_token_ids, expert_ids,
157+
inv_perm) = _moe_permute(qhidden_states, None, topk_ids,
158+
num_experts, None, align_block_size)
159+
# convert to fp16/bf16 as gemm output
160+
return (permuted_qhidden_states.to(dtype), a1q_scale,
161+
sorted_token_ids, expert_ids, inv_perm)
162+
163+
def run(input: tuple):
164+
if use_customized_permute:
165+
(permuted_hidden_states, first_token_off, inv_perm_idx,
166+
m_indices) = input
167+
moe_unpermute(permuted_hidden_states, topk_weights, topk_ids,
168+
inv_perm_idx, first_token_off, topk, num_experts,
169+
num_experts)
170+
else:
171+
(permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
172+
inv_perm) = input
173+
_moe_unpermute_and_reduce(output_hidden_states,
174+
permuted_hidden_states, inv_perm,
175+
topk_weights)
176+
177+
# JIT compilation & warmup
178+
input = prepare()
179+
run(input)
180+
torch.cuda.synchronize()
181+
182+
# Capture 10 invocations with CUDA graph
183+
graph = torch.cuda.CUDAGraph()
184+
with torch.cuda.graph(graph):
185+
for _ in range(10):
186+
run(input)
187+
torch.cuda.synchronize()
188+
189+
# Warmup
190+
for _ in range(5):
191+
graph.replay()
192+
torch.cuda.synchronize()
193+
194+
start_event = torch.cuda.Event(enable_timing=True)
195+
end_event = torch.cuda.Event(enable_timing=True)
196+
197+
latencies: list[float] = []
198+
for i in range(num_iters):
199+
torch.cuda.synchronize()
200+
start_event.record()
201+
graph.replay()
202+
end_event.record()
203+
end_event.synchronize()
204+
latencies.append(start_event.elapsed_time(end_event))
205+
avg = sum(latencies) / (num_iters * 10) * 1000 # us
206+
graph.reset()
207+
return avg
208+
209+
210+
@ray.remote(num_gpus=1)
211+
class BenchmarkWorker:
212+
213+
def __init__(self, seed: int) -> None:
214+
torch.set_default_device("cuda")
215+
current_platform.seed_everything(seed)
216+
self.seed = seed
217+
# Get the device ID to allocate tensors and kernels
218+
# on the respective GPU. This is required for Ray to work
219+
# correctly with multi-GPU tuning on the ROCm platform.
220+
self.device_id = int(ray.get_gpu_ids()[0])
221+
222+
def benchmark(
223+
self,
224+
num_tokens: int,
225+
num_experts: int,
226+
hidden_size: int,
227+
topk: int,
228+
dtype: torch.dtype,
229+
use_fp8_w8a8: bool,
230+
use_int8_w8a16: bool,
231+
use_customized_permute: bool = False,
232+
) -> tuple[dict[str, int], float]:
233+
current_platform.seed_everything(self.seed)
234+
235+
permute_time = benchmark_permute(
236+
num_tokens,
237+
num_experts,
238+
hidden_size,
239+
topk,
240+
dtype,
241+
use_fp8_w8a8,
242+
use_int8_w8a16,
243+
num_iters=100,
244+
use_customized_permute=use_customized_permute)
245+
unpermute_time = benchmark_unpermute(
246+
num_tokens,
247+
num_experts,
248+
hidden_size,
249+
topk,
250+
dtype,
251+
use_fp8_w8a8,
252+
use_int8_w8a16,
253+
num_iters=100,
254+
use_customized_permute=use_customized_permute)
255+
return permute_time, unpermute_time
256+
257+
258+
def get_weight_block_size_safety(config, default_value=None):
259+
260+
quantization_config = getattr(config, 'quantization_config', {})
261+
if isinstance(quantization_config, dict):
262+
return quantization_config.get('weight_block_size', default_value)
263+
return default_value
264+
265+
266+
def main(args: argparse.Namespace):
267+
print(args)
268+
269+
config = AutoConfig.from_pretrained(
270+
args.model, trust_remote_code=args.trust_remote_code)
271+
if config.architectures[0] == "DbrxForCausalLM":
272+
E = config.ffn_config.moe_num_experts
273+
topk = config.ffn_config.moe_top_k
274+
elif config.architectures[0] == "JambaForCausalLM":
275+
E = config.num_experts
276+
topk = config.num_experts_per_tok
277+
elif (config.architectures[0] == "DeepseekV3ForCausalLM"
278+
or config.architectures[0] == "DeepseekV2ForCausalLM"):
279+
E = config.n_routed_experts
280+
topk = config.num_experts_per_tok
281+
elif config.architectures[0] in [
282+
"Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"
283+
]:
284+
E = config.num_experts
285+
topk = config.num_experts_per_tok
286+
287+
else:
288+
# Support for llama4
289+
config = config.get_text_config()
290+
# Default: Mixtral.
291+
E = config.num_local_experts
292+
topk = config.num_experts_per_tok
293+
294+
hidden_size = config.hidden_size
295+
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
296+
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
297+
use_int8_w8a16 = args.dtype == "int8_w8a16"
298+
use_customized_permute = args.use_customized_permute
299+
300+
if args.batch_size is None:
301+
batch_sizes = [
302+
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
303+
2048, 3072, 4096
304+
]
305+
else:
306+
batch_sizes = [args.batch_size]
307+
308+
ray.init()
309+
num_gpus = int(ray.available_resources()["GPU"])
310+
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
311+
312+
def _distribute(method: str, inputs: list[Any]) -> list[Any]:
313+
outputs = []
314+
worker_idx = 0
315+
for input_args in inputs:
316+
worker = workers[worker_idx]
317+
worker_method = getattr(worker, method)
318+
output = worker_method.remote(*input_args)
319+
outputs.append(output)
320+
worker_idx = (worker_idx + 1) % num_gpus
321+
return ray.get(outputs)
322+
323+
outputs = _distribute(
324+
"benchmark", [(batch_size, E, hidden_size, topk, dtype, use_fp8_w8a8,
325+
use_int8_w8a16, use_customized_permute)
326+
for batch_size in batch_sizes])
327+
328+
for batch_size, (permute, unpermute) in zip(batch_sizes, outputs):
329+
print(f"Batch size: {batch_size}")
330+
print(f"Permute time: {permute:.2f} us")
331+
print(f"Unpermute time: {unpermute:.2f} us")
332+
333+
334+
if __name__ == "__main__":
335+
parser = FlexibleArgumentParser()
336+
parser.add_argument("--model",
337+
type=str,
338+
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
339+
parser.add_argument("--dtype",
340+
type=str,
341+
choices=["auto", "fp8_w8a8", "int8_w8a16"],
342+
default="auto")
343+
parser.add_argument("--use-customized-permute", action="store_true")
344+
parser.add_argument("--seed", type=int, default=0)
345+
parser.add_argument("--batch-size", type=int, required=False)
346+
parser.add_argument("--trust-remote-code", action="store_true")
347+
args = parser.parse_args()
348+
349+
main(args)

‎csrc/moe/moe_permute_unpermute_op.cu

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,6 @@ void moe_unpermute(
113113
TORCH_CHECK(
114114
permuted_hidden_states.scalar_type() == hidden_states.scalar_type(),
115115
"topk_ids dtype must be same as src_row_id2dst_row_id_map");
116-
// TORCH_CHECK(permuted_hidden_states.size(0) == hidden_states.size(0) * topk,
117-
// "permuted_hidden_states must be [n_token * topk, n_hidden],"
118-
// "hidden_states must be [n_token, n_hidden]");
119116
auto n_token = hidden_states.size(0);
120117
auto n_hidden = hidden_states.size(1);
121118
auto stream = at::cuda::getCurrentCUDAStream().stream();

‎csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
112112
int num_experts, int num_experts_per_node, int k,
113113
CubKeyValueSorter& sorter, void* sorter_ws,
114114
cudaStream_t stream) {
115-
int64_t const expanded_num_rows = k * num_rows;
115+
int64_t const expanded_num_rows = static_cast<int64_t>(k) * num_rows;
116116
// We need to use the full num_experts because that is the sentinel value used
117117
// by topk for disabled experts
118118
sorter.updateNumExperts(num_experts);

‎vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def forward_cuda(
191191
apply_router_weight_on_input: bool = False,
192192
activation: str = "silu",
193193
) -> torch.Tensor:
194-
topk_weights, topk_ids, token_expert_indices = FusedMoE.select_experts(
194+
topk_weights, topk_ids = FusedMoE.select_experts(
195195
hidden_states=x,
196196
router_logits=router_logits,
197197
use_grouped_topk=use_grouped_topk,

0 commit comments

Comments
 (0)
Failed to load comments.