|
3 | 3 | import functools
|
4 | 4 | import logging
|
5 | 5 | import os
|
| 6 | +import tempfile |
6 | 7 | from pathlib import Path
|
7 | 8 | from typing import List, Optional, Tuple, Union
|
8 | 9 |
|
@@ -117,9 +118,27 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
|
117 | 118 | cfg.pretraining_dataset[0]["type"] or "pretrain",
|
118 | 119 | )
|
119 | 120 |
|
120 |
| - iter_ds = load_dataset( |
121 |
| - path, streaming=True, split=split, name=name, data_files=data_files |
122 |
| - ) |
| 121 | + # when letting accelerator dispatch batches from the main process, we don't need to load the dataset from |
| 122 | + # other ranks, we just need to present a fake dataset |
| 123 | + if ( |
| 124 | + cfg.accelerator_config |
| 125 | + and cfg.accelerator_config.dispatch_batches |
| 126 | + and not is_local_main_process() |
| 127 | + ): |
| 128 | + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f: |
| 129 | + f.write("text\n") |
| 130 | + f.write("lorem ipsum dolor sit amet\n") |
| 131 | + # rewind the file pointer to the beginning so we can read it again |
| 132 | + f.seek(0) |
| 133 | + iter_ds = load_dataset( |
| 134 | + "csv", data_files=f.name, split="train", streaming=True |
| 135 | + ) |
| 136 | + else: |
| 137 | + if is_local_main_process(): |
| 138 | + iter_ds = load_dataset( |
| 139 | + path, streaming=True, split=split, name=name, data_files=data_files |
| 140 | + ) |
| 141 | + |
123 | 142 | if skip:
|
124 | 143 | LOG.info(f"Skipping {skip} samples from the dataset")
|
125 | 144 | iter_ds = iter_ds.skip(skip)
|
|
0 commit comments