4
4
from KerasLayer .context2query_attention import context2query_attention
5
5
from KerasLayer .multihead_attention import Attention as MultiHeadAttention
6
6
from KerasLayer .position_embedding import Position_Embedding as PositionEmbedding
7
- from keras import layers
7
+ from KerasLayer . QAoutputBlock import QAoutputBlock
8
8
from keras .optimizers import *
9
9
from keras .callbacks import *
10
10
from KerasLayer .layer_dropout import LayerDropout
@@ -63,16 +63,8 @@ def feed_forward_block(FeedForward_layers, x, dropout=0.0, l=1., L=1.):
63
63
x = LayerDropout (dropout * (l / L ))([x , residual ])
64
64
return x
65
65
66
- def output_block (x1 , x2 , ans_limit = 50 ):
67
- outer = tf .matmul (tf .expand_dims (x1 , axis = 2 ), tf .expand_dims (x2 , axis = 1 ))
68
- outer = tf .matrix_band_part (outer , 0 , ans_limit )
69
- output1 = tf .argmax (tf .reduce_max (outer , axis = 2 ), axis = 1 )
70
- output2 = tf .argmax (tf .reduce_max (outer , axis = 1 ), axis = 1 )
71
- return [output1 , output2 ]
72
-
73
-
74
66
def QANet (word_dim = 300 , char_dim = 64 , cont_limit = 400 , ques_limit = 50 , char_limit = 16 , word_mat = None , char_mat = None ,
75
- char_input_size = 1000 , filters = 128 , num_head = 8 , dropout = 0.1 , train = True , ans_limit = 30 ):
67
+ char_input_size = 1000 , filters = 128 , num_head = 8 , dropout = 0.1 , ans_limit = 30 ):
76
68
# Input Embedding Layer
77
69
contw_input = Input ((cont_limit ,))
78
70
quesw_input = Input ((ques_limit ,))
@@ -85,17 +77,6 @@ def QANet(word_dim=300, char_dim=64, cont_limit=400, ques_limit=50, char_limit=1
85
77
cont_len = Lambda (lambda x : tf .expand_dims (tf .reduce_sum (tf .cast (x , tf .int32 ), axis = 1 ), axis = 1 ))(c_mask )
86
78
ques_len = Lambda (lambda x : tf .expand_dims (tf .reduce_sum (tf .cast (x , tf .int32 ), axis = 1 ), axis = 1 ))(q_mask )
87
79
88
- # # slice
89
- # c_maxlen = tf.reduce_max(cont_len)
90
- # q_maxlen = tf.reduce_max(ques_len)
91
- # contw_input_ = Lambda(lambda x:tf.slice(x, [0, 0], [-1, c_maxlen]))(contw_input)
92
- # quesw_input_ = Lambda(lambda x:tf.slice(x, [0, 0], [-1, q_maxlen]))(quesw_input)
93
- # c_mask_ = Lambda(lambda x:tf.slice(x, [0, 0], [-1, c_maxlen]))(c_mask)
94
- # q_mask_ = Lambda(lambda x:tf.slice(x, [0, 0], [-1, q_maxlen]))(q_mask)
95
- # contc_input_ = Lambda(lambda x:tf.slice(x, [0, 0, 0], [-1,c_maxlen, char_limit]))(contc_input)
96
- # quesc_input_ = Lambda(lambda x:tf.slice(x, [0, 0, 0], [-1,q_maxlen, char_limit]))(quesc_input)
97
- # print(contw_input_,quesw_input_)
98
-
99
80
# embedding word
100
81
WordEmbedding = Embedding (word_mat .shape [0 ], word_dim , weights = [word_mat ], mask_zero = False , trainable = False )
101
82
xw_cont = WordEmbedding (contw_input )
@@ -203,21 +184,16 @@ def QANet(word_dim=300, char_dim=64, cont_limit=400, ques_limit=50, char_limit=1
203
184
x_end = Lambda (lambda x : mask_logits (x [0 ], x [1 ], axis = 0 , time_dim = 1 ))([x_end , cont_len ])
204
185
x_end = Lambda (lambda x : K .softmax (x ), name = 'end' )(x_end )
205
186
206
- if train :
207
- return Model (inputs = [contw_input , quesw_input , contc_input , quesc_input ],
208
- outputs = [x_start , x_end ])
209
- else :
210
- x_final = Lambda (lambda x : output_block (x [0 ], x [1 ], ans_limit ))([x_start , x_end ])
211
- return Model (inputs = [contw_input , quesw_input , contc_input , quesc_input ],
212
- outputs = x_final )
187
+ x_start_fin , x_end_fin = QAoutputBlock (ans_limit )([x_start ,x_end ])
188
+ return Model (inputs = [contw_input , quesw_input , contc_input , quesc_input ], outputs = [x_start , x_end , x_start_fin , x_end_fin ])
213
189
214
190
embedding_matrix = np .random .random ((10000 ,300 ))
215
191
embedding_matrix_char = np .random .random ((1000 ,64 ))
216
192
model = QANet (word_mat = embedding_matrix ,char_mat = embedding_matrix_char )
217
- model .summary ()
193
+ # model.summary()
218
194
219
- # optimizer=Adam(lr=0.001,beta_1=0.8,beta_2=0.999,epsilon=1e-7)
220
- # model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['acc' ])
195
+ optimizer = Adam (lr = 0.001 ,beta_1 = 0.8 ,beta_2 = 0.999 ,epsilon = 1e-7 )
196
+ model .compile (optimizer = optimizer , loss = [ 'categorical_crossentropy' ,'categorical_crossentropy' , 'mae' , 'mae' ], loss_weights = [ 1 , 1 , 0 , 0 ])
221
197
#
222
198
# # call backs
223
199
# class LRSetting(Callback):
@@ -228,19 +204,19 @@ def QANet(word_dim=300, char_dim=64, cont_limit=400, ques_limit=50, char_limit=1
228
204
# check_point = ModelCheckpoint('model/QANetv02.h5', monitor='val_loss', verbose=0, save_best_only=True,save_weights_only=True, mode='auto', period=1)
229
205
# early_stop = EarlyStopping(monitor='val_loss', patience=10, verbose=1, mode='auto')
230
206
#
231
- # # load data
232
- # char_dim=200
233
- # cont_limit=400
234
- # ques_limit=50
235
- # char_limit=16
236
- #
237
- # context_word = np.random.randint(0, 10000, (300, cont_limit))
238
- # question_word = np.random.randint(0, 10000, (300, ques_limit))
239
- # context_char = np.random.randint(0, 96, (300, cont_limit, char_limit))
240
- # question_char = np.random.randint(0, 96, (300, ques_limit, char_limit))
241
- # context_length = np.random.randint(5, cont_limit , (300, 1 ))
242
- # question_length = np.random.randint(5, ques_limit , (300, 1 ))
243
- # start_label = np.random.randint(0, 2, (300, cont_limit) )
244
- # end_label = np.random.randint(0, 2, (300, cont_limit) )
245
- #
246
- # model.fit([context_word,question_word,context_char,question_char,context_length,question_length ],[start_label,end_label],batch_size=8)
207
+ # load data
208
+ char_dim = 64
209
+ cont_limit = 400
210
+ ques_limit = 50
211
+ char_limit = 16
212
+
213
+ context_word = np .random .randint (0 , 10000 , (300 , cont_limit ))
214
+ question_word = np .random .randint (0 , 10000 , (300 , ques_limit ))
215
+ context_char = np .random .randint (0 , 96 , (300 , cont_limit , char_limit ))
216
+ question_char = np .random .randint (0 , 96 , (300 , ques_limit , char_limit ))
217
+ start_label = np .random .randint (0 , 2 , (300 , cont_limit ))
218
+ end_label = np .random .randint (0 , 2 , (300 , cont_limit ))
219
+ start_label_fin = np .argmax ( start_label , axis = - 1 )
220
+ end_label_fin = np .argmax ( end_label , axis = - 1 )
221
+
222
+ model .fit ([context_word ,question_word ,context_char ,question_char ],[start_label , end_label , start_label_fin , end_label_fin ],batch_size = 8 )
0 commit comments