Skip to content

[Feature] Support DeepEP Low Latency #4767

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 1, 2025
1 change: 1 addition & 0 deletions docs/backend/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ Please consult the documentation below to learn more about the parameters you ma
* `enable_ep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for MoE models.
* `ep_size`: The size of EP. Please shard the model weights with `tp_size=ep_size`, for detailed benchmarking refer to [this PR](https://github.com/sgl-project/sglang/pull/2203). If not set, `ep_size` will be automatically set to `tp_size`.
* `enable_deepep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for DeepSeek-V3 model based on deepseek-ai/DeepEP.
* `deepep_mode`: Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.

## Memory and scheduling

Expand Down
142 changes: 142 additions & 0 deletions python/sglang/srt/layers/moe/ep_moe/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,148 @@ def silu_and_mul_triton_kernel(
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)


# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
@triton.jit
def _silu_and_mul_post_quant_kernel(
input_ptr,
stride_input_0,
stride_input_1,
stride_input_2,
output_ptr,
stride_output_0,
stride_output_1,
stride_output_2,
output_scale_ptr,
stride_output_scale_0,
stride_output_scale_1,
stride_output_scale_2,
masked_m_ptr,
size_n,
fp8_max,
fp8_min,
BLOCK_N: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
expert_id = tl.program_id(2)
token_id = tl.program_id(1)
hidden_dim_block_index = tl.program_id(0)

block_num_per_expert = tl.num_programs(1)

token_num_cur_expert = tl.load(masked_m_ptr + expert_id)

stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)

offs_in_d = hidden_dim_block_index * BLOCK_N + tl.arange(0, BLOCK_N)
input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d
output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d
output_scale_offs = (
output_scale_ptr
+ expert_id * stride_output_scale_0
+ hidden_dim_block_index * stride_output_scale_2
)

for token_index in tl.range(
token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE
):
gate = tl.load(
input_ptr_offs + token_index * stride_input_1,
mask=offs_in_d < size_n,
other=0.0,
).to(tl.float32)
up = tl.load(
input_ptr_offs + token_index * stride_input_1 + size_n,
mask=offs_in_d < size_n,
other=0.0,
)
gate = gate / (1 + tl.exp(-gate))
gate = gate.to(input_ptr.dtype.element_ty)
gate_up = up * gate
_absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
output_s = _absmax / fp8_max
output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(
output_ptr.dtype.element_ty
)
tl.store(
output_ptr_offs + token_index * stride_output_1,
output_q,
mask=offs_in_d < size_n,
)
tl.store(
output_scale_offs + token_index * stride_output_scale_1,
output_s,
)


def silu_and_mul_masked_post_quant_fwd(
input: torch.Tensor,
output: torch.Tensor,
output_scale: torch.Tensor,
quant_group_size: int,
masked_m: torch.Tensor,
):
"""
input shape [expert_num, token_num_padded, hidden_dim]
output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8
output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32
quant_group_size int,
masked_m shape [expert_num],
"""

assert input.is_contiguous()
assert output.dtype == torch.float8_e4m3fn
assert output.is_contiguous()
assert len(input.shape) == 3
assert input.shape[0] == masked_m.shape[0]
assert input.shape[-1] % 2 == 0

size_n = input.shape[-1] // 2
assert size_n % quant_group_size == 0

expert_num = len(masked_m)

if expert_num < 4:
BLOCK_NUM_PER_EXPERT = 64
else:
BLOCK_NUM_PER_EXPERT = 32

BLOCK_N = quant_group_size
num_warps = 1
NUM_STAGES = 6
hidden_dim_split_block_num = triton.cdiv(size_n, BLOCK_N)
assert BLOCK_N % quant_group_size == 0

grid = (
hidden_dim_split_block_num,
BLOCK_NUM_PER_EXPERT,
expert_num,
)

finfo = torch.finfo(torch.float8_e4m3fn)
fp8_max = finfo.max
fp8_min = -fp8_max

_silu_and_mul_post_quant_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
output_scale,
*output_scale.stride(),
masked_m,
size_n,
fp8_max,
fp8_min,
BLOCK_N=BLOCK_N,
NUM_STAGE=NUM_STAGES,
num_warps=num_warps,
)
return


@triton.jit
def tanh(x):
return 2 * tl.sigmoid(2 * x) - 1
Expand Down
159 changes: 81 additions & 78 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@

import torch

# TODO: use deep_gemm masked kernel after low latency dispatch
# import deep_gemm
# from deep_gemm import (
# get_col_major_tma_aligned_tensor,
# m_grouped_gemm_fp8_fp8_bf16_nt_masked,
# )
try:
from deep_gemm import (
get_col_major_tma_aligned_tensor,
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
)

use_deep_gemm = True
except ImportError:
use_deep_gemm = False

from torch.nn import Module

from sglang.srt.custom_op import CustomOp
Expand All @@ -22,6 +26,7 @@
post_reorder_triton_kernel,
pre_reorder_triton_kernel,
run_moe_ep_preproess,
silu_and_mul_masked_post_quant_fwd,
silu_and_mul_triton_kernel,
)
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
Expand Down Expand Up @@ -809,6 +814,7 @@ def __init__(
correction_bias: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
activation: str = "silu",
deepep_mode: str = "auto",
):
super().__init__(
num_experts,
Expand All @@ -827,21 +833,41 @@ def __init__(
custom_routing_function,
activation,
)
self.deepep_mode = deepep_mode
if self.deepep_mode in ["low_latency", "auto"]:
assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm"
self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
)

def forward(
self,
hidden_states: torch.Tensor,
reorder_topk_ids: torch.Tensor,
seg_indptr: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
forward_mode: ForwardMode,
):
# Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode)
if True: # not forward_mode.is_decode():
if self.deepep_mode == "normal" or (
self.deepep_mode == "auto" and not forward_mode.is_decode()
):
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
elif self.deepep_mode == "low_latency" or (
self.deepep_mode == "auto" and forward_mode.is_decode()
):
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
else:
return self.forward_deepgemm_masked(
hidden_states, reorder_topk_ids, seg_indptr
)
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")

def forward_normal(
self,
Expand Down Expand Up @@ -958,89 +984,66 @@ def forward_normal(

def forward_deepgemm_masked(
self,
hidden_states: torch.Tensor,
reorder_topk_ids: torch.Tensor,
seg_indptr: torch.Tensor,
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
masked_m: torch.Tensor,
expected_m: int,
):
assert self.quant_method is not None
assert self.activation == "silu"

if self.activation_scheme == "dynamic" and not self.use_block_quant:
max_value = (
torch.max(hidden_states)
.repeat(self.num_experts_per_partition)
.to(torch.float32)
)
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
assert (
hidden_states_fp8[0].size(0) % 4 == 0
), f"TMA alignment error: {hidden_states_fp8[0].size(0)}"

# GroupGemm-0
num_groups, m, k = hidden_states_fp8[0].size()
n = self.w13_weight.size(1)
expected_m = min(expected_m, m)
gateup_output = torch.empty(
hidden_states.shape[0],
self.w13_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
(num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
)
if hidden_states.shape[0] > 0:
# Transpose earlier so that the testing will not trigger transposing kernels
hidden_states = (
hidden_states[0],
get_col_major_tma_aligned_tensor(hidden_states[1]),
)
"""
gateup_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
hidden_states, self.w13_weight, out, masked_m, expected_m
)
"""

# Act
down_input = torch.empty(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=(
self.fp8_dtype
if (self.use_fp8_w8a8 and not self.use_block_quant)
else hidden_states.dtype
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=gateup_output.device,
dtype=self.fp8_dtype,
)
if self.w2_input_scale is None and not self.use_block_quant:
self.w2_input_scale = torch.ones(
self.num_experts_per_partition,
dtype=torch.float32,
device=hidden_states.device,
)

if self.activation == "silu":
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
reorder_topk_ids,
self.w2_input_scale,
0,
self.num_experts_per_partition - 1,
BLOCK_SIZE=512,
)
else:
raise ValueError(f"Unsupported activation: {self.activation=}")
gateup_output.shape[2] // 2 // scale_block_size,
),
device=gateup_output.device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
)

# GroupGemm-1
n = self.w2_weight.size(1)
down_input_fp8 = (
down_input,
get_col_major_tma_aligned_tensor(down_input_scale),
)
down_output = torch.empty(
down_input.shape[0],
self.w2_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
)
if down_input.shape[0] > 0:
# Transpose earlier so that the testing will not trigger transposing kernels
down_input = (
down_input[0],
get_col_major_tma_aligned_tensor(down_input[1]),
)
"""
down_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
down_input, self.w2_weight, out, masked_m, expected_m
)
"""

return down_output
Loading
Loading