diff --git a/bert4keras/layers.py b/bert4keras/layers.py index 835229d7..e2b95cb5 100644 --- a/bert4keras/layers.py +++ b/bert4keras/layers.py @@ -3,7 +3,7 @@ import numpy as np import tensorflow as tf -from bert4keras.backend import keras, K +from bert4keras.backend import keras, K, is_tf_keras from bert4keras.backend import sequence_masking from bert4keras.backend import recompute_grad from keras import initializers, activations @@ -28,7 +28,7 @@ def new_func(self, input_shape): return new_func -if keras.__version__[-2:] != 'tf' and keras.__version__ < '2.3': +if (not is_tf_keras) and keras.__version__ < '2.3': class Layer(keras.layers.Layer): """重新定义Layer,赋予“层中层”功能 @@ -91,6 +91,34 @@ def __init__(self, **kwargs): self.supports_masking = True # 本项目的自定义层均可mask +if (not is_tf_keras) or tf.__version__ < '1.15': + + if not is_tf_keras: + NodeBase = keras.engine.base_layer.Node + else: + from tensorflow.python.keras.engine import base_layer + NodeBase = base_layer.Node + + class Node(NodeBase): + """修改Node来修复keras下孪生网络的bug + 注意:这是keras的bug,并不是bert4keras的bug,但keras已经不更新了, + 所以只好在这里进行修改。tf 1.15+自带的keras已经修改了这个 + bug。 + """ + @property + def arguments(self): + return self._arguments.copy() + + @arguments.setter + def arguments(self, value): + self._arguments = value + + if not is_tf_keras: + keras.engine.base_layer.Node = Node + else: + base_layer.Node = Node + + class Embedding(keras.layers.Embedding): """拓展Embedding层 """