@@ -88,11 +88,11 @@ def parameter_setup(self, args):
88
88
def correct_sent_indexing (self , sent ):
89
89
""" Correct id difference between pytorch_transformers and AllenNLP.
90
90
The AllenNLP indexer adds'@@UNKNOWN@@' token as index 1, and '@@PADDING@@' as index 0
91
-
91
+
92
92
args:
93
- sent: batch dictionary, in which
93
+ sent: batch dictionary, in which
94
94
sent[self.tokenizer_required]: <long> [batch_size, var_seq_len] input token IDs
95
-
95
+
96
96
returns:
97
97
ids: <long> [bath_size, var_seq_len] corrected token IDs
98
98
input_mask: <long> [bath_size, var_seq_len] mask of input sequence
@@ -185,42 +185,47 @@ def get_seg_ids(self, token_ids, input_mask):
185
185
return seg_ids
186
186
187
187
@staticmethod
188
- def apply_boundary_tokens (s1 , s2 = None ):
188
+ def apply_boundary_tokens (s1 , s2 = None , get_offset = False ):
189
189
"""
190
190
A function that appliese the appropriate EOS/SOS/SEP/CLS tokens to token sequence or
191
- token sequence pair for most tasks.
191
+ token sequence pair for most tasks.
192
192
This function should be implmented in subclasses.
193
-
193
+
194
194
args:
195
195
s1: list[str], tokens from sentence 1
196
196
s2: list[str] (optional), tokens from sentence 2, used for pair embedding
197
-
197
+ get_offset: bool, returns offset if True
198
+
198
199
returns
199
200
s: list[str], token sequence with boundry tokens
201
+ offset_s1 (optional): int, index offset of s1
202
+ offset_s2 (optional): int, index offset of s2
200
203
"""
201
204
raise NotImplementedError
202
205
203
206
@staticmethod
204
- def apply_lm_boundary_tokens (s1 ):
207
+ def apply_lm_boundary_tokens (s1 , get_offset = False ):
205
208
"""
206
209
A function that appliese the appropriate EOS/SOS/SEP/CLS tokens to a token sequence for
207
210
language modeling tasks.
208
211
This function should be implmented in subclasses.
209
-
212
+
210
213
args:
211
214
s1: list[str], tokens from sentence
212
-
215
+ get_offset: bool, returns offset if True
216
+
213
217
returns
214
218
s: list[str], token sequence with boundry tokens
219
+ offset_s1 (optional): int, index offset of s1
215
220
"""
216
221
raise NotImplementedError
217
222
218
223
def forward (self , sent , task_name ):
219
- """ Run pytorch_transformers model and return output representation
224
+ """ Run pytorch_transformers model and return output representation
220
225
This function should be implmented in subclasses.
221
-
226
+
222
227
args:
223
- sent: batch dictionary, in which
228
+ sent: batch dictionary, in which
224
229
sent[self.tokenizer_required]: <long> [batch_size, var_seq_len] input token IDs
225
230
task_name: task_name string, this can used to implement different mixing scalars for
226
231
differnt tasks. See the TODO in parameter_setup for more details.
@@ -235,7 +240,7 @@ def get_pretrained_lm_head(self):
235
240
weight to the input token embedding. In most cases, this module needs to work with
236
241
output_mode as "top" or "none"
237
242
This function should be implmented in subclasses.
238
-
243
+
239
244
returns:
240
245
lm_head: module [*, hidden_size] -> [*, vocab_size]
241
246
"""
@@ -265,12 +270,17 @@ def __init__(self, args):
265
270
self .parameter_setup (args )
266
271
267
272
@staticmethod
268
- def apply_boundary_tokens (s1 , s2 = None ):
273
+ def apply_boundary_tokens (s1 , s2 = None , get_offset = False ):
269
274
# BERT-style boundary token padding on string token sequences
270
275
if s2 :
271
- return ["[CLS]" ] + s1 + ["[SEP]" ] + s2 + ["[SEP]" ]
276
+ s = ["[CLS]" ] + s1 + ["[SEP]" ] + s2 + ["[SEP]" ]
277
+ if get_offset :
278
+ return s , 1 , len (s1 ) + 2
272
279
else :
273
- return ["[CLS]" ] + s1 + ["[SEP]" ]
280
+ s = ["[CLS]" ] + s1 + ["[SEP]" ]
281
+ if get_offset :
282
+ return s , 1
283
+ return s
274
284
275
285
def forward (self , sent : Dict [str , torch .LongTensor ], task_name : str = "" ) -> torch .FloatTensor :
276
286
ids , input_mask = self .correct_sent_indexing (sent )
@@ -317,12 +327,17 @@ def __init__(self, args):
317
327
self .parameter_setup (args )
318
328
319
329
@staticmethod
320
- def apply_boundary_tokens (s1 , s2 = None ):
330
+ def apply_boundary_tokens (s1 , s2 = None , get_offset = False ):
321
331
# RoBERTa-style boundary token padding on string token sequences
322
332
if s2 :
323
- return ["<s>" ] + s1 + ["</s>" , "</s>" ] + s2 + ["</s>" ]
333
+ s = ["<s>" ] + s1 + ["</s>" , "</s>" ] + s2 + ["</s>" ]
334
+ if get_offset :
335
+ return s , 1 , len (s1 ) + 3
324
336
else :
325
- return ["<s>" ] + s1 + ["</s>" ]
337
+ s = ["<s>" ] + s1 + ["</s>" ]
338
+ if get_offset :
339
+ return s , 1
340
+ return s
326
341
327
342
def forward (self , sent : Dict [str , torch .LongTensor ], task_name : str = "" ) -> torch .FloatTensor :
328
343
ids , input_mask = self .correct_sent_indexing (sent )
@@ -372,12 +387,17 @@ def __init__(self, args):
372
387
self ._SEG_ID_SEP = 3
373
388
374
389
@staticmethod
375
- def apply_boundary_tokens (s1 , s2 = None ):
390
+ def apply_boundary_tokens (s1 , s2 = None , get_offset = False ):
376
391
# XLNet-style boundary token marking on string token sequences
377
392
if s2 :
378
- return s1 + ["<sep>" ] + s2 + ["<sep>" , "<cls>" ]
393
+ s = s1 + ["<sep>" ] + s2 + ["<sep>" , "<cls>" ]
394
+ if get_offset :
395
+ return s , 0 , len (s1 ) + 1
379
396
else :
380
- return s1 + ["<sep>" , "<cls>" ]
397
+ s = s1 + ["<sep>" , "<cls>" ]
398
+ if get_offset :
399
+ return s , 0
400
+ return s
381
401
382
402
def forward (self , sent : Dict [str , torch .LongTensor ], task_name : str = "" ) -> torch .FloatTensor :
383
403
ids , input_mask = self .correct_sent_indexing (sent )
@@ -425,17 +445,25 @@ def __init__(self, args):
425
445
self .parameter_setup (args )
426
446
427
447
@staticmethod
428
- def apply_boundary_tokens (s1 , s2 = None ):
448
+ def apply_boundary_tokens (s1 , s2 = None , get_offset = False ):
429
449
# OpenAI-GPT-style boundary token marking on string token sequences
430
450
if s2 :
431
- return ["<start>" ] + s1 + ["<delim>" ] + s2 + ["<extract>" ]
451
+ s = ["<start>" ] + s1 + ["<delim>" ] + s2 + ["<extract>" ]
452
+ if get_offset :
453
+ return s , 1 , len (s1 ) + 2
432
454
else :
433
- return ["<start>" ] + s1 + ["<extract>" ]
455
+ s = ["<start>" ] + s1 + ["<extract>" ]
456
+ if get_offset :
457
+ return s , 1
458
+ return s
434
459
435
460
@staticmethod
436
- def apply_lm_boundary_tokens (s1 ):
461
+ def apply_lm_boundary_tokens (s1 , get_offset = False ):
437
462
# OpenAI-GPT-style boundary token marking on string token sequences for LM tasks
438
- return ["\n </w>" ] + s1 + ["\n </w>" ]
463
+ s = ["\n </w>" ] + s1 + ["\n </w>" ]
464
+ if get_offset :
465
+ return s , 1
466
+ return s
439
467
440
468
def forward (self , sent : Dict [str , torch .LongTensor ], task_name : str = "" ) -> torch .FloatTensor :
441
469
ids , input_mask = self .correct_sent_indexing (sent )
@@ -479,17 +507,25 @@ def __init__(self, args):
479
507
self .parameter_setup (args )
480
508
481
509
@staticmethod
482
- def apply_boundary_tokens (s1 , s2 = None ):
510
+ def apply_boundary_tokens (s1 , s2 = None , get_offset = False ):
483
511
# GPT-2-style boundary token marking on string token sequences
484
512
if s2 :
485
- return ["<start>" ] + s1 + ["<delim>" ] + s2 + ["<extract>" ]
513
+ s = ["<start>" ] + s1 + ["<delim>" ] + s2 + ["<extract>" ]
514
+ if get_offset :
515
+ return s , 1 , len (s1 ) + 2
486
516
else :
487
- return ["<start>" ] + s1 + ["<extract>" ]
517
+ s = ["<start>" ] + s1 + ["<extract>" ]
518
+ if get_offset :
519
+ return s , 1
520
+ return s
488
521
489
522
@staticmethod
490
- def apply_lm_boundary_tokens (s1 ):
523
+ def apply_lm_boundary_tokens (s1 , get_offset = False ):
491
524
# GPT-2-style boundary token marking on string token sequences for LM tasks
492
- return ["<|endoftext|>" ] + s1 + ["<|endoftext|>" ]
525
+ s = ["<|endoftext|>" ] + s1 + ["<|endoftext|>" ]
526
+ if get_offset :
527
+ return s , 1
528
+ return s
493
529
494
530
def forward (self , sent : Dict [str , torch .LongTensor ], task_name : str = "" ) -> torch .FloatTensor :
495
531
ids , input_mask = self .correct_sent_indexing (sent )
@@ -533,17 +569,25 @@ def __init__(self, args):
533
569
self .parameter_setup (args )
534
570
535
571
@staticmethod
536
- def apply_boundary_tokens (s1 , s2 = None ):
572
+ def apply_boundary_tokens (s1 , s2 = None , get_offset = False ):
537
573
# TransformerXL-style boundary token marking on string token sequences
538
574
if s2 :
539
- return ["<start>" ] + s1 + ["<delim>" ] + s2 + ["<extract>" ]
575
+ s = ["<start>" ] + s1 + ["<delim>" ] + s2 + ["<extract>" ]
576
+ if get_offset :
577
+ return s , 1 , len (s1 ) + 2
540
578
else :
541
- return ["<start>" ] + s1 + ["<extract>" ]
579
+ s = ["<start>" ] + s1 + ["<extract>" ]
580
+ if get_offset :
581
+ return s , 1
582
+ return s
542
583
543
584
@staticmethod
544
- def apply_lm_boundary_tokens (s1 ):
585
+ def apply_lm_boundary_tokens (s1 , get_offset = False ):
545
586
# TransformerXL-style boundary token marking on string token sequences for LM tasks
546
- return ["<\n >" ] + s1 + ["<\n >" ]
587
+ s = ["<\n >" ] + s1 + ["<\n >" ]
588
+ if get_offset :
589
+ return s , 1
590
+ return s
547
591
548
592
def forward (self , sent : Dict [str , torch .LongTensor ], task_name : str = "" ) -> torch .FloatTensor :
549
593
ids , input_mask = self .correct_sent_indexing (sent )
@@ -592,12 +636,17 @@ def __init__(self, args):
592
636
self .parameter_setup (args )
593
637
594
638
@staticmethod
595
- def apply_boundary_tokens (s1 , s2 = None ):
639
+ def apply_boundary_tokens (s1 , s2 = None , get_offset = False ):
596
640
# XLM-style boundary token marking on string token sequences
597
641
if s2 :
598
- return ["</s>" ] + s1 + ["</s>" ] + s2 + ["</s>" ]
642
+ s = ["</s>" ] + s1 + ["</s>" ] + s2 + ["</s>" ]
643
+ if get_offset :
644
+ return s , 1 , len (s1 ) + 2
599
645
else :
600
- return ["</s>" ] + s1 + ["</s>" ]
646
+ s = ["</s>" ] + s1 + ["</s>" ]
647
+ if get_offset :
648
+ return s , 1 , len (s1 ) + 1
649
+ return s
601
650
602
651
def forward (self , sent : Dict [str , torch .LongTensor ], task_name : str = "" ) -> torch .FloatTensor :
603
652
ids , input_mask = self .correct_sent_indexing (sent )
0 commit comments