Skip to content

Commit a1e9abf

Browse files
XLNet support and overhaul/cleanup of BERT support (#845)
* Rename namespaces to suppress warnings. * Revert "Rename namespaces to suppress warnings." This reverts commit 0cf7b23. * Initial working-ish attempt. * Intermediate check-in... * More partial progress. * Another pass... * Fix sep/cls handling, cleanup. * Further cleanup. * Keyword name fix. * Another flag fix. * Pull debug print. * Line length cleanup. * WiC fix. * Two task setup bugs. * BoolQ typo * Improved segment handling. * Delete unused is_pair_task, other cleanup/fixes. * Fix deleted path from merge. * Fix cache path. * Address (spurious?) tokenization warning. * Select pool_type automatically to match model. h/t Haokun Liu * Config updates. * Path fix * Fix XLNet UNK handling. * Internal temporary MNLI alternate. * Revert "Internal temporary MNLI alternate." This reverts commit 455792a. * Add helper fn tests * Finish merge * Remove unused argument. * Possible ReCoRD bug fix * Cleanup * Fix merge issues. * Revert "Remove unused argument." This reverts commit 96a7c37. * Assorted responses to Alex's commenst. * Further ReCoRD fix. * @iftenney's comments. * Fix/simplify segment logic. * @W4ngatang's comments * Cleanup. * Cleanup * Fix issues with alternative embeddings_mode settings, max_layer. * More mix cleanup. * Masking fix. * Address (most of) @iftenney's comments * Tidying. * Misc cleanup. * Comment.
1 parent 23ad1a7 commit a1e9abf

38 files changed

+940
-577
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ user_config.sh
77
.idea
88
.ipynb_checkpoints/
99
perluniprops/
10+
.DS_Store

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ A few things you might want to know about `jiant`:
1010
- `jiant` is configuration-driven. You can run an enormous variety of experiments by simply writing configuration files. Of course, if you need to add any major new features, you can also easily edit or extend the code.
1111
- `jiant` contains implementations of strong baselines for the [GLUE](https://gluebenchmark.com) and [SuperGLUE](https://super.gluebenchmark.com/) benchmarks, and it's the recommended starting point for work on these benchmarks.
1212
- `jiant` was developed at [the 2018 JSALT Workshop](https://www.clsp.jhu.edu/workshops/18-workshop/) by [the General-Purpose Sentence Representation Learning](https://jsalt18-sentence-repl.github.io/) team and is maintained by [the NYU Machine Learning for Language Lab](https://wp.nyu.edu/ml2/people/), with help from [many outside collaborators](https://github.com/nyu-mll/jiant/graphs/contributors) (especially Google AI Language's [Ian Tenney](https://ai.google/research/people/IanTenney)).
13-
- `jiant` is built on [PyTorch](https://pytorch.org). It also uses many components from [AllenNLP](https://github.com/allenai/allennlp) and the HuggingFace PyTorch [implementations](https://github.com/huggingface/pytorch-pretrained-BERT) of BERT and GPT.
13+
- `jiant` is built on [PyTorch](https://pytorch.org). It also uses many components from [AllenNLP](https://github.com/allenai/allennlp) and the HuggingFace PyTorch [implementations](https://github.com/huggingface/pytorch-transformers) of GPT, BERT, and XLNet.
1414
- The name `jiant` doesn't mean much. The 'j' stands for JSALT. That's all the acronym we have.
1515

1616
## Getting Started
@@ -84,10 +84,10 @@ This package is released under the [MIT License](LICENSE.md). The material in th
8484

8585
## Acknowledgments
8686

87-
- Part of the development of `jiant` took at the 2018 Frederick Jelinek Memorial Summer Workshop on Speech and Language Technologies, and was supported by Johns Hopkins University with unrestricted gifts from Amazon, Facebook, Google, Microsoft and Mitsubishi Electric Research Laboratories.
87+
- Part of the development of `jiant` took at the 2018 Frederick Jelinek Memorial Summer Workshop on Speech and Language Technologies, and was supported by Johns Hopkins University with unrestricted gifts from Amazon, Facebook, Google, Microsoft and Mitsubishi Electric Research Laboratories.
8888
- This work was made possible in part by a donation to NYU from Eric and Wendy Schmidt made
8989
by recommendation of the Schmidt Futures program.
90-
- We gratefully acknowledge the support of NVIDIA Corporation with the donation of a Titan V GPU used at NYU in this work.
90+
- We gratefully acknowledge the support of NVIDIA Corporation with the donation of a Titan V GPU used at NYU in this work.
9191
- Developer Alex Wang is supported by the National Science Foundation Graduate Research Fellowship Program under Grant
9292
No. DGE 1342536. Any opinions, findings, and conclusions or recommendations expressed in this
9393
material are those of the author(s) and do not necessarily reflect the views of the National Science

cola_inference.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@
4747

4848
from jiant.models import build_model
4949
from jiant.preprocess import build_indexers, build_tasks
50-
from jiant.tasks.tasks import process_sentence, sentence_to_text_field
50+
from jiant.tasks.tasks import tokenize_and_truncate, sentence_to_text_field
5151
from jiant.utils import config
5252
from jiant.utils.data_loaders import load_tsv
53-
from jiant.utils.utils import check_arg_name, load_model_state
53+
from jiant.utils.utils import check_arg_name, load_model_state, select_pool_type
5454

5555
log.basicConfig(format="%(asctime)s: %(message)s", datefmt="%m/%d %I:%M:%S %p", level=log.INFO)
5656

@@ -121,6 +121,7 @@ def main(cl_arguments):
121121
cl_args = handle_arguments(cl_arguments)
122122
args = config.params_from_file(cl_args.config_file, cl_args.overrides)
123123
check_arg_name(args)
124+
124125
assert args.target_tasks == "cola", "Currently only supporting CoLA. ({})".format(
125126
args.target_tasks
126127
)
@@ -138,6 +139,11 @@ def main(cl_arguments):
138139
)
139140
args.cuda = -1
140141

142+
if args.tokenizer == "auto":
143+
args.tokenizer = tokenizers.select_tokenizer(args)
144+
if args.pool_type == "auto":
145+
args.pool_type = select_pool_type(args)
146+
141147
# Prepare data #
142148
_, target_tasks, vocab, word_embs = build_tasks(args)
143149
tasks = sorted(set(target_tasks), key=lambda x: x.name)
@@ -185,7 +191,7 @@ def run_repl(model, vocab, indexers, task, args):
185191
if input_string == "QUIT":
186192
break
187193

188-
tokens = process_sentence(
194+
tokens = tokenize_and_truncate(
189195
tokenizer_name=task.tokenizer_name, sent=input_string, max_seq_len=args.max_seq_len
190196
)
191197
print("TOKENS:", " ".join("[{}]".format(tok) for tok in tokens))
@@ -282,7 +288,7 @@ def load_cola_data(input_path, task, input_format, max_seq_len):
282288
with open(input_path, "r") as f_in:
283289
sentences = f_in.readlines()
284290
tokens = [
285-
process_sentence(
291+
tokenize_and_truncate(
286292
tokenizer_name=task.tokenizer_name, sent=sentence, max_seq_len=max_seq_len
287293
)
288294
for sentence in sentences

config/ccg_bert.conf

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ include "defaults.conf"
44
pretrain_tasks = ccg
55
target_tasks = ccg
66
input_module = bert-base-uncased
7-
tokenizer = ${input_module}
87
do_target_task_training = 0
98
transfer_paradigm = finetune
109

@@ -16,7 +15,6 @@ skip_embs = 1
1615

1716
// BERT-specific setup
1817
classifier = log_reg // following BERT paper
19-
pool_type = first
2018

2119
dropout = 0.1 // following BERT paper
2220
optimizer = bert_adam

config/copa_bert.conf

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@ do_full_eval = 1
1919

2020
// Typical BERT base setup
2121
input_module = bert-base-uncased
22-
tokenizer = bert-base-uncased
2322
transfer_paradigm = finetune
2423
classifier = log_reg
25-
pool_type = first
2624
optimizer = bert_adam
2725
lr = 0.00001
2826
sent_enc = none

config/defaults.conf

Lines changed: 60 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ transfer_paradigm = "frozen" // How to use pretrained model parameters during ta
9090
// "frozen" will train the downstream models on fixed
9191
// representations from the encoder model.
9292
// "finetune" will update the parameters of the encoders models as
93-
// well as the downstream models.
93+
// well as the downstream models. (This disables d_proj.)
9494
load_target_train_checkpoint = none // If not "none", load the specified model_state checkpoint
9595
// file when starting do_target_task_training.
9696
// Supports * wildcards.
@@ -140,6 +140,9 @@ batch_size = 32 // Training batch size.
140140
optimizer = adam // Optimizer. All valid AllenNLP options are available, including 'sgd'.
141141
// Use 'bert_adam' for reproducing BERT experiments.
142142
// 'adam' uses the newer AMSGrad variant.
143+
// Warning: bert_adam is designed for cases where the number of epochs is known
144+
// in advance, so it may not behave reasonably unless max_epochs is set to a
145+
// reasonable positive value.
143146
lr = 0.0001 // Initial learning rate.
144147
min_lr = 0.000001 // Minimum learning rate. Training will stop when our explicit LR decay lowers
145148
// the LR below this point or if any other stopping criterion applies.
@@ -221,42 +224,41 @@ max_targ_word_v_size = 20000 // Maximum target word vocab size for seq2seq task
221224

222225
// Input Handling //
223226

224-
input_module = "" // The word embedding or contextual word representation layer.
225-
// Currently supported options:
226-
// - scratch: Word embeddings trained from scratch.
227-
// - glove: Leaded GloVe word embeddings. Typically used with
228-
// tokenizer = MosesTokenizer. Note that this is not quite identical to the
229-
// Stanford tokenizer used to train GloVe.
230-
// - fastText: Leaded GloVe word embeddings. Use with
231-
// tokenizer = MosesTokenizer.
232-
// - elmo: AllenNLP's ELMo contextualized word vector model hidden states. Use
233-
// with tokenizer = MosesTokenizer.
234-
// - elmo-chars-only: The dynamic CNN-based word embedding layer of AllenNLP's
235-
// ELMo, but not ELMo's LSTM layer hidden states. Use with
236-
// tokenizer = MosesTokenizer.
237-
// - bert-base-uncased, etc.: Any BERT model specifier that is valid for
238-
// pytorch-pretrained-bert may be specified here. Use with
239-
// tokenizer = ${input_module}
240-
// We support the newer bert-large-uncased-whole-word-masking and
241-
// bert-large-cased-whole-word-masking cased models, but they require
242-
// the git development version of pytorch-pretrained-bert. To use these
243-
// models, follow the instructions under 'From source' here:
244-
// https://github.com/huggingface/pytorch-pretrained-BERT
245-
// Most of these options use MosesTokenizer tokenization, but
246-
// BERT and GPT need more specific tokenization (tokenizer config
247-
// parameter should be equal to input_module for BERT, and should be
248-
// equal to 'OpenAI.BPE' if input_module = gpt).
249-
// For ELMo, BERT, and GPT, there are additional config parameters below.
250-
251-
tokenizer = "MosesTokenizer" // The name of the tokenizer, passed to the Task constructor for
252-
// appropriate handling during data loading. Currently supported
253-
// options:
254-
// - "": Split the input data on whitespace.
255-
// - MosesTokenizer: Our standard word tokenizer. (Support for
256-
// other NLTK tokenizers is pending.)
257-
// - bert-uncased-base, etc.: Use the tokenizer supplied with
258-
// pytorch-pretrained-bert that corresponds to that BERT model.
259-
// - OpenAI.BPE: The tokenizer supplied with OpenAI GPT.
227+
input_module = "" // The word embedding or contextual word representation layer.
228+
// Currently supported options:
229+
// - scratch: Word embeddings trained from scratch.
230+
// - glove: Loaded GloVe word embeddings. Typically used with
231+
// tokenizer = MosesTokenizer. Note that this is not quite identical to
232+
// the Stanford tokenizer used to train GloVe.
233+
// - fastText: Loaded fastText word embeddings. Use with
234+
// tokenizer = MosesTokenizer.
235+
// - elmo: AllenNLP's ELMo contextualized word vector model hidden states. Use
236+
// with tokenizer = MosesTokenizer.
237+
// - elmo-chars-only: The dynamic CNN-based word embedding layer of AllenNLP's
238+
// ELMo, but not ELMo's LSTM layer hidden states. Use with
239+
// tokenizer = MosesTokenizer.
240+
// - gpt: The OpenAI GPT language model encoder.
241+
// Use with tokenizer = OpenAI.BPE.
242+
// - bert-base-uncased, etc.: Any BERT model specifier that is valid for
243+
// pytorch-pretrained-bert may be specified here. Use with
244+
// tokenizer = ${input_module}
245+
// We support the newer bert-large-uncased-whole-word-masking and
246+
// bert-large-cased-whole-word-masking cased models, but they require
247+
// the git development version of pytorch-pretrained-bert. To use these
248+
// models, follow the instructions under 'From source' here:
249+
// https://github.com/huggingface/pytorch-pretrained-BERT
250+
251+
tokenizer = auto // The name of the tokenizer, passed to the Task constructor for
252+
// appropriate handling during data loading. Currently supported
253+
// options:
254+
// - auto: Select the tokenizer that matches the model specified in
255+
// input_module above. Usually a safe default.
256+
// - "": Split the input data on whitespace.
257+
// - MosesTokenizer: Our standard word tokenizer. (Support for
258+
// other NLTK tokenizers is pending.)
259+
// - bert-uncased-base, etc.: Use the tokenizer supplied with
260+
// pytorch-pretrained-bert that corresponds to that BERT model.
261+
// - OpenAI.BPE: The tokenizer supplied with OpenAI GPT.
260262

261263
word_embs_file = ${WORD_EMBS_FILE} // Path to embeddings file, used with glove and fastText.
262264
d_word = 300 // Dimension of word embeddings, used with scratch, glove, or fastText.
@@ -282,22 +284,21 @@ openai_embeddings_mode = "none" // How to handle the embedding layer of the Ope
282284
// "mix" uses ELMo-style scalar mixing (with
283285
// learned weights) across all layers.
284286

285-
bert_embeddings_mode = "none" // How to handle the embedding layer of the
286-
// BERT model:
287-
// "none" or "top" returns only top-layer activation,
288-
// "cat" returns top-layer concatenated with
289-
// lexical layer,
290-
// "only" returns only lexical layer,
291-
// "mix" uses ELMo-style scalar mixing (with
292-
// learned weights) across all layers.
293-
bert_max_layer = -1 // Maximum layer to return from BERT encoder. Layer 0 is
294-
// wordpiece embeddings.
295-
// bert_embeddings_mode will behave as if the BERT encoder
296-
// is truncated at this layer, so 'top' will return this
297-
// layer, and 'mix' will return a mix of all layers up to
298-
// and including this layer.
299-
// Set to -1 to use all layers.
300-
// Used for probing experiments.
287+
pytorch_transformers_output_mode = "none" // How to handle the embedding layer of the
288+
// BERT/XLNet model:
289+
// "none" or "top" returns only top-layer activation,
290+
// "cat" returns top-layer concatenated with
291+
// lexical layer,
292+
// "only" returns only lexical layer,
293+
// "mix" uses ELMo-style scalar mixing (with learned
294+
// weights) across all layers.
295+
pytorch_transformers_max_layer = -1 // Maximum layer to return from BERT etc. encoder. Layer 0 is
296+
// wordpiece embeddings. pytorch_transformers_embeddings_mode
297+
// will behave as if the is truncated at this layer, so 'top'
298+
// will return this layer, and 'mix' will return a mix of all
299+
// layers up to and including this layer.
300+
// Set to -1 to use all layers.
301+
// Used for probing experiments.
301302

302303
force_include_wsj_vocabulary = 0 // Set if using PTB parsing (grammar induction) task. Makes sure
303304
// to include WSJ vocabulary.
@@ -320,7 +321,7 @@ n_layers_enc = 2 // Number of layers for a 'rnn' sent_enc.
320321
skip_embs = 1 // If true, concatenate the sent_enc's input (ELMo/GPT/BERT output or
321322
// embeddings) with the sent_enc's output.
322323
sep_embs_for_skip = 0 // Whether the skip embedding uses the same embedder object as the original
323-
//embedding (before skip).
324+
// embedding (before skip).
324325
// Only makes a difference if we are using ELMo weights, where it allows
325326
// the four tuned ELMo scalars to vary separately for each target task.
326327
n_layers_highway = 0 // Number of highway layers between the embedding layer and the sent_enc layer. [Deprecated.]
@@ -364,8 +365,11 @@ pair_attn = 1 // If true, use attn in sentence-pair classification/regression t
364365
d_hid_attn = 512 // Post-attention LSTM state size.
365366
shared_pair_attn = 0 // If true, share pair_attn parameters across all tasks that use it.
366367
d_proj = 512 // Size of task-specific linear projection applied before before pooling.
367-
pool_type = "max" // Type of pooling to reduce sequences of vectors into a single vector.
368-
// Options: "max", "mean", "first", "final"
368+
// Disabled when fine-tuning pytorch_transformers models.
369+
pool_type = "auto" // Type of pooling to reduce sequences of vectors into a single vector.
370+
// Options: "auto", "max", "mean", "first", "final"
371+
// "auto" uses "first" for plain BERT (with no sent_enc), "final" for plain
372+
// XLNet and GPT, and "max" in all other settings.
369373
span_classifier_loss_fn = "softmax" // Classifier loss function. Used only in some tasks (notably
370374
// span-related tasks), not mlp/fancy_mlp. Currently supports
371375
// sigmoid and softmax.

config/examples/copa_bert.conf

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@ do_full_eval = 1
1919

2020
// Typical BERT base setup
2121
input_module = bert-base-uncased
22-
tokenizer = bert-base-uncased
2322
transfer_paradigm = finetune
2423
classifier = log_reg
25-
pool_type = first
2624
optimizer = bert_adam
2725
lr = 0.00001
2826
sent_enc = none

config/examples/stilts_example.conf

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ batch_size = 24
1818
write_preds = "val,test"
1919

2020
//BERT-specific parameters
21-
bert_embeddings_mode = "top"
22-
pool_type = "first"
21+
pytorch_transformers_output_mode = "top"
2322
sep_embs_for_skip = 1
2423
sent_enc = "none"
2524
classifier = log_reg // following BERT paper
@@ -34,6 +33,5 @@ patience = 20
3433
max_vals = 10000
3534
transfer_paradigm = "finetune"
3635

37-
tokenizer = "bert-base-uncased"
3836
input_module = "bert-base-uncased"
3937

config/superglue-bert.conf

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@ exp_name = "bert-large-cased"
77
// Data and preprocessing settings
88
max_seq_len = 256 // Mainly needed for MultiRC, to avoid over-truncating
99
// But not 512 as that is really hard to fit in memory.
10-
tokenizer = "bert-large-cased"
10+
1111
// Model settings
1212
input_module = "bert-large-cased"
13-
bert_embeddings_mode = "top"
14-
pool_type = "first"
13+
pytorch_transformers_output_mode = "top"
1514
pair_attn = 0 // shouldn't be needed but JIC
1615
s2s = {
1716
attention = none

environment.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,9 @@ dependencies:
3232
- ftfy==5.4.1
3333
- spacy==2.0.11
3434

35+
# Warning: jiant currently depends on *both* pytorch_pretrained_bert > 0.6 _and_
36+
# pytorch_transformers > 1.0. These are the same package, though the name changed between
37+
# these two versions. AllenNLP requires 0.6 to support the BertAdam optimizer, and jiant
38+
# directly requires 1.0 to support XLNet and WWM-BERT.
39+
# This AllenNLP issue is relevant: https://github.com/allenai/allennlp/issues/3067
40+
- pytorch-transformers==1.0.0

gcp/config/jiant_paths.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ export JIANT_PROJECT_PREFIX="$HOME/exp"
1212

1313
# pre-downloaded ELMo models
1414
export ELMO_SRC_DIR="/nfs/jiant/share/elmo"
15-
# cache for BERT models
16-
export PYTORCH_PRETRAINED_BERT_CACHE="/nfs/jiant/share/bert_cache"
15+
# cache for BERT etc. models
16+
export PYTORCH_PRETRAINED_BERT_CACHE="/nfs/jiant/share/pytorch_transformers_cache"
1717
# word embeddings
1818
export WORD_EMBS_FILE="/nfs/jiant/share/wiki-news-300d-1M.vec"
1919

gcp/kubernetes/run_batch.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,3 @@ jsonnet -S -o "${YAML_FILE}" \
9999
##
100100
# Create the Kubernetes pod; this will actually launch the job.
101101
kubectl ${KUBECTL_MODE} -f "${YAML_FILE}"
102-

0 commit comments

Comments
 (0)