Skip to content

Commit 1552d98

Browse files
authored
created a helper function to convert the existing huggingface dataset to an iterable dataset (#64)
1 parent fa6d0ae commit 1552d98

File tree

3 files changed

+66
-18
lines changed

3 files changed

+66
-18
lines changed

src/cehrbert/runners/hf_cehrbert_finetune_runner.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from cehrbert.models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer
2828
from cehrbert.runners.hf_runner_argument_dataclass import FineTuneModelType, ModelArguments
2929
from cehrbert.runners.runner_util import (
30+
convert_dataset_to_iterable_dataset,
3031
generate_prepared_ds_path,
3132
get_last_hf_checkpoint,
3233
get_meds_extension_path,
@@ -99,7 +100,9 @@ def main():
99100
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
100101
processed_dataset = load_from_disk(str(prepared_ds_path))
101102
if data_args.streaming:
102-
processed_dataset = processed_dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
103+
processed_dataset = convert_dataset_to_iterable_dataset(
104+
processed_dataset, num_shards=training_args.dataloader_num_workers
105+
)
103106
LOG.info("Prepared dataset loaded from disk...")
104107
else:
105108
# If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
@@ -112,13 +115,9 @@ def main():
112115
LOG.info(f"Trying to load the MEDS extension from disk at {meds_extension_path}...")
113116
dataset = load_from_disk(meds_extension_path)
114117
if data_args.streaming:
115-
if isinstance(dataset, DatasetDict):
116-
dataset = {
117-
k: v.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
118-
for k, v in dataset.items()
119-
}
120-
else:
121-
dataset = dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
118+
dataset = convert_dataset_to_iterable_dataset(
119+
dataset, num_shards=training_args.dataloader_num_workers
120+
)
122121
except Exception as e:
123122
LOG.exception(e)
124123
dataset = create_dataset_from_meds_reader(data_args, is_pretraining=False)

src/cehrbert/runners/hf_cehrbert_pretrain_runner.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from cehrbert.models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer
1515
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments, ModelArguments
1616
from cehrbert.runners.runner_util import (
17+
convert_dataset_to_iterable_dataset,
1718
generate_prepared_ds_path,
1819
get_last_hf_checkpoint,
1920
get_meds_extension_path,
@@ -160,7 +161,9 @@ def main():
160161
LOG.info("Loading prepared dataset from disk at %s...", prepared_ds_path)
161162
processed_dataset = load_from_disk(str(prepared_ds_path))
162163
if data_args.streaming:
163-
processed_dataset = processed_dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
164+
processed_dataset = convert_dataset_to_iterable_dataset(
165+
processed_dataset, num_shards=training_args.dataloader_num_workers
166+
)
164167
LOG.info("Prepared dataset loaded from disk...")
165168
# If the data has been processed in the past, it's assume the tokenizer has been created
166169
# before. We load the CEHR-BERT tokenizer from the output folder.
@@ -179,13 +182,9 @@ def main():
179182
)
180183
dataset = load_from_disk(meds_extension_path)
181184
if data_args.streaming:
182-
if isinstance(dataset, DatasetDict):
183-
dataset = {
184-
k: v.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
185-
for k, v in dataset.items()
186-
}
187-
else:
188-
dataset = dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
185+
dataset = convert_dataset_to_iterable_dataset(
186+
dataset, num_shards=training_args.dataloader_num_workers
187+
)
189188
except FileNotFoundError as e:
190189
LOG.exception(e)
191190
dataset = create_dataset_from_meds_reader(data_args, is_pretraining=True)

src/cehrbert/runners/runner_util.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import re
55
import sys
66
from pathlib import Path
7-
from typing import Tuple, Union
7+
from typing import Dict, Tuple, Union
88

99
import torch
10-
from datasets import Dataset, IterableDataset, load_dataset
10+
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict, load_dataset
1111
from torch.nn import functional as F
1212
from transformers import EvalPrediction, HfArgumentParser, TrainingArguments
1313
from transformers.trainer_utils import get_last_checkpoint
@@ -337,3 +337,53 @@ def get_meds_extension_path(data_folder: str, dataset_prepared_path: str):
337337
basename = os.path.basename(data_folder)
338338
meds_extension_path = os.path.join(dataset_prepared_path, f"{basename}_meds_extension")
339339
return meds_extension_path
340+
341+
342+
def convert_dataset_to_iterable_dataset(
343+
dataset: Union[Dataset, DatasetDict], num_shards: int = 1
344+
) -> Union[IterableDataset, Dict[str, IterableDataset]]:
345+
"""
346+
Converts a Hugging Face `Dataset` or `DatasetDict` into an `IterableDataset` or.
347+
348+
a dictionary of `IterableDataset` objects, enabling efficient parallel processing
349+
using multiple workers in a data loader.
350+
351+
Parameters
352+
----------
353+
dataset : Union[Dataset, DatasetDict]
354+
The input dataset, which can be either:
355+
- A single `Dataset` object
356+
- A `DatasetDict` (containing multiple datasets, such as train, validation, and test splits)
357+
358+
num_shards : int
359+
The number of workers (shards) to split the dataset into for parallel data loading.
360+
This allows efficient sharding of the dataset across multiple workers.
361+
362+
Returns
363+
-------
364+
Union[IterableDataset, Dict[str, IterableDataset]]
365+
The converted dataset, either as:
366+
- A single `IterableDataset` if the input was a `Dataset`
367+
- A dictionary of `IterableDataset` objects if the input was a `DatasetDict` or `IterableDatasetDict`
368+
369+
Notes
370+
-----
371+
- If the input `dataset` is a `DatasetDict` (or `IterableDatasetDict`), each dataset split
372+
(e.g., train, validation, test) is converted into an `IterableDataset`.
373+
- If the input `dataset` is a single `Dataset`, it is directly converted into an `IterableDataset`.
374+
- The `num_shards` parameter in `to_iterable_dataset` allows splitting the dataset for parallel
375+
data loading with multiple workers.
376+
377+
Example
378+
-------
379+
# Convert a standard Dataset to an IterableDataset for parallel processing
380+
iterable_dataset = convert_dataset_to_iterable_dataset(my_dataset, dataloader_num_workers=4)
381+
382+
# Convert a DatasetDict (e.g., train, validation splits) into IterableDataset objects
383+
iterable_dataset_dict = convert_dataset_to_iterable_dataset(my_dataset_dict, dataloader_num_workers=4)
384+
"""
385+
if isinstance(dataset, DatasetDict) or isinstance(dataset, IterableDatasetDict):
386+
dataset = {k: v.to_iterable_dataset(num_shards=num_shards) for k, v in dataset.items()}
387+
else:
388+
dataset = dataset.to_iterable_dataset(num_shards=num_shards)
389+
return dataset

0 commit comments

Comments
 (0)