Skip to content

Commit abd6d88

Browse files
committed
Merge branch 'develop' into cuda-device-placement-mixin-file-per-host
2 parents d96c21e + 90909ab commit abd6d88

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

src/distilabel/pipeline/base.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,6 +1118,9 @@ def _manage_batch_flow(self, batch: "_Batch") -> None:
11181118
self._send_batch_to_step(new_batch)
11191119
else:
11201120
self._request_more_batches_if_needed(step)
1121+
else:
1122+
if len(self.dag) == 1:
1123+
self._request_batch_from_generator(step.name) # type: ignore
11211124

11221125
self._cache()
11231126

@@ -1225,6 +1228,19 @@ def _request_initial_batches(self) -> None:
12251228
)
12261229
self._send_batch_to_step(batch)
12271230

1231+
def _request_batch_from_generator(self, step_name: str) -> None:
1232+
"""Request a new batch to a `GeneratorStep`.
1233+
1234+
Args:
1235+
step_name: the name of the `GeneratorStep` to which a batch has to be requested.
1236+
"""
1237+
# Get the last batch that the previous step sent to generate the next batch
1238+
# (next `seq_no`).
1239+
last_batch = self._batch_manager.get_last_batch_sent(step_name) # type: ignore
1240+
if last_batch is None:
1241+
return
1242+
self._send_batch_to_step(last_batch.next_batch())
1243+
12281244
def _request_more_batches_if_needed(self, step: "Step") -> None:
12291245
"""Request more batches to the predecessors steps of `step` if needed.
12301246
@@ -1239,17 +1255,7 @@ def _request_more_batches_if_needed(self, step: "Step") -> None:
12391255
if previous_step_name not in self.dag.root_steps:
12401256
continue
12411257

1242-
# Get the last batch that the previous step sent to generate the next batch
1243-
# (next `seq_no`).
1244-
last_batch = self._batch_manager.get_last_batch_sent(previous_step_name) # type: ignore
1245-
if last_batch is None:
1246-
continue
1247-
1248-
self._logger.debug(
1249-
f"Step '{step.name}' input buffer for step '{previous_step_name}' is"
1250-
" empty. Requesting new batch..."
1251-
)
1252-
self._send_batch_to_step(last_batch.next_batch())
1258+
self._request_batch_from_generator(previous_step_name)
12531259

12541260
def _handle_batch_on_stop(self, batch: "_Batch") -> None:
12551261
"""Handles a batch that was received from the output queue when the pipeline was

0 commit comments

Comments
 (0)