Skip to content

Commit d8ec4df

Browse files
Merge pull request #17 from wellcometrust/feature/ivyleavedtoadflax/refactor_load_tsv
Refactor load_tsv to cover multitask case
2 parents 085f0fb + 87c098c commit d8ec4df

File tree

5 files changed

+139
-132
lines changed

5 files changed

+139
-132
lines changed

deep_reference_parser/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from .reference_utils import (
2525
break_into_chunks,
2626
labels_to_prodigy,
27-
load_data,
2827
load_tsv,
2928
prodigy_to_conll,
3029
prodigy_to_lists,

deep_reference_parser/reference_utils.py

Lines changed: 42 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -8,161 +8,74 @@
88
import json
99
import os
1010
import pickle
11+
import pandas as pd
1112

1213
import spacy
1314

1415
from .logger import logger
1516

1617

17-
def load_data(filepath):
18+
def split_list_by_linebreaks(tokens):
19+
"""Cycle through a list of tokens (or labels) and split them into lists
20+
based on the presence of Nones or more likely math.nan caused by converting
21+
pd.DataFrame columns to lists.
1822
"""
19-
Load and return the data stored in the given path.
20-
21-
Adapted from: https://github.com/dhlab-epfl/LinkedBooksDeepReferenceParsing
22-
23-
The data is structured as follows:
24-
* Each line contains four columns separated by a single space.
25-
* Each word has been put on a separate line and there is an empty line
26-
after each sentence.
27-
* The first item on each line is a word, the second, third and fourth are
28-
tags related to the word.
29-
30-
Example:
31-
32-
The sentence "L. Antonielli, Iprefetti dell' Italia napoleonica, Bologna
33-
1983." is represented in the dataset as:
34-
35-
```
36-
L author b-secondary b-r
37-
. author i-secondary i-r
38-
Antonielli author i-secondary i-r
39-
, author i-secondary i-r
40-
Iprefetti title i-secondary i-r
41-
dell title i-secondary i-r
42-
’ title i-secondary i-r
43-
Italia title i-secondary i-r
44-
napoleonica title i-secondary i-r
45-
, title i-secondary i-r
46-
Bologna publicationplace i-secondary i-r
47-
1983 year e-secondary i-r
48-
. year e-secondary e-r
49-
```
50-
51-
Args:
52-
filepath (str): Path to the data.
53-
54-
Returns:
55-
four lists: The first contains tokens, the next three contain
56-
corresponding labels.
57-
58-
"""
59-
60-
# Arrays to return
61-
words = []
62-
tags_1 = []
63-
tags_2 = []
64-
tags_3 = []
65-
66-
word = tags1 = tags2 = tags3 = []
67-
with open(filepath, "r") as file:
68-
for line in file:
69-
# Do not take the first line into consideration
70-
71-
if "DOCSTART" not in line:
72-
# Check if empty line
73-
74-
if line in ["\n", "\r\n"]:
75-
# Append line
76-
77-
words.append(word)
78-
tags_1.append(tags1)
79-
tags_2.append(tags2)
80-
tags_3.append(tags3)
81-
82-
# Reset
83-
word = []
84-
tags1 = []
85-
tags2 = []
86-
tags3 = []
87-
88-
else:
89-
# Split the line into words, tag #1
90-
w = line[:-1].split(" ")
91-
92-
word.append(w[0])
93-
tags1.append(w[1])
94-
tags2.append(w[2])
95-
tags3.append(w[3])
96-
97-
logger.info("Loaded %s training examples", len(words))
98-
99-
return words, tags_1, tags_2, tags_3
100-
23+
out = []
24+
tokens_gen = iter(tokens)
25+
while True:
26+
try:
27+
token = next(tokens_gen)
28+
if isinstance(token, str) and token:
29+
out.append(token)
30+
else:
31+
yield out
32+
out = []
33+
except StopIteration:
34+
if out:
35+
yield out
36+
break
10137

10238
def load_tsv(filepath, split_char="\t"):
10339
"""
10440
Load and return the data stored in the given path.
10541
106-
Adapted from: https://github.com/dhlab-epfl/LinkedBooksDeepReferenceParsing
42+
Expects data in the following format (tab separations).
43+
44+
References o o
45+
o o
46+
1 o o
47+
. o o
48+
o o
49+
WHO title b-r
50+
treatment title i-r
51+
guidelines title i-r
52+
for title i-r
53+
drug title i-r
54+
- title i-r
55+
resistant title i-r
56+
tuberculosis title i-r
57+
, title i-r
58+
2016 title i-r
10759
108-
NOTE: In the current implementation in deep_reference_parser, only one set
109-
of tags is used. The others will be used in a later PR.
11060
111-
The data is structured as follows:
112-
* Each line contains four columns separated by a single space.
113-
* Each word has been put on a separate line and there is an empty line
114-
after each sentence.
115-
* The first item on each line is a word, the second, third and fourth are
116-
tags related to the word.
11761
11862
Args:
11963
filepath (str): Path to the data.
12064
split_char(str): Character to be used to split each line of the
12165
document.
12266
12367
Returns:
124-
two lists: The first contains tokens, the second contains corresponding
125-
labels.
68+
a series of lists depending on the number of label columns provided in
69+
filepath.
12670
12771
"""
12872

129-
# Arrays to return
130-
words = []
131-
tags_1 = []
132-
133-
word = []
134-
tags1 = []
135-
136-
with open(filepath, "r") as file:
137-
for line in file:
138-
# Check if empty line
139-
140-
if line in ["\n", "\r\n", "\t\n"]:
141-
# Append line
142-
143-
words.append(word)
144-
tags_1.append(tags1)
145-
146-
# Reset
147-
word = []
148-
tags1 = []
149-
150-
else:
151-
152-
# Split the line into words, tag #1
153-
154-
w = line[:-1].split(split_char)
155-
word.append(w[0])
156-
157-
# If tags are passed, (for training) then also add
158-
159-
if len(w) == 2:
160-
161-
tags1.append(w[1])
73+
df = pd.read_csv(filepath, delimiter=split_char, header=None, skip_blank_lines=False)
74+
out = [list(split_list_by_linebreaks(column)) for _, column in df.iteritems()]
16275

163-
logger.info("Loaded %s training examples", len(words))
76+
logger.info("Loaded %s training examples", len(out[0]))
16477

165-
return words, tags_1
78+
return tuple(out)
16679

16780

16881
def prodigy_to_conll(docs):

tests/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ def get_path(p):
1313
TEST_REFERENCES = get_path("test_data/test_references.txt")
1414
TEST_TSV_PREDICT = get_path("test_data/test_tsv_predict.tsv")
1515
TEST_TSV_TRAIN = get_path("test_data/test_tsv_train.tsv")
16+
TEST_LOAD_TSV = get_path("test_data/test_load_tsv.tsv")

tests/test_data/test_load_tsv.tsv

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
the i-r a
2+
focus i-r a
3+
in i-r a
4+
Daloa i-r a
5+
, i-r a
6+
Côte i-r a
7+
d’Ivoire]. i-r a
8+
9+
Bulletin i-r a
10+
de i-r a
11+
la i-r a
12+
Société i-r a
13+
de i-r a
14+
Pathologie i-r a
15+
16+
Exotique i-r a
17+
et i-r a
18+

tests/test_reference_utils.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
prodigy_to_conll,
1313
write_tsv,
1414
yield_token_label_pairs,
15+
split_list_by_linebreaks,
1516
)
1617

17-
from .common import TEST_TSV_PREDICT, TEST_TSV_TRAIN
18+
from .common import TEST_TSV_PREDICT, TEST_TSV_TRAIN, TEST_LOAD_TSV
1819

1920

2021
def test_prodigy_to_conll():
@@ -75,6 +76,14 @@ def test_load_tsv_train():
7576

7677
actual = load_tsv(TEST_TSV_TRAIN)
7778

79+
assert len(actual[0][0]) == len(expected[0][0])
80+
assert len(actual[0][1]) == len(expected[0][1])
81+
assert len(actual[0][2]) == len(expected[0][2])
82+
83+
assert len(actual[1][0]) == len(expected[1][0])
84+
assert len(actual[1][1]) == len(expected[1][1])
85+
assert len(actual[1][2]) == len(expected[1][2])
86+
7887
assert actual == expected
7988

8089

@@ -109,13 +118,59 @@ def test_load_tsv_predict():
109118
["Bulletin", "de", "la", "Société", "de", "Pathologie"],
110119
["Exotique", "et"],
111120
],
112-
[[], [], [],],
113121
)
114122

115123
actual = load_tsv(TEST_TSV_PREDICT)
116124

117125
assert actual == expected
118126

127+
def test_load_tsv_train_multiple_labels():
128+
"""
129+
Text of TEST_TSV_TRAIN:
130+
131+
```
132+
the i-r
133+
focus i-r
134+
in i-r
135+
Daloa i-r
136+
, i-r
137+
Côte i-r
138+
d’Ivoire]. i-r
139+
140+
Bulletin i-r
141+
de i-r
142+
la i-r
143+
Société i-r
144+
de-r
145+
Pathologie i-r
146+
147+
Exotique i-r
148+
et i-r
149+
```
150+
"""
151+
152+
expected = (
153+
[
154+
["the", "focus", "in", "Daloa", ",", "Côte", "d’Ivoire]."],
155+
["Bulletin", "de", "la", "Société", "de", "Pathologie"],
156+
["Exotique", "et"],
157+
],
158+
[
159+
["i-r", "i-r", "i-r", "i-r", "i-r", "i-r", "i-r"],
160+
["i-r", "i-r", "i-r", "i-r", "i-r", "i-r"],
161+
["i-r", "i-r"],
162+
],
163+
[
164+
["a", "a", "a", "a", "a", "a", "a"],
165+
["a", "a", "a", "a", "a", "a"],
166+
["a", "a"],
167+
],
168+
)
169+
170+
actual = load_tsv(TEST_LOAD_TSV)
171+
172+
assert actual == expected
173+
119174

120175
def test_yield_toke_label_pairs():
121176

@@ -197,3 +252,24 @@ def test_break_into_chunks():
197252
actual = break_into_chunks(before, max_words=2)
198253

199254
assert expected == actual
255+
256+
def test_split_list_by_linebreaks():
257+
258+
lst = ["a", "b", "c", None, "d"]
259+
expected = [["a", "b", "c"], ["d"]]
260+
261+
actual = split_list_by_linebreaks(lst)
262+
263+
def test_list_by_linebreaks_ending_in_None():
264+
265+
lst = ["a", "b", "c", float("nan"), "d", None]
266+
expected = [["a", "b", "c"], ["d"]]
267+
268+
actual = split_list_by_linebreaks(lst)
269+
270+
def test_list_by_linebreaks_starting_in_None():
271+
272+
lst = [None, "a", "b", "c", None, "d"]
273+
expected = [["a", "b", "c"], ["d"]]
274+
275+
actual = split_list_by_linebreaks(lst)

0 commit comments

Comments
 (0)