Skip to content

Commit 978d8b7

Browse files
author
Yada Pruksachatkun
committed
2 parents 0e622dd + b43fba6 commit 978d8b7

File tree

6 files changed

+185
-77
lines changed

6 files changed

+185
-77
lines changed

config/superglue-bert.conf

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ exp_name = "bert-large-cased"
88
max_seq_len = 256 // Mainly needed for MultiRC, to avoid over-truncating
99
// But not 512 as that is really hard to fit in memory.
1010
tokenizer = "bert-large-cased"
11-
1211
// Model settings
1312
bert_model_name = "bert-large-cased"
1413
bert_embeddings_mode = "top"
@@ -42,3 +41,6 @@ do_full_eval = 1
4241
write_preds = "val,test"
4342
write_strict_glue_format = 1
4443

44+
// For WSC
45+
classifier_loss_fn = "softmax"
46+
classifier_span_pooling = "attn"

src/tasks/tasks.py

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,13 @@
2727
from ..allennlp_mods.correlation import Correlation
2828
from ..allennlp_mods.numeric_field import NumericField
2929
from ..utils import utils
30-
from ..utils.data_loaders import get_tag_list, load_diagnostic_tsv, load_tsv, process_sentence
30+
from ..utils.data_loaders import (
31+
get_tag_list,
32+
load_diagnostic_tsv,
33+
load_span_data,
34+
load_tsv,
35+
process_sentence,
36+
)
3137
from ..utils.tokenizers import get_tokenizer
3238
from .registry import register_task # global task registry
3339

@@ -1995,24 +2001,13 @@ def load_data(self):
19952001

19962002
class SpanClassificationTask(Task):
19972003
"""
1998-
Generic class for span tasks.
2004+
Generic class for span tasks.
19992005
Acts as a classifier, but with multiple targets for each input text.
20002006
Targets are of the form (span1, span2,..., span_n, label), where the spans are
20012007
half-open token intervals [i, j).
20022008
The number of spans is constant across examples.
20032009
"""
20042010

2005-
@property
2006-
def _tokenizer_suffix(self):
2007-
""""
2008-
Suffix to make sure we use the correct source files,
2009-
based on the given tokenizer.
2010-
"""
2011-
if self.tokenizer_name:
2012-
return ".retokenized." + self.tokenizer_name
2013-
else:
2014-
return ""
2015-
20162011
def tokenizer_is_supported(self, tokenizer_name):
20172012
""" Check if the tokenizer is supported for this task. """
20182013
# Assume all tokenizers supported; if retokenized data not found
@@ -2049,8 +2044,7 @@ def __init__(
20492044
assert label_file is not None
20502045
assert files_by_split is not None
20512046
self._files_by_split = {
2052-
split: os.path.join(path, fname) + self._tokenizer_suffix
2053-
for split, fname in files_by_split.items()
2047+
split: os.path.join(path, fname) for split, fname in files_by_split.items()
20542048
}
20552049
self.num_spans = num_spans
20562050
self.max_seq_len = max_seq_len
@@ -2089,15 +2083,6 @@ def _stream_records(self, filename):
20892083
filename,
20902084
)
20912085

2092-
def load_data(self):
2093-
iters_by_split = collections.OrderedDict()
2094-
for split, filename in self._files_by_split.items():
2095-
iter = list(self._stream_records(filename))
2096-
iters_by_split[split] = iter
2097-
self._iters_by_split = iters_by_split
2098-
self.all_labels = list(utils.load_lines(self.label_file))
2099-
self.n_classes = len(self.all_labels)
2100-
21012086
def get_split_text(self, split: str):
21022087
"""
21032088
Get split text as iterable of records.
@@ -2139,19 +2124,15 @@ def make_instance(self, record, idx, indexers) -> Type[Instance]:
21392124

21402125
for i in range(self.num_spans):
21412126
example["span" + str(i + 1) + "s"] = ListField(
2142-
[
2143-
self._make_span_field(t["span" + str(i + 1)], text_field, 1)
2144-
for t in record["targets"]
2145-
]
2127+
[self._make_span_field(record["target"]["span" + str(i + 1)], text_field, 1)]
21462128
)
2147-
2148-
labels = [utils.wrap_singleton_string(t["label"]) for t in record["targets"]]
21492129
example["labels"] = ListField(
21502130
[
21512131
MultiLabelField(
2152-
label_set, label_namespace=self._label_namespace, skip_indexing=False
2132+
[str(record["label"])],
2133+
label_namespace=self._label_namespace,
2134+
skip_indexing=False,
21532135
)
2154-
for label_set in labels
21552136
]
21562137
)
21572138
return Instance(example)
@@ -2533,17 +2514,28 @@ def get_metrics(self, reset=False):
25332514
@register_task("winograd-coreference", rel_path="winograd-coref")
25342515
class WinogradCoreferenceTask(SpanClassificationTask):
25352516
def __init__(self, path, **kw):
2536-
self._files_by_split = {
2537-
"train": "train.jsonl",
2538-
"val": "val.jsonl",
2539-
"test": "test_with_labels.jsonl",
2540-
}
2517+
self._files_by_split = {"train": "train.jsonl", "val": "val.jsonl", "test": "test.jsonl"}
25412518
self.num_spans = 2
25422519
super().__init__(
25432520
files_by_split=self._files_by_split, label_file="labels.txt", path=path, **kw
25442521
)
2522+
self.n_classes = 2
25452523
self.val_metric = "%s_acc" % self.name
25462524

2525+
def load_data(self):
2526+
iters_by_split = collections.OrderedDict()
2527+
for split, filename in self._files_by_split.items():
2528+
if filename.endswith("test.jsonl"):
2529+
iters_by_split[split] = load_span_data(
2530+
self.tokenizer_name, filename, has_labels=False
2531+
)
2532+
else:
2533+
iters_by_split[split] = load_span_data(self.tokenizer_name, filename)
2534+
self._iters_by_split = iters_by_split
2535+
2536+
def get_all_labels(self):
2537+
return ["True", "False"]
2538+
25472539
def update_metrics(self, logits, labels, tagmask=None):
25482540
logits, labels = logits.detach(), labels.detach()
25492541

src/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1080,7 +1080,7 @@ def _save_checkpoint(self, training_state, phase="pretrain", new_best_macro=Fals
10801080
training_state,
10811081
os.path.join(
10821082
self._serialization_dir,
1083-
"pretraining_state_{}_epoch_{}{}.th".format(phase, epoch, best_str),
1083+
"metric_state_{}_epoch_{}{}.th".format(phase, epoch, best_str),
10841084
),
10851085
)
10861086

src/utils/data_loaders.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,39 @@
1111
from allennlp.data import vocabulary
1212

1313
from .tokenizers import get_tokenizer
14+
from .retokenize import realign_spans
1415

1516
BERT_CLS_TOK, BERT_SEP_TOK = "[CLS]", "[SEP]"
1617
SOS_TOK, EOS_TOK = "<SOS>", "<EOS>"
1718

1819

20+
def load_span_data(tokenizer_name, file_name, label_fn=None, has_labels=True):
21+
"""
22+
Load a span-related task file in .jsonl format, does re-alignment of spans, and tokenizes the text.
23+
Re-alignment of spans involves transforming the spans so that it matches the text after
24+
tokenization.
25+
For example, given the original text: [Mr., Porter, is, nice] and bert-base-cased tokenization, we get
26+
[Mr, ., Por, ter, is, nice ]. If the original span indices was [0,2], under the new tokenization,
27+
it becomes [0, 3].
28+
The task file should of be of the following form:
29+
text: str,
30+
label: bool
31+
target: dict that contains the spans
32+
Args:
33+
tokenizer_name: str,
34+
file_name: str,
35+
label_fn: function that expects a row and outputs a transformed row with labels tarnsformed.
36+
Returns:
37+
List of dictionaries of the aligned spans and tokenized text.
38+
"""
39+
rows = pd.read_json(file_name, lines=True)
40+
# realign spans
41+
rows = rows.apply(lambda x: realign_spans(x, tokenizer_name), axis=1)
42+
if has_labels is False:
43+
rows["label"] = False
44+
return list(rows.T.to_dict().values())
45+
46+
1947
def load_tsv(
2048
tokenizer_name,
2149
data_file,

src/utils/retokenize.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,79 @@ def _mat_from_spans_sparse(spans: Sequence[Tuple[int, int]], n_chars: int) -> Ma
9494
return sparse.csr_matrix((data, (ridxs, cidxs)), shape=(len(spans), n_chars))
9595

9696

97+
def realign_spans(record, tokenizer_name):
98+
"""
99+
Builds the indices alignment while also tokenizing the input
100+
piece by piece.
101+
Only BERT and Moses tokenization is supported currently.
102+
103+
Parameters
104+
-----------------------
105+
record: dict with the below fields
106+
text: str
107+
targets: list of dictionaries
108+
label: bool
109+
span1_index: int, start index of first span
110+
span1_text: str, text of first span
111+
span2_index: int, start index of second span
112+
span2_text: str, text of second span
113+
tokenizer_name: str
114+
115+
Returns
116+
------------------------
117+
record: dict with the below fields:
118+
text: str in tokenized form
119+
targets: dictionary with the below fields
120+
-label: bool
121+
-span_1: (int, int) of token indices
122+
-span1_text: str, the string
123+
-span2: (int, int) of token indices
124+
-span2_text: str, the string
125+
"""
126+
127+
# find span indices and text
128+
text = record["text"].split()
129+
span1 = record["target"]["span1_index"]
130+
span1_text = record["target"]["span1_text"]
131+
span2 = record["target"]["span2_index"]
132+
span2_text = record["target"]["span2_text"]
133+
134+
# construct end spans given span text space-tokenized length
135+
span1 = [span1, span1 + len(span1_text.strip().split())]
136+
span2 = [span2, span2 + len(span2_text.strip().split())]
137+
indices = [span1, span2]
138+
139+
sorted_indices = sorted(indices, key=lambda x: x[0])
140+
current_tokenization = []
141+
span_mapping = {}
142+
143+
# align first span to tokenized text
144+
aligner_fn = get_aligner_fn(tokenizer_name)
145+
_, new_tokens = aligner_fn(" ".join(text[: sorted_indices[0][0]]))
146+
current_tokenization.extend(new_tokens)
147+
new_span1start = len(current_tokenization)
148+
_, span_tokens = aligner_fn(" ".join(text[sorted_indices[0][0] : sorted_indices[0][1]]))
149+
current_tokenization.extend(span_tokens)
150+
new_span1end = len(current_tokenization)
151+
span_mapping[sorted_indices[0][0]] = [new_span1start, new_span1end]
152+
153+
# re-indexing second span
154+
_, new_tokens = aligner_fn(" ".join(text[sorted_indices[0][1] : sorted_indices[1][0]]))
155+
current_tokenization.extend(new_tokens)
156+
new_span2start = len(current_tokenization)
157+
_, span_tokens = aligner_fn(" ".join(text[sorted_indices[1][0] : sorted_indices[1][1]]))
158+
current_tokenization.extend(span_tokens)
159+
new_span2end = len(current_tokenization)
160+
span_mapping[sorted_indices[1][0]] = [new_span2start, new_span2end]
161+
162+
# save back into record
163+
_, all_text = aligner_fn(" ".join(text))
164+
record["target"]["span1"] = span_mapping[record["target"]["span1_index"]]
165+
record["target"]["span2"] = span_mapping[record["target"]["span2_index"]]
166+
record["text"] = " ".join(all_text)
167+
return record
168+
169+
97170
class TokenAligner(object):
98171
"""Align two similiar tokenizations.
99172

0 commit comments

Comments
 (0)