|
4 | 4 | import re
|
5 | 5 | import sys
|
6 | 6 | from pathlib import Path
|
7 |
| -from typing import Tuple, Union |
| 7 | +from typing import Dict, Tuple, Union |
8 | 8 |
|
9 | 9 | import torch
|
10 |
| -from datasets import Dataset, IterableDataset, load_dataset |
| 10 | +from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict, load_dataset |
11 | 11 | from torch.nn import functional as F
|
12 | 12 | from transformers import EvalPrediction, HfArgumentParser, TrainingArguments
|
13 | 13 | from transformers.trainer_utils import get_last_checkpoint
|
@@ -337,3 +337,53 @@ def get_meds_extension_path(data_folder: str, dataset_prepared_path: str):
|
337 | 337 | basename = os.path.basename(data_folder)
|
338 | 338 | meds_extension_path = os.path.join(dataset_prepared_path, f"{basename}_meds_extension")
|
339 | 339 | return meds_extension_path
|
| 340 | + |
| 341 | + |
| 342 | +def convert_dataset_to_iterable_dataset( |
| 343 | + dataset: Union[Dataset, DatasetDict], num_shards: int = 1 |
| 344 | +) -> Union[IterableDataset, Dict[str, IterableDataset]]: |
| 345 | + """ |
| 346 | + Converts a Hugging Face `Dataset` or `DatasetDict` into an `IterableDataset` or. |
| 347 | +
|
| 348 | + a dictionary of `IterableDataset` objects, enabling efficient parallel processing |
| 349 | + using multiple workers in a data loader. |
| 350 | +
|
| 351 | + Parameters |
| 352 | + ---------- |
| 353 | + dataset : Union[Dataset, DatasetDict] |
| 354 | + The input dataset, which can be either: |
| 355 | + - A single `Dataset` object |
| 356 | + - A `DatasetDict` (containing multiple datasets, such as train, validation, and test splits) |
| 357 | +
|
| 358 | + num_shards : int |
| 359 | + The number of workers (shards) to split the dataset into for parallel data loading. |
| 360 | + This allows efficient sharding of the dataset across multiple workers. |
| 361 | +
|
| 362 | + Returns |
| 363 | + ------- |
| 364 | + Union[IterableDataset, Dict[str, IterableDataset]] |
| 365 | + The converted dataset, either as: |
| 366 | + - A single `IterableDataset` if the input was a `Dataset` |
| 367 | + - A dictionary of `IterableDataset` objects if the input was a `DatasetDict` or `IterableDatasetDict` |
| 368 | +
|
| 369 | + Notes |
| 370 | + ----- |
| 371 | + - If the input `dataset` is a `DatasetDict` (or `IterableDatasetDict`), each dataset split |
| 372 | + (e.g., train, validation, test) is converted into an `IterableDataset`. |
| 373 | + - If the input `dataset` is a single `Dataset`, it is directly converted into an `IterableDataset`. |
| 374 | + - The `num_shards` parameter in `to_iterable_dataset` allows splitting the dataset for parallel |
| 375 | + data loading with multiple workers. |
| 376 | +
|
| 377 | + Example |
| 378 | + ------- |
| 379 | + # Convert a standard Dataset to an IterableDataset for parallel processing |
| 380 | + iterable_dataset = convert_dataset_to_iterable_dataset(my_dataset, dataloader_num_workers=4) |
| 381 | +
|
| 382 | + # Convert a DatasetDict (e.g., train, validation splits) into IterableDataset objects |
| 383 | + iterable_dataset_dict = convert_dataset_to_iterable_dataset(my_dataset_dict, dataloader_num_workers=4) |
| 384 | + """ |
| 385 | + if isinstance(dataset, DatasetDict) or isinstance(dataset, IterableDatasetDict): |
| 386 | + dataset = {k: v.to_iterable_dataset(num_shards=num_shards) for k, v in dataset.items()} |
| 387 | + else: |
| 388 | + dataset = dataset.to_iterable_dataset(num_shards=num_shards) |
| 389 | + return dataset |
0 commit comments