@@ -31,10 +31,14 @@ class Seq2SeqDecoder(Model):
31
31
def __init__ (self ,
32
32
vocab : Vocabulary ,
33
33
input_dim : int ,
34
+ decoder_hidden_size : int ,
35
+ max_decoding_steps : int ,
36
+ output_proj_input_dim : int ,
34
37
target_namespace : str = "targets" ,
35
38
target_embedding_dim : int = None ,
36
39
attention : str = "none" ,
37
40
dropout : float = 0.0 ,
41
+ scheduled_sampling_ratio : float = 0.0 ,
38
42
) -> None :
39
43
super (Seq2SeqDecoder , self ).__init__ (vocab )
40
44
self ._max_decoding_steps = max_decoding_steps
@@ -50,26 +54,39 @@ def __init__(self,
50
54
# Decoder output dim needs to be the same as the encoder output dim since we initialize the
51
55
# hidden state of the decoder with that of the final hidden states of the encoder. Also, if
52
56
# we're using attention with ``DotProductSimilarity``, this is needed.
53
- self ._decoder_hidden_dim = input_dim
57
+ self ._encoder_output_dim = input_dim
58
+ self ._decoder_hidden_dim = decoder_hidden_size
59
+ if self ._encoder_output_dim != self ._decoder_hidden_dim :
60
+ self ._projection_encoder_out = Linear (self ._encoder_output_dim , self ._decoder_hidden_dim )
61
+ else :
62
+ self ._projection_encoder_out = lambda x : x
54
63
self ._decoder_output_dim = self ._decoder_hidden_dim
55
- # target_embedding_dim = target_embedding_dim #or self._source_embedder.get_output_dim()
64
+ self . _output_proj_input_dim = output_proj_input_dim
56
65
self ._target_embedding_dim = target_embedding_dim
57
66
self ._target_embedder = Embedding (num_classes , self ._target_embedding_dim )
58
67
59
- self ._sent_pooler = Pooler .from_params (input_dim , input_dim , False )
68
+ # Used to get an initial hidden state from the encoder states
69
+ self ._sent_pooler = Pooler .from_params (d_inp = input_dim , d_proj = decoder_hidden_size , project = True )
60
70
61
71
if attention == "bilinear" :
62
- self ._decoder_attention = BilinearAttention (input_dim , input_dim )
72
+ self ._decoder_attention = BilinearAttention (decoder_hidden_size , input_dim )
63
73
# The output of attention, a weighted average over encoder outputs, will be
64
74
# concatenated to the input vector of the decoder at each time step.
65
75
self ._decoder_input_dim = input_dim + target_embedding_dim
66
76
elif attention == "none" :
77
+ self ._decoder_attention = None
67
78
self ._decoder_input_dim = target_embedding_dim
68
79
else :
69
80
raise Exception ("attention not implemented {}" .format (attention ))
70
81
71
82
self ._decoder_cell = LSTMCell (self ._decoder_input_dim , self ._decoder_hidden_dim )
72
- self ._output_projection_layer = Linear (self ._decoder_output_dim , num_classes )
83
+ # Allow for a bottleneck layer between encoder outputs and distribution over vocab
84
+ # The bottleneck layer consists of a linear transform and helps to reduce number of parameters
85
+ if self ._output_proj_input_dim != self ._decoder_output_dim :
86
+ self ._projection_bottleneck = Linear (self ._decoder_output_dim , self ._output_proj_input_dim )
87
+ else :
88
+ self ._projection_bottleneck = lambda x : x
89
+ self ._output_projection_layer = Linear (self ._output_proj_input_dim , num_classes )
73
90
self ._dropout = torch .nn .Dropout (p = dropout )
74
91
75
92
def _initalize_hidden_context_states (self , encoder_outputs , encoder_outputs_mask ):
@@ -80,10 +97,9 @@ def _initalize_hidden_context_states(self, encoder_outputs, encoder_outputs_mask
80
97
encoder_outputs: torch.FloatTensor, [bs, T, h]
81
98
encoder_outputs_mask: torch.LongTensor, [bs, T, 1]
82
99
"""
83
- # very important - feel free to check it a third time
84
- # idempotent / safe to run in place. encoder_outputs_mask should never
85
- # change
86
- if hasattr (self , "_decoder_attention" ) and self ._decoder_attention :
100
+
101
+ if self ._decoder_attention is not None :
102
+ encoder_outputs = self ._projection_encoder_out (encoder_outputs )
87
103
encoder_outputs .data .masked_fill_ (1 - encoder_outputs_mask .byte ().data , - float ('inf' ))
88
104
89
105
decoder_hidden = encoder_outputs .new_zeros (encoder_outputs_mask .size (0 ), self ._decoder_hidden_dim )
@@ -132,8 +148,10 @@ def forward(self, # type: ignore
132
148
decoder_hidden , decoder_context = self ._decoder_cell (
133
149
decoder_input , (decoder_hidden , decoder_context ))
134
150
151
+ # output projection
152
+ proj_input = self ._projection_bottleneck (decoder_hidden )
135
153
# (batch_size, num_classes)
136
- output_projections = self ._output_projection_layer (decoder_hidden )
154
+ output_projections = self ._output_projection_layer (proj_input )
137
155
138
156
# list of (batch_size, 1, num_classes)
139
157
step_logit = output_projections .unsqueeze (1 )
@@ -204,7 +222,7 @@ def _prepare_decode_step_input(
204
222
# (batch_size, target_embedding_dim)
205
223
embedded_input = self ._target_embedder (input_indices )
206
224
207
- if hasattr ( self , " _decoder_attention" ) and self . _decoder_attention :
225
+ if self . _decoder_attention is not None :
208
226
# encoder_outputs : (batch_size, input_sequence_length, encoder_output_dim)
209
227
# Ensuring mask is also a FloatTensor. Or else the multiplication within attention will
210
228
# complain.
@@ -221,9 +239,9 @@ def _prepare_decode_step_input(
221
239
# (batch_size, input_sequence_length)
222
240
input_weights = self ._decoder_attention (
223
241
decoder_hidden_state , encoder_outputs , encoder_outputs_mask )
224
- # (batch_size, encoder_output_dim )
242
+ # (batch_size, input_dim )
225
243
attended_input = weighted_sum (encoder_outputs , input_weights )
226
- # (batch_size, encoder_output_dim + target_embedding_dim)
244
+ # (batch_size, input_dim + target_embedding_dim)
227
245
return torch .cat ((attended_input , embedded_input ), - 1 )
228
246
else :
229
247
return embedded_input
@@ -259,20 +277,3 @@ def _get_loss(logits: torch.LongTensor,
259
277
relevant_mask = target_mask [:, 1 :].contiguous () # (batch_size, num_decoding_steps)
260
278
loss = sequence_cross_entropy_with_logits (logits , relevant_targets , relevant_mask )
261
279
return loss
262
-
263
- @classmethod
264
- def from_params (cls , vocab , params : Params ) -> 'SimpleSeq2Seq' :
265
- input_dim = params .pop ("input_dim" )
266
- max_decoding_steps = params .pop ("max_decoding_steps" )
267
- target_namespace = params .pop ("target_namespace" , "targets" )
268
- target_embedding_dim = params .pop ("target_embedding_dim" )
269
- attention = params .pop ("attention" , "none" )
270
- dropout = params .pop_float ("dropout" , 0.0 )
271
- params .assert_empty (cls .__name__ )
272
- return cls (vocab ,
273
- input_dim = input_dim ,
274
- target_embedding_dim = target_embedding_dim ,
275
- max_decoding_steps = max_decoding_steps ,
276
- target_namespace = target_namespace ,
277
- attention = attention ,
278
- dropout = dropout )
0 commit comments