Skip to content

Commit 9f688c3

Browse files
authored
Merge pull request #33 from wellcometrust/update-changelog
Corrections post release
2 parents b301d12 + e4de5a4 commit 9f688c3

9 files changed

+140
-67
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# Changelog
22

3+
## 2020.4.23 - Pre-release
4+
5+
* Add multitask split_parse command and tests, called with python -m deep_reference_parser split_parse
6+
* Fix issues with training data creation
7+
* Output predictions of validation data by default
8+
* Various improvements - using tox for testing, refactoring, improving error messages, README and tests
9+
310
## 2020.3.3 - Pre-release
411

512
NOTE: This version includes changes to both the way that model artefacts are packaged and saved, and the way that data are laded and parsed from tsv files. This results in a significantly faster training time (c.14 hours -> c.0.5 hour), but older models will no longer be compatible. For compatibility you must use multitask modles > 2020.3.19, splitting models > 2020.3.6, and parisng models > 2020.3.8. These models currently perform less well than previous versions, but performance is expected to improve with more data and experimentation predominatly around sequence length.

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ WORD_EMBEDDING := 2020.1.1-wellcome-embeddings-300
2020
WORD_EMBEDDING_TEST := 2020.1.1-wellcome-embeddings-10-test
2121

2222
MODEL_PATH := models
23-
MODEL_VERSION := 2019.12.0
23+
MODEL_VERSION := multitask/2020.4.5_multitask
2424

2525
#
2626
# S3 Bucket

README.md

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,20 +59,20 @@ Current mode version: *2020.3.8_parsing*
5959

6060
#### Multitask model (splitting and parsing)
6161

62-
Current mode version: *2020.3.19_multitask*
62+
Current mode version: *2020.4.5_multitask*
6363

6464
|token|f1|
6565
|---|---|
66-
|author|0.9102|
67-
|title|0.8809|
68-
|year|0.7469|
69-
|o|0.8892|
70-
|parsing weighted avg|0.8869|
71-
|b-r|0.8254|
72-
|e-r|0.7908|
73-
|i-r|0.9563|
74-
|o|0.7560|
75-
|weighted avg|0.9240|
66+
|author|0.9458|
67+
|title|0.9002|
68+
|year|0.8704|
69+
|o|0.9407|
70+
|parsing weighted avg|0.9285|
71+
|b-r|0.9111|
72+
|e-r|0.8788|
73+
|i-r|0.9726|
74+
|o|0.9332|
75+
|weighted avg|0.9591|
7676

7777
#### Computing requirements
7878

@@ -82,7 +82,7 @@ Models are trained on AWS instances using CPU only.
8282
|---|---|---|---|---|
8383
|Span detection|00:26:41|m4.4xlarge|$0.88|$0.39|
8484
|Components|00:17:22|m4.4xlarge|$0.88|$0.25|
85-
|MultiTask|00:19:56|m4.4xlarge|$0.88|$0.29|
85+
|MultiTask|00:42:43|c4.4xlarge|$0.91|$0.63|
8686

8787
## tl;dr: Just get me to the references!
8888

deep_reference_parser/configs/2020.4.5_multitask.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ policy_valid = data/processed/annotated/deep_reference_parser/multitask/2020.3.1
1818
s3_slug = https://datalabs-public.s3.eu-west-2.amazonaws.com/deep_reference_parser/
1919

2020
[build]
21-
output_path = data/models/multitask/2020.4.5_multitask/
21+
output_path = models/multitask/2020.4.5_multitask/
2222
output = crf
2323
word_embeddings = embeddings/2020.1.1-wellcome-embeddings-300.txt
2424
pretrained_embedding = 0

deep_reference_parser/split_parse.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from deep_reference_parser.logger import logger
2626
from deep_reference_parser.model_utils import get_config
2727
from deep_reference_parser.reference_utils import break_into_chunks
28-
from deep_reference_parser.tokens_to_references import tokens_to_references
28+
from deep_reference_parser.tokens_to_references import tokens_to_reference_lists
2929

3030
msg = wasabi.Printer(icons={"check": "\u2023"})
3131

@@ -138,35 +138,31 @@ def split_parse(self, text, return_tokens=False, verbose=False):
138138

139139
else:
140140

141-
# TODO: return references with attributes (author, title, year)
142-
# in json format. For now just return predictions as they are to
143-
# allow testing of endpoints.
141+
# Return references with attributes (author, title, year)
142+
# in json format.
143+
# List of lists for each reference - each reference list contains all token attributes predictions
144+
# [[(token, attribute), ... , (token, attribute)], ..., [(token, attribute), ...]]
144145

145-
return preds
146-
147-
# # Otherwise convert the tokens into references and return
148-
149-
# refs = tokens_to_references(tokens, preds)
150-
151-
# if verbose:
146+
references_components = tokens_to_reference_lists(tokens, spans=preds[1], components=preds[0])
147+
if verbose:
152148

153-
# msg.divider("Results")
149+
msg.divider("Results")
154150

155-
# if refs:
151+
if references_components:
156152

157-
# msg.good(f"Found {len(refs)} references.")
158-
# msg.info("Printing found references:")
153+
msg.good(f"Found {len(references_components)} references.")
154+
msg.info("Printing found references:")
159155

160-
# for ref in refs:
161-
# msg.text(ref, icon="check", spaced=True)
156+
for ref in references_components:
157+
msg.text(ref['Reference'], icon="check", spaced=True)
162158

163-
# else:
159+
else:
164160

165-
# msg.fail("Failed to find any references.")
161+
msg.fail("Failed to find any references.")
166162

167-
# out = refs
163+
out = references_components
168164

169-
#return out
165+
return out
170166

171167

172168
@plac.annotations(

deep_reference_parser/tokens_to_references.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,12 @@
1010
from .deep_reference_parser import logger
1111

1212

13-
def tokens_to_references(tokens, labels):
14-
"""
15-
Given a corresponding list of tokens and a list of labels: split the tokens
16-
and return a list of references.
17-
18-
Args:
19-
tokens(list): A list of tokens.
20-
labels(list): A corresponding list of labels.
21-
22-
"""
13+
def get_reference_spans(tokens, spans):
2314

2415
# Flatten the lists of tokens and predictions into a single list.
2516

2617
flat_tokens = list(itertools.chain.from_iterable(tokens))
27-
flat_predictions = list(itertools.chain.from_iterable(labels))
18+
flat_predictions = list(itertools.chain.from_iterable(spans))
2819

2920
# Find all b-r and e-r tokens.
3021

@@ -37,25 +28,67 @@ def tokens_to_references(tokens, labels):
3728
logger.debug("Found %s b-r tokens", len(ref_starts))
3829
logger.debug("Found %s e-r tokens", len(ref_ends))
3930

40-
references = []
41-
4231
n_refs = len(ref_starts)
4332

4433
# Split on each b-r.
45-
# TODO: It may be worth including some simple post processing steps here
46-
# to pick up false positives, for instance cutting short a reference
47-
# after n tokens.
4834

35+
token_starts = []
36+
token_ends = []
4937
for i in range(0, n_refs):
50-
token_start = ref_starts[i]
38+
token_starts.append(ref_starts[i])
5139
if i + 1 < n_refs:
52-
53-
token_end = ref_starts[i + 1] - 1
40+
token_ends.append(ref_starts[i + 1] - 1)
5441
else:
55-
token_end = len(flat_tokens)
42+
token_ends.append(len(flat_tokens))
43+
44+
return token_starts, token_ends, flat_tokens
45+
46+
47+
def tokens_to_references(tokens, labels):
48+
"""
49+
Given a corresponding list of tokens and a list of labels: split the tokens
50+
and return a list of references.
5651
52+
Args:
53+
tokens(list): A list of tokens.
54+
labels(list): A corresponding list of labels.
55+
56+
"""
57+
58+
token_starts, token_ends, flat_tokens = get_reference_spans(tokens, labels)
59+
60+
references = []
61+
for token_start, token_end in zip(token_starts, token_ends):
5762
ref = flat_tokens[token_start : token_end + 1]
5863
flat_ref = " ".join(ref)
5964
references.append(flat_ref)
6065

6166
return references
67+
68+
def tokens_to_reference_lists(tokens, spans, components):
69+
"""
70+
Given a corresponding list of tokens, a list of
71+
reference spans (e.g. 'b-r') and components (e.g. 'author;):
72+
split the tokens according to the spans and return a
73+
list of reference components for each reference.
74+
75+
Args:
76+
tokens(list): A list of tokens.
77+
spans(list): A corresponding list of reference spans.
78+
components(list): A corresponding list of reference components.
79+
80+
"""
81+
82+
token_starts, token_ends, flat_tokens = get_reference_spans(tokens, spans)
83+
flat_components = list(itertools.chain.from_iterable(components))
84+
85+
references_components = []
86+
for token_start, token_end in zip(token_starts, token_ends):
87+
88+
ref_tokens = flat_tokens[token_start : token_end + 1]
89+
ref_components = flat_components[token_start : token_end + 1]
90+
flat_ref = " ".join(ref_tokens)
91+
92+
references_components.append({'Reference': flat_ref, 'Attributes': list(zip(ref_tokens, ref_components))})
93+
94+
return references_components

tests/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ def get_path(p):
99

1010

1111
TEST_CFG = get_path("test_data/test_config.ini")
12+
TEST_CFG_MULTITASK = get_path("test_data/test_config_multitask.ini")
1213
TEST_JSONL = get_path("test_data/test_jsonl.jsonl")
1314
TEST_REFERENCES = get_path("test_data/test_references.txt")
1415
TEST_TSV_PREDICT = get_path("test_data/test_tsv_predict.tsv")
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
[DEFAULT]
2+
version = test
3+
4+
[data]
5+
test_proportion = 0.25
6+
valid_proportion = 0.25
7+
data_path = data/
8+
respect_line_endings = 0
9+
respect_doc_endings = 1
10+
line_limit = 150
11+
rodrigues_train = data/rodrigues/clean_test.txt
12+
rodrigues_test =
13+
rodrigues_valid =
14+
policy_train = data/processed/annotated/deep_reference_parser/multitask/2020.3.18_multitask_test.tsv
15+
policy_test = data/processed/annotated/deep_reference_parser/multitask/2020.3.18_multitask_test.tsv
16+
policy_valid = data/processed/annotated/deep_reference_parser/multitask/2020.3.18_multitask_test.tsv
17+
# This needs to have a trailing slash!
18+
s3_slug = https://datalabs-public.s3.eu-west-2.amazonaws.com/deep_reference_parser/
19+
20+
[build]
21+
output_path = models/multitask/2020.4.5_multitask/
22+
output = crf
23+
word_embeddings = embeddings/2020.1.1-wellcome-embeddings-300-test.txt
24+
pretrained_embedding = 0
25+
dropout = 0.5
26+
lstm_hidden = 400
27+
word_embedding_size = 300
28+
char_embedding_size = 100
29+
char_embedding_type = BILSTM
30+
optimizer = adam
31+
32+
[train]
33+
epochs = 60
34+
batch_size = 100
35+
early_stopping_patience = 5
36+
metric = val_f1
37+
38+
[evaluate]
39+
out_file = evaluation_data.tsv

tests/test_deep_reference_parser_entrypoints.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from deep_reference_parser.split import Splitter
88
from deep_reference_parser.split_parse import SplitParser
99

10-
from .common import TEST_CFG, TEST_REFERENCES
10+
from .common import TEST_CFG, TEST_CFG_MULTITASK, TEST_REFERENCES
1111

1212

1313
@pytest.fixture
@@ -22,7 +22,7 @@ def parser():
2222

2323
@pytest.fixture
2424
def split_parser():
25-
return SplitParser(TEST_CFG)
25+
return SplitParser(TEST_CFG_MULTITASK)
2626

2727

2828
@pytest.fixture
@@ -67,7 +67,7 @@ def test_split_parser_list_output(text, split_parser):
6767
If the model artefacts and embeddings are not present this test will
6868
downloaded them, which can be slow.
6969
"""
70-
out = split_parser.split_parse(text, verbose=False)
70+
out = split_parser.split_parse(text, return_tokens=False, verbose=False)
7171
print(out)
7272

7373
assert isinstance(out, list)
@@ -100,13 +100,10 @@ def test_parser_tokens_output(text, parser):
100100
def test_split_parser_tokens_output(text, split_parser):
101101
"""
102102
"""
103-
out = split_parser.split_parse(text, verbose=False)
103+
out = split_parser.split_parse(text, return_tokens=True, verbose=False)
104104

105-
assert isinstance(out, list)
106-
107-
# NOTE: full functionality of split_parse is not yet implemented.
108-
109-
# assert isinstance(out[0], tuple)
110-
# assert len(out[0]) == 2
111-
# assert isinstance(out[0][0], str)
112-
# assert isinstance(out[0][1], str)
105+
assert isinstance(out[0], tuple)
106+
assert len(out[0]) == 3
107+
assert isinstance(out[0][0], str)
108+
assert isinstance(out[0][1], str)
109+
assert isinstance(out[0][2], str)

0 commit comments

Comments
 (0)