Skip to content

Commit 9f55b39

Browse files
comaniacLeiWang1999
authored andcommitted
[Model] Pipeline parallel support for Mixtral (vllm-project#6516)
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent 9bca381 commit 9f55b39

File tree

3 files changed

+60
-19
lines changed

3 files changed

+60
-19
lines changed

tests/distributed/test_pipeline_parallel.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
from transformers import AutoTokenizer
23

34
from ..utils import RemoteOpenAIServer
45

@@ -12,6 +13,8 @@
1213
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
1314
])
1415
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
16+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17+
1518
pp_args = [
1619
# use half precision for speed and memory savings in CI environment
1720
"--dtype",
@@ -34,7 +37,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
3437
"--dtype",
3538
"bfloat16",
3639
"--tensor-parallel-size",
37-
str(max(TP_SIZE, 2)), # use at least TP_SIZE=2 to hold the model
40+
str(max(TP_SIZE, 2)), # We only use 2 GPUs in the CI.
3841
"--distributed-executor-backend",
3942
"mp",
4043
]
@@ -45,8 +48,10 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
4548
pp_args.append("--enforce-eager")
4649
tp_args.append("--enforce-eager")
4750

51+
prompt = "Hello, my name is"
52+
token_ids = tokenizer(prompt)["input_ids"]
4853
results = []
49-
for args in [pp_args, tp_args]:
54+
for args in (pp_args, tp_args):
5055
with RemoteOpenAIServer(MODEL_NAME, args) as server:
5156
client = server.get_client()
5257

@@ -62,7 +67,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
6267

6368
# test with text prompt
6469
completion = client.completions.create(model=MODEL_NAME,
65-
prompt="Hello, my name is",
70+
prompt=prompt,
6671
max_tokens=5,
6772
temperature=0.0)
6873

@@ -76,7 +81,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
7681
# test using token IDs
7782
completion = client.completions.create(
7883
model=MODEL_NAME,
79-
prompt=[0, 0, 0, 0, 0],
84+
prompt=token_ids,
8085
max_tokens=5,
8186
temperature=0.0,
8287
)
@@ -91,7 +96,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
9196
# test simple list
9297
batch = client.completions.create(
9398
model=MODEL_NAME,
94-
prompt=["Hello, my name is", "Hello, my name is"],
99+
prompt=[prompt, prompt],
95100
max_tokens=5,
96101
temperature=0.0,
97102
)
@@ -105,7 +110,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
105110
# test streaming
106111
batch = client.completions.create(
107112
model=MODEL_NAME,
108-
prompt=["Hello, my name is", "Hello, my name is"],
113+
prompt=[prompt, prompt],
109114
max_tokens=5,
110115
temperature=0.0,
111116
stream=True,

vllm/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"MistralForCausalLM",
3535
"Phi3ForCausalLM",
3636
"GPT2LMHeadModel",
37+
"MixtralForCausalLM",
3738
]
3839

3940

vllm/model_executor/models/mixtral.py

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from vllm.attention import Attention, AttentionMetadata
3131
from vllm.config import CacheConfig, LoRAConfig
32-
from vllm.distributed import get_tensor_model_parallel_world_size
32+
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
3333
from vllm.model_executor.layers.fused_moe import FusedMoE
3434
from vllm.model_executor.layers.layernorm import RMSNorm
3535
from vllm.model_executor.layers.linear import (QKVParallelLinear,
@@ -48,6 +48,7 @@
4848
from vllm.sequence import IntermediateTensors, SamplerOutput
4949

5050
from .interfaces import SupportsLoRA
51+
from .utils import is_pp_missing_parameter, make_layers
5152

5253

5354
class MixtralMoE(nn.Module):
@@ -255,12 +256,11 @@ def __init__(
255256
config.hidden_size,
256257
org_num_embeddings=config.vocab_size,
257258
)
258-
self.layers = nn.ModuleList([
259-
MixtralDecoderLayer(config,
260-
cache_config,
261-
quant_config=quant_config)
262-
for _ in range(config.num_hidden_layers)
263-
])
259+
260+
self.start_layer, self.end_layer, self.layers = make_layers(
261+
config.num_hidden_layers, lambda: MixtralDecoderLayer(
262+
config, cache_config, quant_config=quant_config))
263+
264264
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
265265

266266
def forward(
@@ -269,14 +269,25 @@ def forward(
269269
positions: torch.Tensor,
270270
kv_caches: List[torch.Tensor],
271271
attn_metadata: AttentionMetadata,
272+
intermediate_tensors: Optional[IntermediateTensors],
272273
) -> torch.Tensor:
273-
hidden_states = self.embed_tokens(input_ids)
274-
residual = None
275-
for i in range(len(self.layers)):
274+
if get_pp_group().is_first_rank:
275+
hidden_states = self.embed_tokens(input_ids)
276+
residual = None
277+
else:
278+
assert intermediate_tensors is not None
279+
hidden_states = intermediate_tensors["hidden_states"]
280+
residual = intermediate_tensors["residual"]
281+
for i in range(self.start_layer, self.end_layer):
276282
layer = self.layers[i]
277283
hidden_states, residual = layer(positions, hidden_states,
278-
kv_caches[i], attn_metadata,
279-
residual)
284+
kv_caches[i - self.start_layer],
285+
attn_metadata, residual)
286+
if not get_pp_group().is_last_rank:
287+
return IntermediateTensors({
288+
"hidden_states": hidden_states,
289+
"residual": residual
290+
})
280291
hidden_states, _ = self.norm(hidden_states, residual)
281292
return hidden_states
282293

@@ -347,7 +358,7 @@ def forward(
347358
intermediate_tensors: Optional[IntermediateTensors] = None,
348359
) -> torch.Tensor:
349360
hidden_states = self.model(input_ids, positions, kv_caches,
350-
attn_metadata)
361+
attn_metadata, intermediate_tensors)
351362
return hidden_states
352363

353364
def compute_logits(self, hidden_states: torch.Tensor,
@@ -356,6 +367,20 @@ def compute_logits(self, hidden_states: torch.Tensor,
356367
sampling_metadata)
357368
return logits
358369

370+
def make_empty_intermediate_tensors(
371+
self, batch_size: int, dtype: torch.dtype,
372+
device: torch.device) -> IntermediateTensors:
373+
return IntermediateTensors({
374+
"hidden_states":
375+
torch.zeros((batch_size, self.config.hidden_size),
376+
dtype=dtype,
377+
device=device),
378+
"residual":
379+
torch.zeros((batch_size, self.config.hidden_size),
380+
dtype=dtype,
381+
device=device),
382+
})
383+
359384
def sample(
360385
self,
361386
logits: Optional[torch.Tensor],
@@ -392,6 +417,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
392417
# Skip loading extra bias for GPTQ models.
393418
if name.endswith(".bias") and name not in params_dict:
394419
continue
420+
# Skip layers on other devices.
421+
if is_pp_missing_parameter(name, self):
422+
continue
423+
395424
param = params_dict[name]
396425
weight_loader = param.weight_loader
397426
weight_loader(param, loaded_weight, shard_id)
@@ -402,6 +431,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
402431
if weight_name not in name:
403432
continue
404433
name = name.replace(weight_name, param_name)
434+
# Skip layers on other devices.
435+
if is_pp_missing_parameter(name, self):
436+
continue
405437
param = params_dict[name]
406438
weight_loader = param.weight_loader
407439
weight_loader(param,
@@ -414,6 +446,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
414446
# Skip loading extra bias for GPTQ models.
415447
if name.endswith(".bias") and name not in params_dict:
416448
continue
449+
# Skip layers on other devices.
450+
if is_pp_missing_parameter(name, self):
451+
continue
417452
# Remapping the name of FP8 kv-scale.
418453
name = maybe_remap_kv_scale_name(name, params_dict)
419454
if name is None:

0 commit comments

Comments
 (0)