Skip to content

Commit c2d4696

Browse files
WoosukKwonfialhocoelho
authored andcommitted
[Hardware][TPU] Support MoE with Pallas GMM kernel (vllm-project#6457)
1 parent 75275b3 commit c2d4696

File tree

5 files changed

+89
-8
lines changed

5 files changed

+89
-8
lines changed

Dockerfile.tpu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
ARG NIGHTLY_DATE="20240601"
1+
ARG NIGHTLY_DATE="20240713"
22
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
33

44
FROM $BASE_IMAGE
55
WORKDIR /workspace
66

77
# Install aiohttp separately to avoid build errors.
88
RUN pip install aiohttp
9+
# Install NumPy 1 instead of NumPy 2.
10+
RUN pip install "numpy<2"
911
# Install the TPU and Pallas dependencies.
1012
RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
1113
RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

docs/source/getting_started/tpu-installation.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ First, install the dependencies:
5656
$ pip uninstall torch torch-xla -y
5757
5858
$ # Install PyTorch and PyTorch XLA.
59-
$ export DATE="+20240601"
59+
$ export DATE="+20240713"
6060
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
6161
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
6262
@@ -85,7 +85,7 @@ Next, build vLLM from source. This will only take a few seconds:
8585
ImportError: libopenblas.so.0: cannot open shared object file: No such file or directory
8686
8787
88-
You can install OpenBLAS with the following command:
88+
Please install OpenBLAS with the following command:
8989

9090
.. code-block:: console
9191

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,24 @@ def forward_cpu(self, *args, **kwargs):
104104
raise NotImplementedError(
105105
"The CPU backend currently does not support MoE.")
106106

107+
def forward_tpu(
108+
self,
109+
x: torch.Tensor,
110+
w1: torch.Tensor,
111+
w2: torch.Tensor,
112+
router_logits: torch.Tensor,
113+
top_k: int,
114+
renormalize: bool,
115+
use_grouped_topk: bool,
116+
num_expert_group: Optional[int],
117+
topk_group: Optional[int],
118+
) -> torch.Tensor:
119+
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
120+
assert not use_grouped_topk
121+
assert num_expert_group is None
122+
assert topk_group is None
123+
return fused_moe(x, w1, w2, router_logits, top_k, renormalize)
124+
107125

108126
class FusedMoE(torch.nn.Module):
109127
"""FusedMoE layer for MoE models.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from torch_xla.experimental.custom_kernel import _histogram
4+
5+
6+
def fused_moe(
7+
hidden_states: torch.Tensor,
8+
w1: torch.Tensor,
9+
w2: torch.Tensor,
10+
gating_output: torch.Tensor,
11+
topk: int,
12+
renormalize: bool,
13+
) -> torch.Tensor:
14+
"""
15+
Args:
16+
hidden_states: [*, hidden_size]
17+
w1: [num_experts, intermediate_size * 2, hidden_size]
18+
w2: [num_experts, hidden_size, intermediate_size]
19+
gating_output: [*, num_experts]
20+
"""
21+
orig_shape = hidden_states.shape
22+
hidden_size = hidden_states.shape[-1]
23+
num_tokens = hidden_states.shape[:-1].numel()
24+
num_experts = w1.shape[0]
25+
intermediate_size = w2.shape[-1]
26+
device = hidden_states.device
27+
dtype = hidden_states.dtype
28+
assert (num_tokens * topk) % 16 == 0, (
29+
"The Pallas GMM kernel requires num_tokens * topk to be a multiple of "
30+
f"16 but got {num_tokens * topk}")
31+
32+
hidden_states = hidden_states.view(num_tokens, hidden_size)
33+
gating_output = gating_output.view(num_tokens, num_experts)
34+
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
35+
topk_weights, topk_indices = topk_weights.topk(topk, dim=-1)
36+
if renormalize:
37+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
38+
topk_weights = topk_weights.to(dtype)
39+
40+
topk_indices = topk_indices.flatten()
41+
topk_argsort_indices = topk_indices.argsort()
42+
topk_argsort_revert_indices = topk_argsort_indices.argsort()
43+
token_indices = torch.arange(num_tokens,
44+
device=device).repeat_interleave(topk)
45+
token_indices = token_indices[topk_argsort_indices]
46+
group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1)
47+
48+
# NOTE(woosuk): The GMM Pallas kernel requires a different weight layout
49+
# from HF Transformers.
50+
w1 = w1.transpose(1, 2)
51+
w2 = w2.transpose(1, 2)
52+
53+
x = hidden_states[token_indices]
54+
x = torch.ops.xla.gmm(x, w1, group_sizes)
55+
x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:]
56+
x = torch.ops.xla.gmm(x, w2, group_sizes)
57+
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
58+
59+
x = x * topk_weights.unsqueeze_(dim=-1)
60+
x = x.sum(dim=-2)
61+
x = x.reshape(orig_shape)
62+
return x

vllm/worker/tpu_model_runner.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -598,11 +598,10 @@ def _get_padded_prefill_len(x: int) -> int:
598598

599599

600600
def _get_padded_batch_size(batch_size: int) -> int:
601-
if batch_size <= 2:
602-
return batch_size
603-
elif batch_size <= 4:
604-
return 4
605-
elif batch_size <= 8:
601+
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
602+
# To meet this requirement in the simplest way, we set the minimal batch
603+
# size to 8.
604+
if batch_size <= 8:
606605
return 8
607606
else:
608607
return ((batch_size + 15) // 16) * 16

0 commit comments

Comments
 (0)