File tree 1 file changed +11
-12
lines changed
1 file changed +11
-12
lines changed Original file line number Diff line number Diff line change 1
1
"""Module containing Dataset functionality"""
2
2
3
3
import logging
4
+ import os
4
5
from typing import List
5
6
6
7
import torch
7
8
from datasets import IterableDataset
8
9
9
- from .prompt_tokenizers import InvalidDataException , PromptTokenizingStrategy
10
+ from .prompt_tokenizers import PromptTokenizingStrategy
10
11
11
12
# We want this to be a wrapper for an existing dataset that we have loaded
12
13
# lets use the concept of middlewares to wrap each dataset, for example
@@ -34,17 +35,15 @@ def __init__( # pylint: disable=super-init-not-called
34
35
self .dataset = dataset
35
36
36
37
def __iter__ (self ):
37
- iterator = iter (self .dataset )
38
- count = 0
39
- # Loop through the entire dataset
40
- for example in iterator :
41
- try :
42
- yield self .prompt_tokenizer .tokenize_prompt (example )
43
- count += 1
44
- except InvalidDataException :
45
- pass
46
- if count == 0 :
47
- raise RuntimeError ("Expected at least one datapoint in dataset." )
38
+ features = self .dataset .features .keys ()
39
+ num_proc = os .cpu_count ()
40
+ return iter (
41
+ self .dataset .map (
42
+ self .prompt_tokenizer .tokenize_prompt ,
43
+ num_proc = num_proc ,
44
+ remove_columns = features ,
45
+ )
46
+ )
48
47
49
48
50
49
# TODO this isn't the best since it can't interleave datasets
You can’t perform that action at this time.
0 commit comments