Skip to content

Commit 127d6e5

Browse files
committed
bug fix
1 parent 09cddcb commit 127d6e5

File tree

6 files changed

+296
-146
lines changed

6 files changed

+296
-146
lines changed

NetModel/QANet_keras.py

+23-47
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from KerasLayer.context2query_attention import context2query_attention
55
from KerasLayer.multihead_attention import Attention as MultiHeadAttention
66
from KerasLayer.position_embedding import Position_Embedding as PositionEmbedding
7-
from keras import layers
7+
from KerasLayer.QAoutputBlock import QAoutputBlock
88
from keras.optimizers import *
99
from keras.callbacks import *
1010
from KerasLayer.layer_dropout import LayerDropout
@@ -63,16 +63,8 @@ def feed_forward_block(FeedForward_layers, x, dropout=0.0, l=1., L=1.):
6363
x = LayerDropout(dropout * (l / L))([x, residual])
6464
return x
6565

66-
def output_block(x1, x2, ans_limit=50):
67-
outer = tf.matmul(tf.expand_dims(x1, axis=2), tf.expand_dims(x2, axis=1))
68-
outer = tf.matrix_band_part(outer, 0, ans_limit)
69-
output1 = tf.argmax(tf.reduce_max(outer, axis=2), axis=1)
70-
output2 = tf.argmax(tf.reduce_max(outer, axis=1), axis=1)
71-
return [output1, output2]
72-
73-
7466
def QANet(word_dim=300, char_dim=64, cont_limit=400, ques_limit=50, char_limit=16, word_mat=None, char_mat=None,
75-
char_input_size=1000, filters=128, num_head=8, dropout=0.1, train=True, ans_limit=30):
67+
char_input_size=1000, filters=128, num_head=8, dropout=0.1, ans_limit=30):
7668
# Input Embedding Layer
7769
contw_input = Input((cont_limit,))
7870
quesw_input = Input((ques_limit,))
@@ -85,17 +77,6 @@ def QANet(word_dim=300, char_dim=64, cont_limit=400, ques_limit=50, char_limit=1
8577
cont_len = Lambda(lambda x: tf.expand_dims(tf.reduce_sum(tf.cast(x, tf.int32), axis=1), axis=1))(c_mask)
8678
ques_len = Lambda(lambda x: tf.expand_dims(tf.reduce_sum(tf.cast(x, tf.int32), axis=1), axis=1))(q_mask)
8779

88-
# # slice
89-
# c_maxlen = tf.reduce_max(cont_len)
90-
# q_maxlen = tf.reduce_max(ques_len)
91-
# contw_input_ = Lambda(lambda x:tf.slice(x, [0, 0], [-1, c_maxlen]))(contw_input)
92-
# quesw_input_ = Lambda(lambda x:tf.slice(x, [0, 0], [-1, q_maxlen]))(quesw_input)
93-
# c_mask_ = Lambda(lambda x:tf.slice(x, [0, 0], [-1, c_maxlen]))(c_mask)
94-
# q_mask_ = Lambda(lambda x:tf.slice(x, [0, 0], [-1, q_maxlen]))(q_mask)
95-
# contc_input_ = Lambda(lambda x:tf.slice(x, [0, 0, 0], [-1,c_maxlen, char_limit]))(contc_input)
96-
# quesc_input_ = Lambda(lambda x:tf.slice(x, [0, 0, 0], [-1,q_maxlen, char_limit]))(quesc_input)
97-
# print(contw_input_,quesw_input_)
98-
9980
# embedding word
10081
WordEmbedding = Embedding(word_mat.shape[0], word_dim, weights=[word_mat], mask_zero=False, trainable=False)
10182
xw_cont = WordEmbedding(contw_input)
@@ -203,21 +184,16 @@ def QANet(word_dim=300, char_dim=64, cont_limit=400, ques_limit=50, char_limit=1
203184
x_end = Lambda(lambda x: mask_logits(x[0], x[1], axis=0, time_dim=1))([x_end, cont_len])
204185
x_end = Lambda(lambda x: K.softmax(x), name='end')(x_end)
205186

206-
if train:
207-
return Model(inputs=[contw_input, quesw_input, contc_input, quesc_input],
208-
outputs=[x_start, x_end])
209-
else:
210-
x_final = Lambda(lambda x: output_block(x[0], x[1], ans_limit))([x_start, x_end])
211-
return Model(inputs=[contw_input, quesw_input, contc_input, quesc_input],
212-
outputs=x_final)
187+
x_start_fin, x_end_fin = QAoutputBlock(ans_limit)([x_start,x_end])
188+
return Model(inputs=[contw_input, quesw_input, contc_input, quesc_input], outputs=[x_start, x_end, x_start_fin, x_end_fin])
213189

214190
embedding_matrix = np.random.random((10000,300))
215191
embedding_matrix_char = np.random.random((1000,64))
216192
model=QANet(word_mat=embedding_matrix,char_mat=embedding_matrix_char)
217-
model.summary()
193+
# model.summary()
218194

219-
# optimizer=Adam(lr=0.001,beta_1=0.8,beta_2=0.999,epsilon=1e-7)
220-
# model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['acc'])
195+
optimizer=Adam(lr=0.001,beta_1=0.8,beta_2=0.999,epsilon=1e-7)
196+
model.compile(optimizer=optimizer, loss=['categorical_crossentropy','categorical_crossentropy','mae','mae'], loss_weights=[1, 1, 0, 0])
221197
#
222198
# # call backs
223199
# class LRSetting(Callback):
@@ -228,19 +204,19 @@ def QANet(word_dim=300, char_dim=64, cont_limit=400, ques_limit=50, char_limit=1
228204
# check_point = ModelCheckpoint('model/QANetv02.h5', monitor='val_loss', verbose=0, save_best_only=True,save_weights_only=True, mode='auto', period=1)
229205
# early_stop = EarlyStopping(monitor='val_loss', patience=10, verbose=1, mode='auto')
230206
#
231-
# # load data
232-
# char_dim=200
233-
# cont_limit=400
234-
# ques_limit=50
235-
# char_limit=16
236-
#
237-
# context_word = np.random.randint(0, 10000, (300, cont_limit))
238-
# question_word = np.random.randint(0, 10000, (300, ques_limit))
239-
# context_char = np.random.randint(0, 96, (300, cont_limit, char_limit))
240-
# question_char = np.random.randint(0, 96, (300, ques_limit, char_limit))
241-
# context_length = np.random.randint(5, cont_limit, (300, 1))
242-
# question_length = np.random.randint(5, ques_limit, (300, 1))
243-
# start_label = np.random.randint(0, 2, (300, cont_limit))
244-
# end_label = np.random.randint(0, 2, (300, cont_limit))
245-
#
246-
# model.fit([context_word,question_word,context_char,question_char,context_length,question_length],[start_label,end_label],batch_size=8)
207+
# load data
208+
char_dim=64
209+
cont_limit=400
210+
ques_limit=50
211+
char_limit=16
212+
213+
context_word = np.random.randint(0, 10000, (300, cont_limit))
214+
question_word = np.random.randint(0, 10000, (300, ques_limit))
215+
context_char = np.random.randint(0, 96, (300, cont_limit, char_limit))
216+
question_char = np.random.randint(0, 96, (300, ques_limit, char_limit))
217+
start_label = np.random.randint(0, 2, (300, cont_limit))
218+
end_label = np.random.randint(0, 2, (300, cont_limit))
219+
start_label_fin = np.argmax(start_label,axis=-1)
220+
end_label_fin = np.argmax(end_label,axis=-1)
221+
222+
model.fit([context_word,question_word,context_char,question_char],[start_label, end_label, start_label_fin, end_label_fin],batch_size=8)

0 commit comments

Comments
 (0)