Skip to content

Commit 09cddcb

Browse files
committed
add layer_dropout
1 parent 22b8836 commit 09cddcb

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

KerasLayer/QAoutputBlock.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# ! -*- coding: utf-8 -*-
2+
from keras.engine.topology import Layer
3+
from keras.regularizers import *
4+
import tensorflow as tf
5+
import keras.backend as K
6+
7+
class QAoutputBlock(Layer):
8+
def __init__(self, ans_limit=30, **kwargs):
9+
self.ans_limit = ans_limit
10+
super(QAoutputBlock, self).__init__(**kwargs)
11+
12+
def build(self, input_shape):
13+
super(QAoutputBlock, self).build(input_shape)
14+
15+
def call(self, x, mask=None):
16+
x1 ,x2 = x
17+
outer = tf.matmul(tf.expand_dims(x1, axis=2), tf.expand_dims(x2, axis=1))
18+
outer = tf.matrix_band_part(outer, 0, self.ans_limit)
19+
output1 = tf.reshape(tf.cast(tf.argmax(tf.reduce_max(outer, axis=2), axis=1), tf.float32),(-1,1))
20+
output2 = tf.reshape(tf.cast(tf.argmax(tf.reduce_max(outer, axis=1), axis=1), tf.float32),(-1,1))
21+
22+
return [output1, output2]
23+
24+
def compute_output_shape(self, input_shape):
25+
return [(input_shape[0][0],1), (input_shape[0][0],1)]

KerasLayer/layer_dropout.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# ! -*- coding: utf-8 -*-
2+
from keras.engine.topology import Layer
3+
import tensorflow as tf
4+
import keras.backend as K
5+
6+
class LayerDropout(Layer):
7+
def __init__(self, dropout = 0.0, **kwargs):
8+
self.dropout = dropout
9+
super(LayerDropout, self).__init__(**kwargs)
10+
11+
def build(self, input_shape):
12+
super(LayerDropout, self).build(input_shape)
13+
14+
def call(self, x, mask=None, training=None):
15+
x, residual = x
16+
pred = tf.random_uniform([]) < self.dropout
17+
x_train = tf.cond(pred, lambda: residual, lambda: tf.nn.dropout(x, 1.0 - self.dropout) + residual)
18+
x_test = x + residual
19+
return K.in_train_phase(x_train, x_test, training=training)
20+
21+
def compute_output_shape(self, input_shape):
22+
return input_shape

0 commit comments

Comments
 (0)