Skip to content

Commit 743de33

Browse files
author
Alex Wang
authored
Reset all BiLMs (ELMo and when using BiLM); move reset to sentence encoder (#393)
* Move reset elmo to sentence encoder * Remove extra import * Move reset state to sentenceencoder; reset biLM states also * Remove reset elmo util function * Remove reset elmo fn * Add comment * Dont use hasattr; remove check in trainer
1 parent ec9ba0f commit 743de33

File tree

5 files changed

+19
-18
lines changed

5 files changed

+19
-18
lines changed

src/evaluate.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from allennlp.data.iterators import BasicIterator
1212
from . import tasks as tasks_module
1313
from . import preprocess
14-
from . import utils
1514

1615
from typing import List, Sequence, Iterable, Tuple, Dict
1716

@@ -55,7 +54,6 @@ def evaluate(model, tasks: Sequence[tasks_module.Task], batch_size: int,
5554
dataset = getattr(task, "%s_data" % split)
5655
generator = iterator(dataset, num_epochs=1, shuffle=False, cuda_device=cuda_device)
5756
for batch_idx, batch in enumerate(generator):
58-
utils.reset_elmo_states(model)
5957
out = model.forward(task, batch, predict=True)
6058
# We don't want mnli-diagnostic to affect the micro and macro average.
6159
# Accuracy of mnli-diagnostic is hardcoded to 0.

src/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,7 @@ def _positive_pair_sentence_forward(self, batch, task, predict):
652652
So rotating sent1/sent2 and pairing with sent2/sent1 is one way to obtain -ve pairs
653653
'''
654654
out = {}
655+
655656
# embed the sentence
656657
sent1, mask1 = self.sent_encoder(batch['input1'], task)
657658
sent2, mask2 = self.sent_encoder(batch['input2'], task)

src/modules.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def forward(self, embs, mask):
6262
return None
6363

6464
class SentenceEncoder(Model):
65-
''' Given a sequence of tokens, embed each token and pass thru an LSTM. '''
65+
''' Given a sequence of tokens, embed each token and pass thru a sequence encoder. '''
6666
# NOTE: Do not apply dropout to the input of this module. Will be applied internally.
6767

6868
def __init__(self, vocab, text_field_embedder, num_highway_layers, phrase_layer,
@@ -95,18 +95,23 @@ def __init__(self, vocab, text_field_embedder, num_highway_layers, phrase_layer,
9595

9696
initializer(self)
9797

98-
def forward(self, sent, task):
98+
def forward(self, sent, task, reset=True):
9999
# pylint: disable=arguments-differ
100100
"""
101101
Args:
102102
- sent (Dict[str, torch.LongTensor]): From a ``TextField``.
103103
- task (Task): Used by the _text_field_embedder to pick the correct output
104104
ELMo representation.
105+
- reset (Bool): if True, manually reset the states of the ELMo LSTMs present
106+
(if using BiLM or ELMo embeddings). Set False, if want to preserve statefulness.
105107
Returns:
106108
- sent_enc (torch.FloatTensor): (b_size, seq_len, d_emb)
107109
the padded values in sent_enc are set to 0
108110
- sent_mask (torch.FloatTensor): (b_size, seq_len, d_emb); all 0/1s
109111
"""
112+
if reset:
113+
self.reset_states()
114+
110115
# Embeddings
111116
# Note: These highway modules are actually identity functions by default.
112117

@@ -183,6 +188,14 @@ def forward(self, sent, task):
183188
sent_enc = sent_enc.masked_fill(pad_mask, 0)
184189
return sent_enc, sent_mask
185190

191+
def reset_states(self):
192+
''' Reset ELMo if present; reset BiLM (ELMoLSTM) states if present '''
193+
if 'token_embedder_elmo' in [name for name, _ in self._text_field_embedder.named_children()] and \
194+
'_elmo' in [name for name, _ in self._text_field_embedder.token_embedder_elmo.named_children()]:
195+
self._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.reset_states()
196+
if isinstance(self._phrase_layer, BiLMEncoder):
197+
self._phrase_layer.reset_states()
198+
186199
class BiLMEncoder(ElmoLstm):
187200
"""Wrapper around BiLM to give it an interface to comply with SentEncoder
188201
See base class: ElmoLstm

src/trainer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from allennlp.training.learning_rate_schedulers import LearningRateScheduler # pylint: disable=import-error
2222
from allennlp.training.optimizers import Optimizer # pylint: disable=import-error
2323

24-
from .utils import device_mapping, assert_for_log, reset_elmo_states # pylint: disable=import-error
24+
from .utils import device_mapping, assert_for_log # pylint: disable=import-error
2525
from .evaluate import evaluate
2626
from . import config
2727

@@ -434,8 +434,6 @@ def clip_function(grad): return grad.clamp(-self._grad_clipping, self._grad_clip
434434
n_batches_since_val += 1
435435
total_batches_trained += 1
436436
optimizer.zero_grad()
437-
if self._model.elmo:
438-
assert_for_log(self._model.sent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm._states is None, "Found carried over ELMo states!")
439437
output_dict = self._forward(batch, task=task, for_training=True)
440438
assert_for_log("loss" in output_dict,
441439
"Model must return a dict containing a 'loss' key")
@@ -730,7 +728,6 @@ def _forward(self, batch, for_training, task=None):
730728
''' At one point this does something, now it doesn't really do anything '''
731729
tensor_batch = batch
732730
model_out = self._model.forward(task, tensor_batch)
733-
reset_elmo_states(self._model)
734731
return model_out
735732

736733
def _description_from_metrics(self, metrics):

src/utils.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,16 @@
1616
import numpy as np
1717
import torch
1818
from torch.autograd import Variable
19-
20-
from allennlp.common.checks import ConfigurationError
21-
22-
# Masked Multi headed self attention
2319
from torch.nn import Dropout, Linear
2420
from torch.nn import Parameter
2521
from torch.nn import init
2622

23+
from allennlp.common.checks import ConfigurationError
2724
from allennlp.nn.util import last_dim_softmax
2825
from allennlp.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder
2926
from allennlp.common.params import Params
3027

28+
3129
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
3230

3331

@@ -38,12 +36,6 @@
3836
# a poor job of adding correct whitespace. Use unescape_xml() only.
3937
_MOSES_DETOKENIZER = MosesDetokenizer()
4038

41-
def reset_elmo_states(model):
42-
''' Reset ELMo hidden states if ELMo is detected '''
43-
if model.elmo:
44-
model.sent_encoder._text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.reset_states()
45-
return
46-
4739
def copy_iter(elems):
4840
'''Simple iterator yielding copies of elements.'''
4941
for elem in elems:

0 commit comments

Comments
 (0)