Skip to content

Commit 9312847

Browse files
liz-badadasleepcoolaixinnch-wan
authored andcommitted
[Feature] Support DeepEP Low Latency (sgl-project#4767)
Co-authored-by: sleepcoo <sleepcoo@gmail.com> Co-authored-by: laixinn <xielx@shanghaitech.edu.cn> Co-authored-by: ch-wan <cwan39@gatech.edu>
1 parent 6381e17 commit 9312847

File tree

8 files changed

+438
-228
lines changed

8 files changed

+438
-228
lines changed

docs/backend/server_arguments.md

+1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ Please consult the documentation below to learn more about the parameters you ma
9191
* `enable_ep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for MoE models.
9292
* `ep_size`: The size of EP. Please shard the model weights with `tp_size=ep_size`, for detailed benchmarking refer to [this PR](https://github.com/sgl-project/sglang/pull/2203). If not set, `ep_size` will be automatically set to `tp_size`.
9393
* `enable_deepep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for DeepSeek-V3 model based on deepseek-ai/DeepEP.
94+
* `deepep_mode`: Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.
9495

9596
## Memory and scheduling
9697

python/sglang/srt/layers/moe/ep_moe/kernels.py

+142
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,148 @@ def silu_and_mul_triton_kernel(
244244
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
245245

246246

247+
# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
248+
@triton.jit
249+
def _silu_and_mul_post_quant_kernel(
250+
input_ptr,
251+
stride_input_0,
252+
stride_input_1,
253+
stride_input_2,
254+
output_ptr,
255+
stride_output_0,
256+
stride_output_1,
257+
stride_output_2,
258+
output_scale_ptr,
259+
stride_output_scale_0,
260+
stride_output_scale_1,
261+
stride_output_scale_2,
262+
masked_m_ptr,
263+
size_n,
264+
fp8_max,
265+
fp8_min,
266+
BLOCK_N: tl.constexpr,
267+
NUM_STAGE: tl.constexpr,
268+
):
269+
expert_id = tl.program_id(2)
270+
token_id = tl.program_id(1)
271+
hidden_dim_block_index = tl.program_id(0)
272+
273+
block_num_per_expert = tl.num_programs(1)
274+
275+
token_num_cur_expert = tl.load(masked_m_ptr + expert_id)
276+
277+
stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
278+
stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
279+
stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
280+
stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)
281+
282+
offs_in_d = hidden_dim_block_index * BLOCK_N + tl.arange(0, BLOCK_N)
283+
input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d
284+
output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d
285+
output_scale_offs = (
286+
output_scale_ptr
287+
+ expert_id * stride_output_scale_0
288+
+ hidden_dim_block_index * stride_output_scale_2
289+
)
290+
291+
for token_index in tl.range(
292+
token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE
293+
):
294+
gate = tl.load(
295+
input_ptr_offs + token_index * stride_input_1,
296+
mask=offs_in_d < size_n,
297+
other=0.0,
298+
).to(tl.float32)
299+
up = tl.load(
300+
input_ptr_offs + token_index * stride_input_1 + size_n,
301+
mask=offs_in_d < size_n,
302+
other=0.0,
303+
)
304+
gate = gate / (1 + tl.exp(-gate))
305+
gate = gate.to(input_ptr.dtype.element_ty)
306+
gate_up = up * gate
307+
_absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
308+
output_s = _absmax / fp8_max
309+
output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(
310+
output_ptr.dtype.element_ty
311+
)
312+
tl.store(
313+
output_ptr_offs + token_index * stride_output_1,
314+
output_q,
315+
mask=offs_in_d < size_n,
316+
)
317+
tl.store(
318+
output_scale_offs + token_index * stride_output_scale_1,
319+
output_s,
320+
)
321+
322+
323+
def silu_and_mul_masked_post_quant_fwd(
324+
input: torch.Tensor,
325+
output: torch.Tensor,
326+
output_scale: torch.Tensor,
327+
quant_group_size: int,
328+
masked_m: torch.Tensor,
329+
):
330+
"""
331+
input shape [expert_num, token_num_padded, hidden_dim]
332+
output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8
333+
output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32
334+
quant_group_size int,
335+
masked_m shape [expert_num],
336+
"""
337+
338+
assert input.is_contiguous()
339+
assert output.dtype == torch.float8_e4m3fn
340+
assert output.is_contiguous()
341+
assert len(input.shape) == 3
342+
assert input.shape[0] == masked_m.shape[0]
343+
assert input.shape[-1] % 2 == 0
344+
345+
size_n = input.shape[-1] // 2
346+
assert size_n % quant_group_size == 0
347+
348+
expert_num = len(masked_m)
349+
350+
if expert_num < 4:
351+
BLOCK_NUM_PER_EXPERT = 64
352+
else:
353+
BLOCK_NUM_PER_EXPERT = 32
354+
355+
BLOCK_N = quant_group_size
356+
num_warps = 1
357+
NUM_STAGES = 6
358+
hidden_dim_split_block_num = triton.cdiv(size_n, BLOCK_N)
359+
assert BLOCK_N % quant_group_size == 0
360+
361+
grid = (
362+
hidden_dim_split_block_num,
363+
BLOCK_NUM_PER_EXPERT,
364+
expert_num,
365+
)
366+
367+
finfo = torch.finfo(torch.float8_e4m3fn)
368+
fp8_max = finfo.max
369+
fp8_min = -fp8_max
370+
371+
_silu_and_mul_post_quant_kernel[grid](
372+
input,
373+
*input.stride(),
374+
output,
375+
*output.stride(),
376+
output_scale,
377+
*output_scale.stride(),
378+
masked_m,
379+
size_n,
380+
fp8_max,
381+
fp8_min,
382+
BLOCK_N=BLOCK_N,
383+
NUM_STAGE=NUM_STAGES,
384+
num_warps=num_warps,
385+
)
386+
return
387+
388+
247389
@triton.jit
248390
def tanh(x):
249391
return 2 * tl.sigmoid(2 * x) - 1

python/sglang/srt/layers/moe/ep_moe/layer.py

+81-78
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,16 @@
33

44
import torch
55

6-
# TODO: use deep_gemm masked kernel after low latency dispatch
7-
# import deep_gemm
8-
# from deep_gemm import (
9-
# get_col_major_tma_aligned_tensor,
10-
# m_grouped_gemm_fp8_fp8_bf16_nt_masked,
11-
# )
6+
try:
7+
from deep_gemm import (
8+
get_col_major_tma_aligned_tensor,
9+
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
10+
)
11+
12+
use_deep_gemm = True
13+
except ImportError:
14+
use_deep_gemm = False
15+
1216
from torch.nn import Module
1317

1418
from sglang.srt.custom_op import CustomOp
@@ -22,6 +26,7 @@
2226
post_reorder_triton_kernel,
2327
pre_reorder_triton_kernel,
2428
run_moe_ep_preproess,
29+
silu_and_mul_masked_post_quant_fwd,
2530
silu_and_mul_triton_kernel,
2631
)
2732
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
@@ -809,6 +814,7 @@ def __init__(
809814
correction_bias: Optional[torch.Tensor] = None,
810815
custom_routing_function: Optional[Callable] = None,
811816
activation: str = "silu",
817+
deepep_mode: str = "auto",
812818
):
813819
super().__init__(
814820
num_experts,
@@ -827,21 +833,41 @@ def __init__(
827833
custom_routing_function,
828834
activation,
829835
)
836+
self.deepep_mode = deepep_mode
837+
if self.deepep_mode in ["low_latency", "auto"]:
838+
assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm"
839+
self.w13_weight_fp8 = (
840+
self.w13_weight,
841+
(
842+
self.w13_weight_scale_inv
843+
if self.use_block_quant
844+
else self.w13_weight_scale
845+
),
846+
)
847+
self.w2_weight_fp8 = (
848+
self.w2_weight,
849+
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
850+
)
830851

831852
def forward(
832853
self,
833854
hidden_states: torch.Tensor,
834855
reorder_topk_ids: torch.Tensor,
835856
seg_indptr: torch.Tensor,
857+
masked_m: torch.Tensor,
858+
expected_m: int,
836859
forward_mode: ForwardMode,
837860
):
838-
# Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode)
839-
if True: # not forward_mode.is_decode():
861+
if self.deepep_mode == "normal" or (
862+
self.deepep_mode == "auto" and not forward_mode.is_decode()
863+
):
840864
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
865+
elif self.deepep_mode == "low_latency" or (
866+
self.deepep_mode == "auto" and forward_mode.is_decode()
867+
):
868+
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
841869
else:
842-
return self.forward_deepgemm_masked(
843-
hidden_states, reorder_topk_ids, seg_indptr
844-
)
870+
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
845871

846872
def forward_normal(
847873
self,
@@ -958,89 +984,66 @@ def forward_normal(
958984

959985
def forward_deepgemm_masked(
960986
self,
961-
hidden_states: torch.Tensor,
962-
reorder_topk_ids: torch.Tensor,
963-
seg_indptr: torch.Tensor,
987+
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
988+
masked_m: torch.Tensor,
989+
expected_m: int,
964990
):
965991
assert self.quant_method is not None
966992
assert self.activation == "silu"
967-
968-
if self.activation_scheme == "dynamic" and not self.use_block_quant:
969-
max_value = (
970-
torch.max(hidden_states)
971-
.repeat(self.num_experts_per_partition)
972-
.to(torch.float32)
973-
)
974-
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
993+
assert (
994+
hidden_states_fp8[0].size(0) % 4 == 0
995+
), f"TMA alignment error: {hidden_states_fp8[0].size(0)}"
975996

976997
# GroupGemm-0
998+
num_groups, m, k = hidden_states_fp8[0].size()
999+
n = self.w13_weight.size(1)
1000+
expected_m = min(expected_m, m)
9771001
gateup_output = torch.empty(
978-
hidden_states.shape[0],
979-
self.w13_weight.shape[1],
980-
device=hidden_states.device,
981-
dtype=hidden_states.dtype,
1002+
(num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
1003+
)
1004+
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
1005+
hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
9821006
)
983-
if hidden_states.shape[0] > 0:
984-
# Transpose earlier so that the testing will not trigger transposing kernels
985-
hidden_states = (
986-
hidden_states[0],
987-
get_col_major_tma_aligned_tensor(hidden_states[1]),
988-
)
989-
"""
990-
gateup_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
991-
hidden_states, self.w13_weight, out, masked_m, expected_m
992-
)
993-
"""
9941007

9951008
# Act
9961009
down_input = torch.empty(
997-
gateup_output.shape[0],
998-
gateup_output.shape[1] // 2,
999-
device=gateup_output.device,
1000-
dtype=(
1001-
self.fp8_dtype
1002-
if (self.use_fp8_w8a8 and not self.use_block_quant)
1003-
else hidden_states.dtype
1010+
(
1011+
gateup_output.shape[0],
1012+
gateup_output.shape[1],
1013+
gateup_output.shape[2] // 2,
10041014
),
1015+
device=gateup_output.device,
1016+
dtype=self.fp8_dtype,
10051017
)
1006-
if self.w2_input_scale is None and not self.use_block_quant:
1007-
self.w2_input_scale = torch.ones(
1008-
self.num_experts_per_partition,
1009-
dtype=torch.float32,
1010-
device=hidden_states.device,
1011-
)
1012-
1013-
if self.activation == "silu":
1014-
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
1015-
gateup_output,
1016-
down_input,
1018+
scale_block_size = 128
1019+
down_input_scale = torch.empty(
1020+
(
1021+
gateup_output.shape[0],
10171022
gateup_output.shape[1],
1018-
reorder_topk_ids,
1019-
self.w2_input_scale,
1020-
0,
1021-
self.num_experts_per_partition - 1,
1022-
BLOCK_SIZE=512,
1023-
)
1024-
else:
1025-
raise ValueError(f"Unsupported activation: {self.activation=}")
1023+
gateup_output.shape[2] // 2 // scale_block_size,
1024+
),
1025+
device=gateup_output.device,
1026+
dtype=torch.float32,
1027+
)
1028+
silu_and_mul_masked_post_quant_fwd(
1029+
gateup_output,
1030+
down_input,
1031+
down_input_scale,
1032+
scale_block_size,
1033+
masked_m,
1034+
)
10261035

10271036
# GroupGemm-1
1037+
n = self.w2_weight.size(1)
1038+
down_input_fp8 = (
1039+
down_input,
1040+
get_col_major_tma_aligned_tensor(down_input_scale),
1041+
)
10281042
down_output = torch.empty(
1029-
down_input.shape[0],
1030-
self.w2_weight.shape[1],
1031-
device=hidden_states.device,
1032-
dtype=hidden_states.dtype,
1043+
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
1044+
)
1045+
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
1046+
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
10331047
)
1034-
if down_input.shape[0] > 0:
1035-
# Transpose earlier so that the testing will not trigger transposing kernels
1036-
down_input = (
1037-
down_input[0],
1038-
get_col_major_tma_aligned_tensor(down_input[1]),
1039-
)
1040-
"""
1041-
down_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
1042-
down_input, self.w2_weight, out, masked_m, expected_m
1043-
)
1044-
"""
10451048

10461049
return down_output

0 commit comments

Comments
 (0)