Skip to content

Commit

Permalink
修复keras下搭建孪生网络可能存在的bug
Browse files Browse the repository at this point in the history
这是keras的bug!
  • Loading branch information
bojone authored Dec 15, 2020
1 parent 6fa4354 commit b0fe7a7
Showing 1 changed file with 30 additions and 2 deletions.
32 changes: 30 additions & 2 deletions bert4keras/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
import tensorflow as tf
from bert4keras.backend import keras, K
from bert4keras.backend import keras, K, is_tf_keras
from bert4keras.backend import sequence_masking
from bert4keras.backend import recompute_grad
from keras import initializers, activations
Expand All @@ -28,7 +28,7 @@ def new_func(self, input_shape):
return new_func


if keras.__version__[-2:] != 'tf' and keras.__version__ < '2.3':
if (not is_tf_keras) and keras.__version__ < '2.3':

class Layer(keras.layers.Layer):
"""重新定义Layer,赋予“层中层”功能
Expand Down Expand Up @@ -91,6 +91,34 @@ def __init__(self, **kwargs):
self.supports_masking = True # 本项目的自定义层均可mask


if (not is_tf_keras) or tf.__version__ < '1.15':

if not is_tf_keras:
NodeBase = keras.engine.base_layer.Node
else:
from tensorflow.python.keras.engine import base_layer
NodeBase = base_layer.Node

class Node(NodeBase):
"""修改Node来修复keras下孪生网络的bug
注意:这是keras的bug,并不是bert4keras的bug,但keras已经不更新了,
所以只好在这里进行修改。tf 1.15+自带的keras已经修改了这个
bug。
"""
@property
def arguments(self):
return self._arguments.copy()

@arguments.setter
def arguments(self, value):
self._arguments = value

if not is_tf_keras:
keras.engine.base_layer.Node = Node
else:
base_layer.Node = Node


class Embedding(keras.layers.Embedding):
"""拓展Embedding层
"""
Expand Down

0 comments on commit b0fe7a7

Please sign in to comment.