|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | +import argparse |
| 4 | +from typing import Any, TypedDict |
| 5 | + |
| 6 | +import ray |
| 7 | +import torch |
| 8 | +from transformers import AutoConfig |
| 9 | + |
| 10 | +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( |
| 11 | + _moe_permute, _moe_unpermute_and_reduce) |
| 12 | +from vllm.model_executor.layers.fused_moe.fused_moe import * |
| 13 | +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import * |
| 14 | +from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize |
| 15 | +from vllm.platforms import current_platform |
| 16 | +from vllm.utils import FlexibleArgumentParser |
| 17 | + |
| 18 | +FP8_DTYPE = current_platform.fp8_dtype() |
| 19 | + |
| 20 | + |
| 21 | +class BenchmarkConfig(TypedDict): |
| 22 | + BLOCK_SIZE_M: int |
| 23 | + BLOCK_SIZE_N: int |
| 24 | + BLOCK_SIZE_K: int |
| 25 | + GROUP_SIZE_M: int |
| 26 | + num_warps: int |
| 27 | + num_stages: int |
| 28 | + |
| 29 | + |
| 30 | +def benchmark_permute(num_tokens: int, |
| 31 | + num_experts: int, |
| 32 | + hidden_size: int, |
| 33 | + topk: int, |
| 34 | + dtype: torch.dtype, |
| 35 | + use_fp8_w8a8: bool, |
| 36 | + use_int8_w8a16: bool, |
| 37 | + num_iters: int = 100, |
| 38 | + use_customized_permute: bool = False) -> float: |
| 39 | + # init_dtype = torch.float16 if use_fp8_w8a8 else dtype |
| 40 | + hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) |
| 41 | + # output_hidden_states = torch.empty_like(hidden_states) |
| 42 | + if use_fp8_w8a8: |
| 43 | + align_block_size = 128 # deepgemm needs 128 m aligned block |
| 44 | + qhidden_states, scale = _fp8_quantize(hidden_states, None, None) |
| 45 | + else: |
| 46 | + align_block_size = None |
| 47 | + qhidden_states = hidden_states |
| 48 | + |
| 49 | + gating_output = torch.randn(num_iters, |
| 50 | + num_tokens, |
| 51 | + num_experts, |
| 52 | + dtype=torch.float32) |
| 53 | + |
| 54 | + input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) |
| 55 | + topk_weights, topk_ids, token_expert_indices = fused_topk( |
| 56 | + qhidden_states, input_gating, topk, False) |
| 57 | + |
| 58 | + def prepare(i: int): |
| 59 | + input_gating.copy_(gating_output[i]) |
| 60 | + |
| 61 | + def run(): |
| 62 | + if use_customized_permute: |
| 63 | + (permuted_hidden_states, first_token_off, inv_perm_idx, |
| 64 | + m_indices) = moe_permute( |
| 65 | + qhidden_states, |
| 66 | + topk_weights=topk_weights, |
| 67 | + topk_ids=topk_ids, |
| 68 | + token_expert_indices=token_expert_indices, |
| 69 | + topk=topk, |
| 70 | + n_expert=num_experts, |
| 71 | + n_local_expert=num_experts, |
| 72 | + expert_map=None, |
| 73 | + align_block_size=align_block_size, |
| 74 | + ) |
| 75 | + else: |
| 76 | + (permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids, |
| 77 | + inv_perm) = _moe_permute(qhidden_states, None, topk_ids, |
| 78 | + num_experts, None, align_block_size) |
| 79 | + |
| 80 | + # JIT compilation & warmup |
| 81 | + run() |
| 82 | + torch.cuda.synchronize() |
| 83 | + |
| 84 | + # Capture 10 invocations with CUDA graph |
| 85 | + graph = torch.cuda.CUDAGraph() |
| 86 | + with torch.cuda.graph(graph): |
| 87 | + for _ in range(10): |
| 88 | + run() |
| 89 | + torch.cuda.synchronize() |
| 90 | + |
| 91 | + # Warmup |
| 92 | + for _ in range(5): |
| 93 | + graph.replay() |
| 94 | + torch.cuda.synchronize() |
| 95 | + |
| 96 | + start_event = torch.cuda.Event(enable_timing=True) |
| 97 | + end_event = torch.cuda.Event(enable_timing=True) |
| 98 | + |
| 99 | + latencies: list[float] = [] |
| 100 | + for i in range(num_iters): |
| 101 | + prepare(i) |
| 102 | + torch.cuda.synchronize() |
| 103 | + |
| 104 | + start_event.record() |
| 105 | + graph.replay() |
| 106 | + end_event.record() |
| 107 | + end_event.synchronize() |
| 108 | + latencies.append(start_event.elapsed_time(end_event)) |
| 109 | + avg = sum(latencies) / (num_iters * 10) * 1000 # us |
| 110 | + graph.reset() |
| 111 | + return avg |
| 112 | + |
| 113 | + |
| 114 | +def benchmark_unpermute(num_tokens: int, |
| 115 | + num_experts: int, |
| 116 | + hidden_size: int, |
| 117 | + topk: int, |
| 118 | + dtype: torch.dtype, |
| 119 | + use_fp8_w8a8: bool, |
| 120 | + use_int8_w8a16: bool, |
| 121 | + num_iters: int = 100, |
| 122 | + use_customized_permute: bool = False) -> float: |
| 123 | + # init_dtype = torch.float16 if use_fp8_w8a8 else dtype |
| 124 | + hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) |
| 125 | + output_hidden_states = torch.empty_like(hidden_states) |
| 126 | + if use_fp8_w8a8: |
| 127 | + align_block_size = 128 # deepgemm needs 128 m aligned block |
| 128 | + qhidden_states, scale = _fp8_quantize(hidden_states, None, None) |
| 129 | + else: |
| 130 | + align_block_size = None |
| 131 | + qhidden_states = hidden_states |
| 132 | + |
| 133 | + input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) |
| 134 | + |
| 135 | + topk_weights, topk_ids, token_expert_indices = fused_topk( |
| 136 | + qhidden_states, input_gating, topk, False) |
| 137 | + |
| 138 | + def prepare(): |
| 139 | + if use_customized_permute: |
| 140 | + (permuted_hidden_states, first_token_off, inv_perm_idx, |
| 141 | + m_indices) = moe_permute( |
| 142 | + qhidden_states, |
| 143 | + topk_weights=topk_weights, |
| 144 | + topk_ids=topk_ids, |
| 145 | + token_expert_indices=token_expert_indices, |
| 146 | + topk=topk, |
| 147 | + n_expert=num_experts, |
| 148 | + n_local_expert=num_experts, |
| 149 | + expert_map=None, |
| 150 | + align_block_size=align_block_size, |
| 151 | + ) |
| 152 | + # convert to fp16/bf16 as gemm output |
| 153 | + return (permuted_hidden_states.to(dtype), first_token_off, |
| 154 | + inv_perm_idx, m_indices) |
| 155 | + else: |
| 156 | + (permuted_qhidden_states, a1q_scale, sorted_token_ids, expert_ids, |
| 157 | + inv_perm) = _moe_permute(qhidden_states, None, topk_ids, |
| 158 | + num_experts, None, align_block_size) |
| 159 | + # convert to fp16/bf16 as gemm output |
| 160 | + return (permuted_qhidden_states.to(dtype), a1q_scale, |
| 161 | + sorted_token_ids, expert_ids, inv_perm) |
| 162 | + |
| 163 | + def run(input: tuple): |
| 164 | + if use_customized_permute: |
| 165 | + (permuted_hidden_states, first_token_off, inv_perm_idx, |
| 166 | + m_indices) = input |
| 167 | + moe_unpermute(permuted_hidden_states, topk_weights, topk_ids, |
| 168 | + inv_perm_idx, first_token_off, topk, num_experts, |
| 169 | + num_experts) |
| 170 | + else: |
| 171 | + (permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids, |
| 172 | + inv_perm) = input |
| 173 | + _moe_unpermute_and_reduce(output_hidden_states, |
| 174 | + permuted_hidden_states, inv_perm, |
| 175 | + topk_weights) |
| 176 | + |
| 177 | + # JIT compilation & warmup |
| 178 | + input = prepare() |
| 179 | + run(input) |
| 180 | + torch.cuda.synchronize() |
| 181 | + |
| 182 | + # Capture 10 invocations with CUDA graph |
| 183 | + graph = torch.cuda.CUDAGraph() |
| 184 | + with torch.cuda.graph(graph): |
| 185 | + for _ in range(10): |
| 186 | + run(input) |
| 187 | + torch.cuda.synchronize() |
| 188 | + |
| 189 | + # Warmup |
| 190 | + for _ in range(5): |
| 191 | + graph.replay() |
| 192 | + torch.cuda.synchronize() |
| 193 | + |
| 194 | + start_event = torch.cuda.Event(enable_timing=True) |
| 195 | + end_event = torch.cuda.Event(enable_timing=True) |
| 196 | + |
| 197 | + latencies: list[float] = [] |
| 198 | + for i in range(num_iters): |
| 199 | + torch.cuda.synchronize() |
| 200 | + start_event.record() |
| 201 | + graph.replay() |
| 202 | + end_event.record() |
| 203 | + end_event.synchronize() |
| 204 | + latencies.append(start_event.elapsed_time(end_event)) |
| 205 | + avg = sum(latencies) / (num_iters * 10) * 1000 # us |
| 206 | + graph.reset() |
| 207 | + return avg |
| 208 | + |
| 209 | + |
| 210 | +@ray.remote(num_gpus=1) |
| 211 | +class BenchmarkWorker: |
| 212 | + |
| 213 | + def __init__(self, seed: int) -> None: |
| 214 | + torch.set_default_device("cuda") |
| 215 | + current_platform.seed_everything(seed) |
| 216 | + self.seed = seed |
| 217 | + # Get the device ID to allocate tensors and kernels |
| 218 | + # on the respective GPU. This is required for Ray to work |
| 219 | + # correctly with multi-GPU tuning on the ROCm platform. |
| 220 | + self.device_id = int(ray.get_gpu_ids()[0]) |
| 221 | + |
| 222 | + def benchmark( |
| 223 | + self, |
| 224 | + num_tokens: int, |
| 225 | + num_experts: int, |
| 226 | + hidden_size: int, |
| 227 | + topk: int, |
| 228 | + dtype: torch.dtype, |
| 229 | + use_fp8_w8a8: bool, |
| 230 | + use_int8_w8a16: bool, |
| 231 | + use_customized_permute: bool = False, |
| 232 | + ) -> tuple[dict[str, int], float]: |
| 233 | + current_platform.seed_everything(self.seed) |
| 234 | + |
| 235 | + permute_time = benchmark_permute( |
| 236 | + num_tokens, |
| 237 | + num_experts, |
| 238 | + hidden_size, |
| 239 | + topk, |
| 240 | + dtype, |
| 241 | + use_fp8_w8a8, |
| 242 | + use_int8_w8a16, |
| 243 | + num_iters=100, |
| 244 | + use_customized_permute=use_customized_permute) |
| 245 | + unpermute_time = benchmark_unpermute( |
| 246 | + num_tokens, |
| 247 | + num_experts, |
| 248 | + hidden_size, |
| 249 | + topk, |
| 250 | + dtype, |
| 251 | + use_fp8_w8a8, |
| 252 | + use_int8_w8a16, |
| 253 | + num_iters=100, |
| 254 | + use_customized_permute=use_customized_permute) |
| 255 | + return permute_time, unpermute_time |
| 256 | + |
| 257 | + |
| 258 | +def get_weight_block_size_safety(config, default_value=None): |
| 259 | + |
| 260 | + quantization_config = getattr(config, 'quantization_config', {}) |
| 261 | + if isinstance(quantization_config, dict): |
| 262 | + return quantization_config.get('weight_block_size', default_value) |
| 263 | + return default_value |
| 264 | + |
| 265 | + |
| 266 | +def main(args: argparse.Namespace): |
| 267 | + print(args) |
| 268 | + |
| 269 | + config = AutoConfig.from_pretrained( |
| 270 | + args.model, trust_remote_code=args.trust_remote_code) |
| 271 | + if config.architectures[0] == "DbrxForCausalLM": |
| 272 | + E = config.ffn_config.moe_num_experts |
| 273 | + topk = config.ffn_config.moe_top_k |
| 274 | + elif config.architectures[0] == "JambaForCausalLM": |
| 275 | + E = config.num_experts |
| 276 | + topk = config.num_experts_per_tok |
| 277 | + elif (config.architectures[0] == "DeepseekV3ForCausalLM" |
| 278 | + or config.architectures[0] == "DeepseekV2ForCausalLM"): |
| 279 | + E = config.n_routed_experts |
| 280 | + topk = config.num_experts_per_tok |
| 281 | + elif config.architectures[0] in [ |
| 282 | + "Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM" |
| 283 | + ]: |
| 284 | + E = config.num_experts |
| 285 | + topk = config.num_experts_per_tok |
| 286 | + |
| 287 | + else: |
| 288 | + # Support for llama4 |
| 289 | + config = config.get_text_config() |
| 290 | + # Default: Mixtral. |
| 291 | + E = config.num_local_experts |
| 292 | + topk = config.num_experts_per_tok |
| 293 | + |
| 294 | + hidden_size = config.hidden_size |
| 295 | + dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype |
| 296 | + use_fp8_w8a8 = args.dtype == "fp8_w8a8" |
| 297 | + use_int8_w8a16 = args.dtype == "int8_w8a16" |
| 298 | + use_customized_permute = args.use_customized_permute |
| 299 | + |
| 300 | + if args.batch_size is None: |
| 301 | + batch_sizes = [ |
| 302 | + 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, |
| 303 | + 2048, 3072, 4096 |
| 304 | + ] |
| 305 | + else: |
| 306 | + batch_sizes = [args.batch_size] |
| 307 | + |
| 308 | + ray.init() |
| 309 | + num_gpus = int(ray.available_resources()["GPU"]) |
| 310 | + workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] |
| 311 | + |
| 312 | + def _distribute(method: str, inputs: list[Any]) -> list[Any]: |
| 313 | + outputs = [] |
| 314 | + worker_idx = 0 |
| 315 | + for input_args in inputs: |
| 316 | + worker = workers[worker_idx] |
| 317 | + worker_method = getattr(worker, method) |
| 318 | + output = worker_method.remote(*input_args) |
| 319 | + outputs.append(output) |
| 320 | + worker_idx = (worker_idx + 1) % num_gpus |
| 321 | + return ray.get(outputs) |
| 322 | + |
| 323 | + outputs = _distribute( |
| 324 | + "benchmark", [(batch_size, E, hidden_size, topk, dtype, use_fp8_w8a8, |
| 325 | + use_int8_w8a16, use_customized_permute) |
| 326 | + for batch_size in batch_sizes]) |
| 327 | + |
| 328 | + for batch_size, (permute, unpermute) in zip(batch_sizes, outputs): |
| 329 | + print(f"Batch size: {batch_size}") |
| 330 | + print(f"Permute time: {permute:.2f} us") |
| 331 | + print(f"Unpermute time: {unpermute:.2f} us") |
| 332 | + |
| 333 | + |
| 334 | +if __name__ == "__main__": |
| 335 | + parser = FlexibleArgumentParser() |
| 336 | + parser.add_argument("--model", |
| 337 | + type=str, |
| 338 | + default="mistralai/Mixtral-8x7B-Instruct-v0.1") |
| 339 | + parser.add_argument("--dtype", |
| 340 | + type=str, |
| 341 | + choices=["auto", "fp8_w8a8", "int8_w8a16"], |
| 342 | + default="auto") |
| 343 | + parser.add_argument("--use-customized-permute", action="store_true") |
| 344 | + parser.add_argument("--seed", type=int, default=0) |
| 345 | + parser.add_argument("--batch-size", type=int, required=False) |
| 346 | + parser.add_argument("--trust-remote-code", action="store_true") |
| 347 | + args = parser.parse_args() |
| 348 | + |
| 349 | + main(args) |
0 commit comments