Skip to content

Commit c36b74e

Browse files
HaokunLiusleepinyourhat
authored andcommitted
Fixing index problem & minor pytorch_transformers_interface cleanup (#916)
* update boundry func with offsets * update tasks that use indexes * remove outdated temporary fix
1 parent 10fb192 commit c36b74e

File tree

3 files changed

+107
-58
lines changed

3 files changed

+107
-58
lines changed

jiant/preprocess.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -619,13 +619,6 @@ def add_pytorch_transformers_vocab(vocab, tokenizer_name):
619619

620620
vocab_size = len(tokenizer)
621621
# do not use tokenizer.vocab_size, it does not include newly added token
622-
if tokenizer_name.startswith("roberta-"):
623-
if tokenizer.convert_ids_to_tokens(vocab_size - 1) is None:
624-
vocab_size -= 1
625-
else:
626-
log.info("Time to delete vocab_size-1 in preprocess.py !!!")
627-
# due to a quirk in huggingface's file, the last token of RobertaTokenizer is None, remove
628-
# this when they fix the problem
629622

630623
ordered_vocab = tokenizer.convert_ids_to_tokens(range(vocab_size))
631624
log.info("Added pytorch_transformers vocab (%s): %d tokens", tokenizer_name, len(ordered_vocab))

jiant/pytorch_transformers_interface/modules.py

Lines changed: 90 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,11 @@ def parameter_setup(self, args):
8888
def correct_sent_indexing(self, sent):
8989
""" Correct id difference between pytorch_transformers and AllenNLP.
9090
The AllenNLP indexer adds'@@UNKNOWN@@' token as index 1, and '@@PADDING@@' as index 0
91-
91+
9292
args:
93-
sent: batch dictionary, in which
93+
sent: batch dictionary, in which
9494
sent[self.tokenizer_required]: <long> [batch_size, var_seq_len] input token IDs
95-
95+
9696
returns:
9797
ids: <long> [bath_size, var_seq_len] corrected token IDs
9898
input_mask: <long> [bath_size, var_seq_len] mask of input sequence
@@ -185,42 +185,47 @@ def get_seg_ids(self, token_ids, input_mask):
185185
return seg_ids
186186

187187
@staticmethod
188-
def apply_boundary_tokens(s1, s2=None):
188+
def apply_boundary_tokens(s1, s2=None, get_offset=False):
189189
"""
190190
A function that appliese the appropriate EOS/SOS/SEP/CLS tokens to token sequence or
191-
token sequence pair for most tasks.
191+
token sequence pair for most tasks.
192192
This function should be implmented in subclasses.
193-
193+
194194
args:
195195
s1: list[str], tokens from sentence 1
196196
s2: list[str] (optional), tokens from sentence 2, used for pair embedding
197-
197+
get_offset: bool, returns offset if True
198+
198199
returns
199200
s: list[str], token sequence with boundry tokens
201+
offset_s1 (optional): int, index offset of s1
202+
offset_s2 (optional): int, index offset of s2
200203
"""
201204
raise NotImplementedError
202205

203206
@staticmethod
204-
def apply_lm_boundary_tokens(s1):
207+
def apply_lm_boundary_tokens(s1, get_offset=False):
205208
"""
206209
A function that appliese the appropriate EOS/SOS/SEP/CLS tokens to a token sequence for
207210
language modeling tasks.
208211
This function should be implmented in subclasses.
209-
212+
210213
args:
211214
s1: list[str], tokens from sentence
212-
215+
get_offset: bool, returns offset if True
216+
213217
returns
214218
s: list[str], token sequence with boundry tokens
219+
offset_s1 (optional): int, index offset of s1
215220
"""
216221
raise NotImplementedError
217222

218223
def forward(self, sent, task_name):
219-
""" Run pytorch_transformers model and return output representation
224+
""" Run pytorch_transformers model and return output representation
220225
This function should be implmented in subclasses.
221-
226+
222227
args:
223-
sent: batch dictionary, in which
228+
sent: batch dictionary, in which
224229
sent[self.tokenizer_required]: <long> [batch_size, var_seq_len] input token IDs
225230
task_name: task_name string, this can used to implement different mixing scalars for
226231
differnt tasks. See the TODO in parameter_setup for more details.
@@ -235,7 +240,7 @@ def get_pretrained_lm_head(self):
235240
weight to the input token embedding. In most cases, this module needs to work with
236241
output_mode as "top" or "none"
237242
This function should be implmented in subclasses.
238-
243+
239244
returns:
240245
lm_head: module [*, hidden_size] -> [*, vocab_size]
241246
"""
@@ -265,12 +270,17 @@ def __init__(self, args):
265270
self.parameter_setup(args)
266271

267272
@staticmethod
268-
def apply_boundary_tokens(s1, s2=None):
273+
def apply_boundary_tokens(s1, s2=None, get_offset=False):
269274
# BERT-style boundary token padding on string token sequences
270275
if s2:
271-
return ["[CLS]"] + s1 + ["[SEP]"] + s2 + ["[SEP]"]
276+
s = ["[CLS]"] + s1 + ["[SEP]"] + s2 + ["[SEP]"]
277+
if get_offset:
278+
return s, 1, len(s1) + 2
272279
else:
273-
return ["[CLS]"] + s1 + ["[SEP]"]
280+
s = ["[CLS]"] + s1 + ["[SEP]"]
281+
if get_offset:
282+
return s, 1
283+
return s
274284

275285
def forward(self, sent: Dict[str, torch.LongTensor], task_name: str = "") -> torch.FloatTensor:
276286
ids, input_mask = self.correct_sent_indexing(sent)
@@ -317,12 +327,17 @@ def __init__(self, args):
317327
self.parameter_setup(args)
318328

319329
@staticmethod
320-
def apply_boundary_tokens(s1, s2=None):
330+
def apply_boundary_tokens(s1, s2=None, get_offset=False):
321331
# RoBERTa-style boundary token padding on string token sequences
322332
if s2:
323-
return ["<s>"] + s1 + ["</s>", "</s>"] + s2 + ["</s>"]
333+
s = ["<s>"] + s1 + ["</s>", "</s>"] + s2 + ["</s>"]
334+
if get_offset:
335+
return s, 1, len(s1) + 3
324336
else:
325-
return ["<s>"] + s1 + ["</s>"]
337+
s = ["<s>"] + s1 + ["</s>"]
338+
if get_offset:
339+
return s, 1
340+
return s
326341

327342
def forward(self, sent: Dict[str, torch.LongTensor], task_name: str = "") -> torch.FloatTensor:
328343
ids, input_mask = self.correct_sent_indexing(sent)
@@ -372,12 +387,17 @@ def __init__(self, args):
372387
self._SEG_ID_SEP = 3
373388

374389
@staticmethod
375-
def apply_boundary_tokens(s1, s2=None):
390+
def apply_boundary_tokens(s1, s2=None, get_offset=False):
376391
# XLNet-style boundary token marking on string token sequences
377392
if s2:
378-
return s1 + ["<sep>"] + s2 + ["<sep>", "<cls>"]
393+
s = s1 + ["<sep>"] + s2 + ["<sep>", "<cls>"]
394+
if get_offset:
395+
return s, 0, len(s1) + 1
379396
else:
380-
return s1 + ["<sep>", "<cls>"]
397+
s = s1 + ["<sep>", "<cls>"]
398+
if get_offset:
399+
return s, 0
400+
return s
381401

382402
def forward(self, sent: Dict[str, torch.LongTensor], task_name: str = "") -> torch.FloatTensor:
383403
ids, input_mask = self.correct_sent_indexing(sent)
@@ -425,17 +445,25 @@ def __init__(self, args):
425445
self.parameter_setup(args)
426446

427447
@staticmethod
428-
def apply_boundary_tokens(s1, s2=None):
448+
def apply_boundary_tokens(s1, s2=None, get_offset=False):
429449
# OpenAI-GPT-style boundary token marking on string token sequences
430450
if s2:
431-
return ["<start>"] + s1 + ["<delim>"] + s2 + ["<extract>"]
451+
s = ["<start>"] + s1 + ["<delim>"] + s2 + ["<extract>"]
452+
if get_offset:
453+
return s, 1, len(s1) + 2
432454
else:
433-
return ["<start>"] + s1 + ["<extract>"]
455+
s = ["<start>"] + s1 + ["<extract>"]
456+
if get_offset:
457+
return s, 1
458+
return s
434459

435460
@staticmethod
436-
def apply_lm_boundary_tokens(s1):
461+
def apply_lm_boundary_tokens(s1, get_offset=False):
437462
# OpenAI-GPT-style boundary token marking on string token sequences for LM tasks
438-
return ["\n</w>"] + s1 + ["\n</w>"]
463+
s = ["\n</w>"] + s1 + ["\n</w>"]
464+
if get_offset:
465+
return s, 1
466+
return s
439467

440468
def forward(self, sent: Dict[str, torch.LongTensor], task_name: str = "") -> torch.FloatTensor:
441469
ids, input_mask = self.correct_sent_indexing(sent)
@@ -479,17 +507,25 @@ def __init__(self, args):
479507
self.parameter_setup(args)
480508

481509
@staticmethod
482-
def apply_boundary_tokens(s1, s2=None):
510+
def apply_boundary_tokens(s1, s2=None, get_offset=False):
483511
# GPT-2-style boundary token marking on string token sequences
484512
if s2:
485-
return ["<start>"] + s1 + ["<delim>"] + s2 + ["<extract>"]
513+
s = ["<start>"] + s1 + ["<delim>"] + s2 + ["<extract>"]
514+
if get_offset:
515+
return s, 1, len(s1) + 2
486516
else:
487-
return ["<start>"] + s1 + ["<extract>"]
517+
s = ["<start>"] + s1 + ["<extract>"]
518+
if get_offset:
519+
return s, 1
520+
return s
488521

489522
@staticmethod
490-
def apply_lm_boundary_tokens(s1):
523+
def apply_lm_boundary_tokens(s1, get_offset=False):
491524
# GPT-2-style boundary token marking on string token sequences for LM tasks
492-
return ["<|endoftext|>"] + s1 + ["<|endoftext|>"]
525+
s = ["<|endoftext|>"] + s1 + ["<|endoftext|>"]
526+
if get_offset:
527+
return s, 1
528+
return s
493529

494530
def forward(self, sent: Dict[str, torch.LongTensor], task_name: str = "") -> torch.FloatTensor:
495531
ids, input_mask = self.correct_sent_indexing(sent)
@@ -533,17 +569,25 @@ def __init__(self, args):
533569
self.parameter_setup(args)
534570

535571
@staticmethod
536-
def apply_boundary_tokens(s1, s2=None):
572+
def apply_boundary_tokens(s1, s2=None, get_offset=False):
537573
# TransformerXL-style boundary token marking on string token sequences
538574
if s2:
539-
return ["<start>"] + s1 + ["<delim>"] + s2 + ["<extract>"]
575+
s = ["<start>"] + s1 + ["<delim>"] + s2 + ["<extract>"]
576+
if get_offset:
577+
return s, 1, len(s1) + 2
540578
else:
541-
return ["<start>"] + s1 + ["<extract>"]
579+
s = ["<start>"] + s1 + ["<extract>"]
580+
if get_offset:
581+
return s, 1
582+
return s
542583

543584
@staticmethod
544-
def apply_lm_boundary_tokens(s1):
585+
def apply_lm_boundary_tokens(s1, get_offset=False):
545586
# TransformerXL-style boundary token marking on string token sequences for LM tasks
546-
return ["<\n>"] + s1 + ["<\n>"]
587+
s = ["<\n>"] + s1 + ["<\n>"]
588+
if get_offset:
589+
return s, 1
590+
return s
547591

548592
def forward(self, sent: Dict[str, torch.LongTensor], task_name: str = "") -> torch.FloatTensor:
549593
ids, input_mask = self.correct_sent_indexing(sent)
@@ -592,12 +636,17 @@ def __init__(self, args):
592636
self.parameter_setup(args)
593637

594638
@staticmethod
595-
def apply_boundary_tokens(s1, s2=None):
639+
def apply_boundary_tokens(s1, s2=None, get_offset=False):
596640
# XLM-style boundary token marking on string token sequences
597641
if s2:
598-
return ["</s>"] + s1 + ["</s>"] + s2 + ["</s>"]
642+
s = ["</s>"] + s1 + ["</s>"] + s2 + ["</s>"]
643+
if get_offset:
644+
return s, 1, len(s1) + 2
599645
else:
600-
return ["</s>"] + s1 + ["</s>"]
646+
s = ["</s>"] + s1 + ["</s>"]
647+
if get_offset:
648+
return s, 1, len(s1) + 1
649+
return s
601650

602651
def forward(self, sent: Dict[str, torch.LongTensor], task_name: str = "") -> torch.FloatTensor:
603652
ids, input_mask = self.correct_sent_indexing(sent)

jiant/tasks/tasks.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2449,7 +2449,7 @@ def _make_span_field(self, s, text_field, offset=1):
24492449
def make_instance(self, record, idx, indexers, model_preprocessing_interface) -> Type[Instance]:
24502450
"""Convert a single record to an AllenNLP Instance."""
24512451
tokens = record["text"].split()
2452-
tokens = model_preprocessing_interface.boundary_token_fn(tokens)
2452+
tokens, offset = model_preprocessing_interface.boundary_token_fn(tokens, get_offset=True)
24532453
text_field = sentence_to_text_field(tokens, indexers)
24542454

24552455
example = {}
@@ -2459,7 +2459,7 @@ def make_instance(self, record, idx, indexers, model_preprocessing_interface) ->
24592459

24602460
for i in range(self.num_spans):
24612461
example["span" + str(i + 1) + "s"] = ListField(
2462-
[self._make_span_field(record["target"]["span" + str(i + 1)], text_field, 1)]
2462+
[self._make_span_field(record["target"]["span" + str(i + 1)], text_field, offset)]
24632463
)
24642464
example["labels"] = LabelField(
24652465
record["label"], label_namespace="labels", skip_indexing=True
@@ -2657,18 +2657,25 @@ def _make_instance(input1, input2, idxs1, idxs2, labels, idx):
26572657
d["sent1_str"] = MetadataField(" ".join(input1))
26582658
d["sent2_str"] = MetadataField(" ".join(input2))
26592659
if model_preprocessing_interface.model_flags["uses_pair_embedding"]:
2660-
inp = model_preprocessing_interface.boundary_token_fn(input1, input2)
2660+
inp, offset1, offset2 = model_preprocessing_interface.boundary_token_fn(
2661+
input1, input2, get_offset=True
2662+
)
26612663
d["inputs"] = sentence_to_text_field(inp, indexers)
2662-
idxs2 = (idxs2[0] + len(input1), idxs2[1] + len(input1))
26632664
else:
2664-
d["input1"] = sentence_to_text_field(
2665-
model_preprocessing_interface.boundary_token_fn(input1), indexers
2665+
inp1, offset1 = model_preprocessing_interface.boundary_token_fn(
2666+
input1, get_offset=True
26662667
)
2667-
d["input2"] = sentence_to_text_field(
2668-
model_preprocessing_interface.boundary_token_fn(input2), indexers
2668+
inp2, offset2 = model_preprocessing_interface.boundary_token_fn(
2669+
input2, get_offset=True
26692670
)
2670-
d["idx1"] = ListField([NumericField(i) for i in range(idxs1[0], idxs1[1])])
2671-
d["idx2"] = ListField([NumericField(i) for i in range(idxs2[0], idxs2[1])])
2671+
d["input1"] = sentence_to_text_field(inp1, indexers)
2672+
d["input2"] = sentence_to_text_field(inp2, indexers)
2673+
d["idx1"] = ListField(
2674+
[NumericField(i) for i in range(idxs1[0] + offset1, idxs1[1] + offset1)]
2675+
)
2676+
d["idx2"] = ListField(
2677+
[NumericField(i) for i in range(idxs2[0] + offset2, idxs2[1] + offset2)]
2678+
)
26722679
d["labels"] = LabelField(labels, label_namespace="labels", skip_indexing=True)
26732680
d["idx"] = LabelField(idx, label_namespace="idxs_tags", skip_indexing=True)
26742681

0 commit comments

Comments
 (0)