Skip to content

Commit ca2973c

Browse files
tomvdwThe TensorFlow Datasets Authors
authored and
The TensorFlow Datasets Authors
committed
Restructure logic to minimize the number of file system accesses
This also introduces a method that uses a glob to find all version folders instead of listing everything in a dir and then doing is_dir on all of them. PiperOrigin-RevId: 697562330
1 parent 0ca4911 commit ca2973c

File tree

7 files changed

+271
-231
lines changed

7 files changed

+271
-231
lines changed

tensorflow_datasets/core/dataset_builders/conll/conll_dataset_builder_test.py

Lines changed: 46 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""Tests for conll_dataset_builder."""
1716
import textwrap
18-
from unittest import mock
1917

2018
from etils import epath
2119
import pytest
@@ -25,28 +23,22 @@
2523

2624
_FOLDER_PATH = "mock/path"
2725

28-
_VALID_INPUT = textwrap.dedent(
29-
"""
26+
_VALID_INPUT = textwrap.dedent("""
3027
-DOCSTART- -X- -X- O
3128
Winter NN B-NP O
3229
is VBZ B-VP O
3330
3431
Air NN I-NP O
3532
. . O O
36-
"""
37-
)
33+
""")
3834

39-
_INVALID_INPUT = textwrap.dedent(
40-
"""
35+
_INVALID_INPUT = textwrap.dedent("""
4136
Winter NN B-NP
4237
is VBZ B-VP O
4338
4439
Air NN I-NP O
4540
. . O O
46-
"""
47-
)
48-
49-
_INPUT_PATH = epath.Path(_FOLDER_PATH, "input_path.txt")
41+
""")
5042

5143

5244
class DummyConllDataset(conll_dataset_builder.ConllDatasetBuilder):
@@ -63,60 +55,56 @@ def _info(self) -> tfds.core.DatasetInfo:
6355
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
6456
"""Returns SplitGenerators."""
6557
del dl_manager
66-
return {"train": self._generate_examples(_INPUT_PATH)}
67-
68-
69-
def test_generate_example():
70-
tf_mock = mock.Mock()
71-
tf_mock.gfile.GFile.return_value = _VALID_INPUT
72-
expected_examples = []
73-
74-
dataset = DummyConllDataset()
75-
76-
with tfds.testing.MockFs() as fs:
77-
fs.add_file(path=_INPUT_PATH, content=_VALID_INPUT)
78-
examples = list(dataset._generate_examples(_INPUT_PATH))
79-
80-
expected_examples = [
81-
(
82-
0,
83-
{
84-
"tokens": ["Winter", "is"],
85-
"pos": ["NN", "VBZ"],
86-
"chunks": ["B-NP", "B-VP"],
87-
"ner": ["O", "O"],
88-
},
89-
),
90-
(
91-
1,
92-
{
93-
"tokens": ["Air", "."],
94-
"pos": ["NN", "."],
95-
"chunks": ["I-NP", "O"],
96-
"ner": ["O", "O"],
97-
},
98-
),
99-
]
100-
101-
assert examples == expected_examples
102-
103-
for _, example in examples:
104-
assert len(example) == len(conll_lib.CONLL_2003_ORDERED_FEATURES)
58+
return {"train": self._generate_examples("/tmp/input.txt")}
59+
60+
61+
def test_generate_example(tmpdir):
62+
tmpdir = epath.Path(tmpdir)
63+
input_path = tmpdir / "input_path.txt"
64+
input_path.write_text(_VALID_INPUT)
65+
66+
dataset = DummyConllDataset(data_dir=tmpdir)
67+
examples = list(dataset._generate_examples(input_path))
68+
69+
expected_examples = [
70+
(
71+
0,
72+
{
73+
"tokens": ["Winter", "is"],
74+
"pos": ["NN", "VBZ"],
75+
"chunks": ["B-NP", "B-VP"],
76+
"ner": ["O", "O"],
77+
},
78+
),
79+
(
80+
1,
81+
{
82+
"tokens": ["Air", "."],
83+
"pos": ["NN", "."],
84+
"chunks": ["I-NP", "O"],
85+
"ner": ["O", "O"],
86+
},
87+
),
88+
]
89+
90+
assert examples == expected_examples
91+
92+
for _, example in examples:
93+
assert len(example) == len(conll_lib.CONLL_2003_ORDERED_FEATURES)
10594

10695
assert len(examples) == 2
10796

10897

109-
def test_generate_corrupted_example():
110-
tf_mock = mock.Mock()
111-
tf_mock.gfile.GFile.return_value = _VALID_INPUT
112-
dataset = DummyConllDataset()
98+
def test_generate_corrupted_example(tmpdir):
99+
tmpdir = epath.Path(tmpdir)
100+
input_path = tmpdir / "input_path.txt"
101+
input_path.write_text(_INVALID_INPUT)
102+
dataset = DummyConllDataset(data_dir=tmpdir)
113103

114104
error_line = "Winter NN B-NP"
115105
error_msg = (
116106
f"Mismatch in the number of features found in line: {error_line}\n\n"
117107
"Should be 4, but found 3"
118108
)
119109
with pytest.raises(ValueError, match=error_msg):
120-
with tfds.testing.MockFs() as fs:
121-
fs.add_file(path=_INPUT_PATH, content=_INVALID_INPUT)
122-
list(dataset._generate_examples(_INPUT_PATH))
110+
list(dataset._generate_examples(input_path))

0 commit comments

Comments
 (0)