Skip to content

Commit 8d9f15f

Browse files
author
Muralidhar Andoorveedu
committed
Add docs
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
1 parent 4f0e0ea commit 8d9f15f

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

docs/source/serving/distributed_serving.rst

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Distributed Inference and Serving
44
=================================
55

6-
vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm <https://arxiv.org/pdf/1909.08053.pdf>`_. We manage the distributed runtime with either `Ray <https://github.com/ray-project/ray>`_ or python native multiprocessing. Multiprocessing can be used when deploying on a single node, multi-node inferencing currently requires Ray.
6+
vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm <https://arxiv.org/pdf/1909.08053.pdf>`_. We also support pipeline parallel as a beta feature for online serving. We manage the distributed runtime with either `Ray <https://github.com/ray-project/ray>`_ or python native multiprocessing. Multiprocessing can be used when deploying on a single node, multi-node inferencing currently requires Ray.
77

88
Multiprocessing will be used by default when not running in a Ray placement group and if there are sufficient GPUs available on the same node for the configured :code:`tensor_parallel_size`, otherwise Ray will be used. This default can be overridden via the :code:`LLM` class :code:`distributed-executor-backend` argument or :code:`--distributed-executor-backend` API server argument. Set it to :code:`mp` for multiprocessing or :code:`ray` for Ray. It's not required for Ray to be installed for the multiprocessing case.
99

@@ -21,7 +21,20 @@ To run multi-GPU serving, pass in the :code:`--tensor-parallel-size` argument wh
2121
2222
$ python -m vllm.entrypoints.openai.api_server \
2323
$ --model facebook/opt-13b \
24-
$ --tensor-parallel-size 4
24+
$ --tensor-parallel-size 4 \
25+
26+
You can also additionally specify :code:`--pipeline-parallel-size` to enable pipeline parallelism. For example, to run API server on 8 GPUs with pipeline parallelism and tensor parallelism:
27+
28+
.. code-block:: console
29+
30+
$ python -m vllm.entrypoints.openai.api_server \
31+
$ --model gpt2 \
32+
$ --tensor-parallel-size 4 \
33+
$ --pipeline-parallel-size 2 \
34+
$ --distributed-executor-backend ray \
35+
36+
.. note::
37+
Pipeline parallel is a beta feature. It is only supported for online serving and the ray backend for now, as well as LLaMa and GPT2 style models.
2538

2639
To scale vLLM beyond a single machine, install and start a `Ray runtime <https://docs.ray.io/en/latest/ray-core/starting-ray.html>`_ via CLI before running vLLM:
2740

@@ -35,7 +48,7 @@ To scale vLLM beyond a single machine, install and start a `Ray runtime <https:/
3548
$ # On worker nodes
3649
$ ray start --address=<ray-head-address>
3750
38-
After that, you can run inference and serving on multiple machines by launching the vLLM process on the head node by setting :code:`tensor_parallel_size` to the number of GPUs to be the total number of GPUs across all machines.
51+
After that, you can run inference and serving on multiple machines by launching the vLLM process on the head node by setting :code:`tensor_parallel_size` multiplied by :code:`pipeline_parallel_size` to the number of GPUs to be the total number of GPUs across all machines.
3952

4053
.. warning::
4154
Please make sure you downloaded the model to all the nodes, or the model is downloaded to some distributed file system that is accessible by all nodes.

0 commit comments

Comments
 (0)