Skip to content

Commit be61d20

Browse files
authored
Create PlacementGroup for steps using vLLM (#842)
* Create placement group for `vLLM` * Use `SPREAD` if `pipeline_parallel_size>1` * Fix bundle initialization * Fix wrong dictionary * Remove using `SPMD` from ray docs * Refactor creating `PlacementGroup` for `vLLM`
1 parent 2aa977f commit be61d20

File tree

2 files changed

+68
-16
lines changed

2 files changed

+68
-16
lines changed

docs/sections/how_to_guides/advanced/scaling_with_ray.md

-8
Original file line numberDiff line numberDiff line change
@@ -232,12 +232,4 @@ with Pipeline(name="text-generation-ray-pipeline") as pipeline:
232232
load_data_from_hub >> text_generation
233233
```
234234

235-
Finally, we need to define two environment variables in our `runtime_env.yaml` file:
236-
237-
```yaml
238-
env_vars:
239-
VLLM_USE_RAY_COMPILED_DAG: "1"
240-
VLLM_USE_RAY_SPMD_WORKER: "1"
241-
```
242-
243235
More information about distributed inference with `vLLM` can be found here: [vLLM - Distributed Serving](https://docs.vllm.ai/en/latest/serving/distributed_serving.html)

src/distilabel/pipeline/ray.py

+68-8
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
1717

1818
from distilabel.distiset import create_distiset
19+
from distilabel.llms.vllm import vLLM
1920
from distilabel.pipeline.base import BasePipeline
2021
from distilabel.pipeline.constants import INPUT_QUEUE_ATTR_NAME
2122
from distilabel.pipeline.step_wrapper import _StepWrapper
@@ -26,6 +27,8 @@
2627
from os import PathLike
2728
from queue import Queue
2829

30+
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
31+
2932
from distilabel.distiset import Distiset
3033
from distilabel.pipeline.typing import InputDataset
3134
from distilabel.steps.base import _Step
@@ -69,6 +72,7 @@ def __init__(
6972

7073
self._ray_head_node_url = ray_head_node_url
7174
self._ray_init_kwargs = ray_init_kwargs or {}
75+
self._ray_node_ids = {}
7276

7377
def run(
7478
self,
@@ -171,6 +175,8 @@ def _init_ray(self) -> None:
171175
else:
172176
ray.init(**self._ray_init_kwargs)
173177

178+
self._ray_node_ids = {node["NodeID"]: False for node in ray.nodes()}
179+
174180
@property
175181
def QueueClass(self) -> Callable:
176182
from ray.util.queue import Queue
@@ -218,17 +224,20 @@ def run(self) -> str:
218224
"name": f"distilabel-{self.name}-{step.name}-{replica}"
219225
}
220226

221-
if step.resources.cpus is not None:
222-
resources["num_cpus"] = step.resources.cpus
227+
if hasattr(step, "llm") and isinstance(step.llm, vLLM): # type: ignore
228+
resources["scheduling_strategy"] = self._create_vllm_placement_group(step)
229+
else:
230+
if step.resources.cpus is not None:
231+
resources["num_cpus"] = step.resources.cpus
223232

224-
if step.resources.gpus is not None:
225-
resources["num_gpus"] = step.resources.gpus
233+
if step.resources.gpus is not None:
234+
resources["num_gpus"] = step.resources.gpus
226235

227-
if step.resources.memory is not None:
228-
resources["memory"] = step.resources.memory
236+
if step.resources.memory is not None:
237+
resources["memory"] = step.resources.memory
229238

230-
if step.resources.resources is not None:
231-
resources["resources"] = step.resources.resources
239+
if step.resources.resources is not None:
240+
resources["resources"] = step.resources.resources
232241

233242
_StepWrapperRay = _StepWrapperRay.options(**resources) # type: ignore
234243

@@ -255,6 +264,57 @@ def run(self) -> str:
255264
)
256265
step_wrapper.run.remote()
257266

267+
def _create_vllm_placement_group(
268+
self, step: "_Step"
269+
) -> "PlacementGroupSchedulingStrategy":
270+
"""Creates a Ray placement group with as many GPU bundles as `tensor_parallel_size`
271+
specified in the `vLLM` initialisation. The created placement group uses the `STRICT_PACK`
272+
strategy if the `pipeline_parallel_size` is less or equal to 1, otherwise it uses
273+
`SPREAD` (placement group with GPU bundles in several nodes). In addition, the created
274+
placement group is targeted to be created in a specific node. This avoids having
275+
`vLLM` raising the exception `Ray does not allocate any GPUs on the driver node...`,
276+
as it assures that the driver `_StepWrapperRay` actor created resides in the same
277+
node as the ray actors created by `vLLM` for the distributed inference.
278+
279+
Args:
280+
step: the step which uses `vLLM`.
281+
282+
Returns:
283+
A `PlacementGroupSchedulingStrategy` using the created `PlacementGroup`.
284+
"""
285+
import ray
286+
287+
llm = step.llm # type: ignore
288+
tensor_parallel_size = llm.extra_kwargs.get("tensor_parallel_size", 1) # type: ignore
289+
pipeline_parallel_size = llm.extra_kwargs.get( # type: ignore
290+
"pipeline_parallel_size", 1
291+
)
292+
293+
node_id = next(
294+
node_id for node_id, used in self._ray_node_ids.items() if not used
295+
)
296+
297+
self._ray_node_ids[node_id] = True
298+
299+
# Create a placement group
300+
pg = ray.util.placement_group(
301+
# Create `tensor_parallel_size` GPU bundles and at least one CPU bundle
302+
# so the actors can be scheduled and executed (1 CPU bundle can have infinite actors):
303+
# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html#schedule-tasks-and-actors-to-placement-groups-use-reserved-resources
304+
bundles=[{"CPU": 1}] + [{"GPU": 1}] * tensor_parallel_size,
305+
strategy="SPREAD" if pipeline_parallel_size > 1 else "STRICT_PACK",
306+
_soft_target_node_id=node_id if pipeline_parallel_size is None else None,
307+
)
308+
309+
self._logger.info(
310+
f"Step '{step.name}' uses `vLLM`. Created a Ray placement group with bundle"
311+
f" specs: {pg.bundle_specs}"
312+
)
313+
314+
return ray.util.scheduling_strategies.PlacementGroupSchedulingStrategy( # type: ignore
315+
placement_group=pg,
316+
)
317+
258318
def _teardown(self) -> None:
259319
"""Clean/release/stop resources reserved to run the pipeline."""
260320
if self._write_buffer:

0 commit comments

Comments
 (0)