Skip to content

Commit 25601bb

Browse files
authored
Create file per hostname in CudaDevicePlacementMixin (#814)
* Create file per hostname * Set default `_desired_num_gpus` to `1` * Fix `GeneratorTask`s not getting assigned gpus and name * Add `_init_cuda_device_placement` method * Remove info message * Add disabling `CudaDevicePlacementMixin` if `RayPipeline` * Fix unit test
1 parent 04b86f5 commit 25601bb

File tree

6 files changed

+45
-7
lines changed

6 files changed

+45
-7
lines changed

src/distilabel/llms/mixins/cuda_device_placement.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import json
1616
import logging
1717
import os
18+
import socket
1819
import tempfile
1920
from contextlib import contextmanager
2021
from pathlib import Path
@@ -26,7 +27,11 @@
2627
from distilabel.mixins.runtime_parameters import RuntimeParameter
2728

2829
_CUDA_DEVICE_PLACEMENT_MIXIN_FILE = (
29-
Path(tempfile.gettempdir()) / "distilabel_cuda_device_placement_mixin.json"
30+
Path(tempfile.gettempdir())
31+
/ "distilabel"
32+
/ "cuda_device_placement"
33+
/ socket.gethostname()
34+
/ "distilabel_cuda_device_placement_mixin.json"
3035
)
3136

3237

@@ -43,6 +48,8 @@ class CudaDevicePlacementMixin(BaseModel):
4348
placement information provided in `_device_llm_placement_map`. If set to a list
4449
of devices, it will be checked if the devices are available to be used by the
4550
`LLM`. If not, a warning will be logged.
51+
disable_cuda_device_placement: Whether to disable the CUDA device placement logic
52+
or not. Defaults to `False`.
4653
_llm_identifier: the identifier of the `LLM` to be used as key in `_device_llm_placement_map`.
4754
_device_llm_placement_map: a dictionary with the device placement information for each
4855
`LLM`.
@@ -51,6 +58,10 @@ class CudaDevicePlacementMixin(BaseModel):
5158
cuda_devices: RuntimeParameter[Union[List[int], Literal["auto"]]] = Field(
5259
default="auto", description="A list with the ID of the CUDA devices to be used."
5360
)
61+
disable_cuda_device_placement: RuntimeParameter[bool] = Field(
62+
default=False,
63+
description="Whether to disable the CUDA device placement logic or not.",
64+
)
5465

5566
_llm_identifier: Union[str, None] = PrivateAttr(default=None)
5667
_desired_num_gpus: PositiveInt = PrivateAttr(default=1)
@@ -63,6 +74,9 @@ def load(self) -> None:
6374
"""Assign CUDA devices to the LLM based on the device placement information provided
6475
in `_device_llm_placement_map`."""
6576

77+
if self.disable_cuda_device_placement:
78+
return
79+
6680
try:
6781
import pynvml
6882

@@ -88,6 +102,9 @@ def load(self) -> None:
88102
def unload(self) -> None:
89103
"""Unloads the LLM and removes the CUDA devices assigned to it from the device
90104
placement information provided in `_device_llm_placement_map`."""
105+
if self.disable_cuda_device_placement:
106+
return
107+
91108
with self._device_llm_placement_map() as device_map:
92109
if self._llm_identifier in device_map:
93110
self._logger.debug( # type: ignore
@@ -105,6 +122,7 @@ def _device_llm_placement_map(self) -> Generator[Dict[str, List[int]], None, Non
105122
Yields:
106123
The content of the device placement file.
107124
"""
125+
_CUDA_DEVICE_PLACEMENT_MIXIN_FILE.parent.mkdir(parents=True, exist_ok=True)
108126
_CUDA_DEVICE_PLACEMENT_MIXIN_FILE.touch()
109127
with portalocker.Lock(
110128
_CUDA_DEVICE_PLACEMENT_MIXIN_FILE,

src/distilabel/pipeline/local.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def _run_step(self, step: "_Step", input_queue: "Queue[Any]", replica: int) -> N
233233
output_queue=self._output_queue,
234234
load_queue=self._load_queue,
235235
dry_run=self._dry_run,
236+
ray_pipeline=False,
236237
)
237238

238239
self._pool.apply_async(step_wrapper.run, error_callback=self._error_callback)

src/distilabel/pipeline/ray.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def run(self) -> str:
235235
output_queue=self._output_queue,
236236
load_queue=self._load_queue,
237237
dry_run=self._dry_run,
238+
ray_pipeline=True,
238239
),
239240
log_queue=self._log_queue,
240241
)

src/distilabel/pipeline/step_wrapper.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from distilabel.pipeline.constants import LAST_BATCH_SENT_FLAG
2222
from distilabel.pipeline.typing import StepLoadStatus
2323
from distilabel.steps.base import GeneratorStep, Step, _Step
24-
from distilabel.steps.tasks.base import Task
24+
from distilabel.steps.tasks.base import _Task
2525

2626

2727
class _StepWrapper:
@@ -44,6 +44,7 @@ def __init__(
4444
output_queue: "Queue[_Batch]",
4545
load_queue: "Queue[Union[StepLoadStatus, None]]",
4646
dry_run: bool = False,
47+
ray_pipeline: bool = False,
4748
) -> None:
4849
"""Initializes the `_ProcessWrapper`.
4950
@@ -54,21 +55,32 @@ def __init__(
5455
load_queue: The queue used to notify the main process that the step has been
5556
loaded, has been unloaded or has failed to load.
5657
dry_run: Flag to ensure we are forcing to run the last batch.
58+
ray_pipeline: Whether the step is running a `RayPipeline` or not.
5759
"""
5860
self.step = step
5961
self.replica = replica
6062
self.input_queue = input_queue
6163
self.output_queue = output_queue
6264
self.load_queue = load_queue
63-
self._dry_run = dry_run
65+
self.dry_run = dry_run
66+
self.ray_pipeline = ray_pipeline
6467

68+
self._init_cuda_device_placement()
69+
70+
def _init_cuda_device_placement(self) -> None:
71+
"""Sets the LLM identifier and the number of desired GPUs of the `CudaDevicePlacementMixin`
72+
if the step is a `_Task` that uses an `LLM` with CUDA capabilities."""
6573
if (
66-
isinstance(self.step, Task)
74+
isinstance(self.step, _Task)
6775
and hasattr(self.step, "llm")
6876
and isinstance(self.step.llm, CudaDevicePlacementMixin)
6977
):
70-
self.step.llm._llm_identifier = self.step.name
71-
self.step.llm._desired_num_gpus = self.step.resources.gpus
78+
if self.ray_pipeline:
79+
self.step.llm.disable_cuda_device_placement = True
80+
else:
81+
desired_num_gpus = self.step.resources.gpus or 1
82+
self.step.llm._llm_identifier = self.step.name
83+
self.step.llm._desired_num_gpus = desired_num_gpus
7284

7385
def run(self) -> str:
7486
"""The target function executed by the process. This function will also handle
@@ -156,7 +168,7 @@ def _generator_step_process_loop(self) -> None:
156168

157169
for data, last_batch in step.process_applying_mappings(offset=offset):
158170
batch.set_data([data])
159-
batch.last_batch = self._dry_run or last_batch
171+
batch.last_batch = self.dry_run or last_batch
160172
self._send_batch(batch)
161173

162174
if batch.last_batch:

tests/unit/pipeline/test_local.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def test_run_steps(self, step_wrapper_mock: mock.MagicMock) -> None:
5555
output_queue=pipeline._output_queue,
5656
load_queue=pipeline._load_queue,
5757
dry_run=False,
58+
ray_pipeline=False,
5859
),
5960
mock.call(
6061
step=dummy_step_1,
@@ -63,6 +64,7 @@ def test_run_steps(self, step_wrapper_mock: mock.MagicMock) -> None:
6364
output_queue=pipeline._output_queue,
6465
load_queue=pipeline._load_queue,
6566
dry_run=False,
67+
ray_pipeline=False,
6668
),
6769
mock.call(
6870
step=dummy_step_1,
@@ -71,6 +73,7 @@ def test_run_steps(self, step_wrapper_mock: mock.MagicMock) -> None:
7173
output_queue=pipeline._output_queue,
7274
load_queue=pipeline._load_queue,
7375
dry_run=False,
76+
ray_pipeline=False,
7477
),
7578
mock.call(
7679
step=dummy_step_2,
@@ -79,6 +82,7 @@ def test_run_steps(self, step_wrapper_mock: mock.MagicMock) -> None:
7982
output_queue=pipeline._output_queue,
8083
load_queue=pipeline._load_queue,
8184
dry_run=False,
85+
ray_pipeline=False,
8286
),
8387
],
8488
)

tests/unit/steps/tasks/structured_outputs/test_outlines.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class DummyUserTest(BaseModel):
5959
"device_map": None,
6060
"token": None,
6161
"use_magpie_template": False,
62+
"disable_cuda_device_placement": False,
6263
"type_info": {
6364
"module": "distilabel.llms.huggingface.transformers",
6465
"name": "TransformersLLM",
@@ -85,6 +86,7 @@ class DummyUserTest(BaseModel):
8586
"device_map": None,
8687
"token": None,
8788
"use_magpie_template": False,
89+
"disable_cuda_device_placement": False,
8890
"type_info": {
8991
"module": "distilabel.llms.huggingface.transformers",
9092
"name": "TransformersLLM",

0 commit comments

Comments
 (0)