Skip to content

Commit c8f68c4

Browse files
yukatherinAlex Wang
authored and
Alex Wang
committed
s2s decoder update (make more params active; add projection layer) (#384)
* quick fix for mt * [veryminor] add mt_attention parameter to defaults.conf * readme: mt->s2s * [s2sdecoder] add bottleneck layer * [s2sdecoder] make all parameters active * clean beamsearch * pull/384: address comments * fix merge * Add s2s_ prefix to opts; add projection documentation in code * Rename s2s configs + refactor param dict construction; fix scheduled sampling bug; fix projected dim bug * Redo decoder param construction * Remove unneeded code * Remove hasattr calls; fix bug when no attn * Fix wrong param name
1 parent 743de33 commit c8f68c4

File tree

5 files changed

+68
-58
lines changed

5 files changed

+68
-58
lines changed

config/defaults.conf

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,12 +222,19 @@ classifier_loss_fn = "" // Classifier loss function. Used only in some speciali
222222
classifier_span_pooling = "x,y" // Span pooling type (for edge probing only).
223223
// Options: 'attn' or one of the 'combination' arguments accepted by AllenNLP's
224224
// EndpointSpanExtractor.
225+
226+
s2s {
227+
d_hid_dec = 1024 // The hidden size of the decoder in seq2seq tasks.
228+
n_layers_dec = 1 // The number of decoder layers in seq2seq tasks.
229+
target_embedding_dim = 300 // The size of target word embeddings in seq2seq tasks.
230+
attention = "bilinear" // Attention used in s2s. Current implemented options are "bilinear" and "none".
231+
output_proj_input_dim = 1024 // Dimension of bottleneck layer in s2s decoder output projection. If
232+
// output_proj_input_dim == d_hid_dec, will not add projection.
233+
}
234+
225235
edgeprobe_cnn_context = 0 // expanded context for edge probing via CNN.
226236
// 0 looks at only the current word, 1 adds +/-
227237
// words (kernel width 3), etc.
228-
d_hid_dec = 300 // The hidden size of the decoder in seq2seq tasks.
229-
n_layers_dec = 1 // The number of decoder layers in seq2seq tasks.
230-
mt_attention = "bilinear" // Attention used in s2s. Current implemented options are "bilinear" and "none".
231238

232239
// Training
233240
eval_val_interval = 500 // Comparable to val_interval, used during train_for_eval. Can be set separately per task.

src/beamsearch.py renamed to src/generate_s2s.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from . import bleu_scoring
77
import numpy as np
88

9-
""" Beam search was confirmed to be WRONG. Use greedy search"""
10-
119

1210
def _get_word(decoder_vocab, word_idx):
1311
return decoder_vocab._index_to_token['targets'][word_idx]

src/models.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from .utils import get_batch_utilization, get_elmo_mixing_weights
2727
from . import config
2828
from . import edge_probing
29-
#from . import beamsearch
3029

3130
from .tasks import CCGTaggingTask, ClassificationTask, CoLATask, EdgeProbingTask, GroundedSWTask, \
3231
GroundedTask, LanguageModelingTask, MTTask, MultiNLIDiagnosticTask, PairClassificationTask, \
@@ -345,33 +344,37 @@ def build_module(task, model, d_sent, d_emb, vocab, embedder, args):
345344
module = edge_probing.EdgeClassifierModule(task, d_sent, task_params)
346345
setattr(model, '%s_mdl' % task.name, module)
347346
elif isinstance(task, (RedditSeq2SeqTask, Wiki103Seq2SeqTask)):
348-
attention = args.mt_attention
349-
log.info("using {} attention".format(attention))
347+
log.info("using {} attention".format(args.s2s['attention']))
350348
decoder_params = Params({'input_dim': d_sent,
351349
'target_embedding_dim': 300,
350+
'decoder_hidden_size': args.s2s['d_hid_dec'],
351+
'output_proj_input_dim': args.s2s['output_proj_input_dim'],
352352
'max_decoding_steps': args.max_seq_len,
353353
'target_namespace': 'tokens',
354-
'attention': attention,
354+
'attention': args.s2s['attention'],
355355
'dropout': args.dropout,
356356
'scheduled_sampling_ratio': 0.0})
357-
decoder = Seq2SeqDecoder.from_params(vocab, decoder_params)
357+
decoder = Seq2SeqDecoder(vocab, **decoder_params)
358358
setattr(model, '%s_decoder' % task.name, decoder)
359359
elif isinstance(task, MTTask):
360-
attention = args.mt_attention
361-
log.info("using {} attention".format(attention))
360+
log.info("using {} attention".format(args.s2s['attention']))
362361
decoder_params = Params({'input_dim': d_sent,
363362
'target_embedding_dim': 300,
364-
'max_decoding_steps': 200,
363+
'decoder_hidden_size': args.s2s['d_hid_dec'],
364+
'output_proj_input_dim': args.s2s['output_proj_input_dim'],
365+
'max_decoding_steps': args.max_seq_len,
365366
'target_namespace': task._label_namespace if hasattr(task, '_label_namespace') else 'targets',
366-
'attention': attention,
367+
'attention': args.s2s['attention'],
367368
'dropout': args.dropout,
368369
'scheduled_sampling_ratio': 0.0})
369-
decoder = Seq2SeqDecoder.from_params(vocab, decoder_params)
370+
decoder = Seq2SeqDecoder(vocab, **decoder_params)
370371
setattr(model, '%s_decoder' % task.name, decoder)
372+
371373
elif isinstance(task, SequenceGenerationTask):
372374
decoder, hid2voc = build_decoder(task, d_sent, vocab, embedder, args)
373375
setattr(model, '%s_decoder' % task.name, decoder)
374376
setattr(model, '%s_hid2voc' % task.name, hid2voc)
377+
375378
elif isinstance(task, (GroundedTask, GroundedSWTask)):
376379
task.img_encoder = CNNEncoder(model_name='resnet', path=task.path)
377380
pooler = build_image_sent_module(task, d_sent, task_params)
@@ -491,10 +494,10 @@ def build_decoder(task, d_inp, vocab, embedder, args):
491494
''' Build a task specific decoder '''
492495
rnn = s2s_e.by_name('lstm').from_params(
493496
Params({'input_size': embedder.get_output_dim(),
494-
'hidden_size': args.d_hid_dec,
495-
'num_layers': args.n_layers_dec, 'bidirectional': False}))
497+
'hidden_size': args.s2s['d_hid_dec'],
498+
'num_layers': args.s2s['n_layers_dec'], 'bidirectional': False}))
496499
decoder = SentenceEncoder(vocab, embedder, 0, rnn)
497-
hid2voc = nn.Linear(args.d_hid_dec, args.max_word_v_size)
500+
hid2voc = nn.Linear(args.s2s['d_hid_dec'], args.max_word_v_size)
498501
return decoder, hid2voc
499502

500503

@@ -813,16 +816,6 @@ def _seq_gen_forward(self, batch, task, predict):
813816
out.update(decoder.forward(sent, sent_mask, batch['targs']))
814817
task.scorer1(out['loss'].item())
815818

816-
# Commented out for final run (still needs this for further debugging).
817-
# We don't want to write predictions during training.
818-
#if not self.training and not isinstance(task, Wiki103_Seq2Seq):
819-
# # bleu scoring
820-
# bleu_score, unk_ratio_macroavg = beamsearch.generate_and_compute_bleu(decoder, sent, sent_mask, batch['targs']['words'], preds_file_path=task.preds_file_path, task=task)
821-
# task.scorer2(bleu_score)
822-
# task.scorer3(unk_ratio_macroavg)
823-
824-
return out
825-
826819
if 'targs' in batch:
827820
pass
828821

src/seq2seq_decoder.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,14 @@ class Seq2SeqDecoder(Model):
3131
def __init__(self,
3232
vocab: Vocabulary,
3333
input_dim: int,
34+
decoder_hidden_size: int,
35+
max_decoding_steps: int,
36+
output_proj_input_dim: int,
3437
target_namespace: str = "targets",
3538
target_embedding_dim: int = None,
3639
attention: str = "none",
3740
dropout: float = 0.0,
41+
scheduled_sampling_ratio: float = 0.0,
3842
) -> None:
3943
super(Seq2SeqDecoder, self).__init__(vocab)
4044
self._max_decoding_steps = max_decoding_steps
@@ -50,26 +54,39 @@ def __init__(self,
5054
# Decoder output dim needs to be the same as the encoder output dim since we initialize the
5155
# hidden state of the decoder with that of the final hidden states of the encoder. Also, if
5256
# we're using attention with ``DotProductSimilarity``, this is needed.
53-
self._decoder_hidden_dim = input_dim
57+
self._encoder_output_dim = input_dim
58+
self._decoder_hidden_dim = decoder_hidden_size
59+
if self._encoder_output_dim != self._decoder_hidden_dim:
60+
self._projection_encoder_out = Linear(self._encoder_output_dim, self._decoder_hidden_dim)
61+
else:
62+
self._projection_encoder_out = lambda x: x
5463
self._decoder_output_dim = self._decoder_hidden_dim
55-
# target_embedding_dim = target_embedding_dim #or self._source_embedder.get_output_dim()
64+
self._output_proj_input_dim = output_proj_input_dim
5665
self._target_embedding_dim = target_embedding_dim
5766
self._target_embedder = Embedding(num_classes, self._target_embedding_dim)
5867

59-
self._sent_pooler = Pooler.from_params(input_dim, input_dim, False)
68+
# Used to get an initial hidden state from the encoder states
69+
self._sent_pooler = Pooler.from_params(d_inp=input_dim, d_proj=decoder_hidden_size, project=True)
6070

6171
if attention == "bilinear":
62-
self._decoder_attention = BilinearAttention(input_dim, input_dim)
72+
self._decoder_attention = BilinearAttention(decoder_hidden_size, input_dim)
6373
# The output of attention, a weighted average over encoder outputs, will be
6474
# concatenated to the input vector of the decoder at each time step.
6575
self._decoder_input_dim = input_dim + target_embedding_dim
6676
elif attention == "none":
77+
self._decoder_attention = None
6778
self._decoder_input_dim = target_embedding_dim
6879
else:
6980
raise Exception("attention not implemented {}".format(attention))
7081

7182
self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_hidden_dim)
72-
self._output_projection_layer = Linear(self._decoder_output_dim, num_classes)
83+
# Allow for a bottleneck layer between encoder outputs and distribution over vocab
84+
# The bottleneck layer consists of a linear transform and helps to reduce number of parameters
85+
if self._output_proj_input_dim != self._decoder_output_dim:
86+
self._projection_bottleneck = Linear(self._decoder_output_dim, self._output_proj_input_dim)
87+
else:
88+
self._projection_bottleneck = lambda x: x
89+
self._output_projection_layer = Linear(self._output_proj_input_dim, num_classes)
7390
self._dropout = torch.nn.Dropout(p=dropout)
7491

7592
def _initalize_hidden_context_states(self, encoder_outputs, encoder_outputs_mask):
@@ -80,10 +97,9 @@ def _initalize_hidden_context_states(self, encoder_outputs, encoder_outputs_mask
8097
encoder_outputs: torch.FloatTensor, [bs, T, h]
8198
encoder_outputs_mask: torch.LongTensor, [bs, T, 1]
8299
"""
83-
# very important - feel free to check it a third time
84-
# idempotent / safe to run in place. encoder_outputs_mask should never
85-
# change
86-
if hasattr(self, "_decoder_attention") and self._decoder_attention:
100+
101+
if self._decoder_attention is not None:
102+
encoder_outputs = self._projection_encoder_out(encoder_outputs)
87103
encoder_outputs.data.masked_fill_(1 - encoder_outputs_mask.byte().data, -float('inf'))
88104

89105
decoder_hidden = encoder_outputs.new_zeros(encoder_outputs_mask.size(0), self._decoder_hidden_dim)
@@ -132,8 +148,10 @@ def forward(self, # type: ignore
132148
decoder_hidden, decoder_context = self._decoder_cell(
133149
decoder_input, (decoder_hidden, decoder_context))
134150

151+
# output projection
152+
proj_input = self._projection_bottleneck(decoder_hidden)
135153
# (batch_size, num_classes)
136-
output_projections = self._output_projection_layer(decoder_hidden)
154+
output_projections = self._output_projection_layer(proj_input)
137155

138156
# list of (batch_size, 1, num_classes)
139157
step_logit = output_projections.unsqueeze(1)
@@ -204,7 +222,7 @@ def _prepare_decode_step_input(
204222
# (batch_size, target_embedding_dim)
205223
embedded_input = self._target_embedder(input_indices)
206224

207-
if hasattr(self, "_decoder_attention") and self._decoder_attention:
225+
if self._decoder_attention is not None:
208226
# encoder_outputs : (batch_size, input_sequence_length, encoder_output_dim)
209227
# Ensuring mask is also a FloatTensor. Or else the multiplication within attention will
210228
# complain.
@@ -221,9 +239,9 @@ def _prepare_decode_step_input(
221239
# (batch_size, input_sequence_length)
222240
input_weights = self._decoder_attention(
223241
decoder_hidden_state, encoder_outputs, encoder_outputs_mask)
224-
# (batch_size, encoder_output_dim)
242+
# (batch_size, input_dim)
225243
attended_input = weighted_sum(encoder_outputs, input_weights)
226-
# (batch_size, encoder_output_dim + target_embedding_dim)
244+
# (batch_size, input_dim + target_embedding_dim)
227245
return torch.cat((attended_input, embedded_input), -1)
228246
else:
229247
return embedded_input
@@ -259,20 +277,3 @@ def _get_loss(logits: torch.LongTensor,
259277
relevant_mask = target_mask[:, 1:].contiguous() # (batch_size, num_decoding_steps)
260278
loss = sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask)
261279
return loss
262-
263-
@classmethod
264-
def from_params(cls, vocab, params: Params) -> 'SimpleSeq2Seq':
265-
input_dim = params.pop("input_dim")
266-
max_decoding_steps = params.pop("max_decoding_steps")
267-
target_namespace = params.pop("target_namespace", "targets")
268-
target_embedding_dim = params.pop("target_embedding_dim")
269-
attention = params.pop("attention", "none")
270-
dropout = params.pop_float("dropout", 0.0)
271-
params.assert_empty(cls.__name__)
272-
return cls(vocab,
273-
input_dim=input_dim,
274-
target_embedding_dim=target_embedding_dim,
275-
max_decoding_steps=max_decoding_steps,
276-
target_namespace=target_namespace,
277-
attention=attention,
278-
dropout=dropout)

src/tasks.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1773,6 +1773,17 @@ def get_metrics(self, reset=False):
17731773
return {'perplexity': math.exp(avg_nll), 'bleu_score': 0, 'unk_ratio_macroavg': unk_ratio_macroavg}
17741774

17751775

1776+
@register_task('wmt_debug', rel_path='wmt_debug/', max_targ_v_size=5000)
1777+
class MTDebug(MTTask):
1778+
def __init__(self, path, max_seq_len, max_targ_v_size, name='wmt_debug'):
1779+
''' Demo task for MT with 10k training examples.'''
1780+
super().__init__(path=path, max_seq_len=max_seq_len,
1781+
max_targ_v_size=max_targ_v_size, name=name)
1782+
self.files_by_split = {"train": os.path.join(path, "train.txt"),
1783+
"val": os.path.join(path, "valid.txt"),
1784+
"test": os.path.join(path, "test.txt")}
1785+
1786+
17761787
@register_task('wmt17_en_ru', rel_path='wmt17_en_ru/', max_targ_v_size=20000)
17771788
class MTTaskEnRu(MTTask):
17781789
def __init__(self, path, max_seq_len, max_targ_v_size, name='mt_en_ru'):

0 commit comments

Comments
 (0)