From 9f6dd86900a1828dabb3186114867d294cec149a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8B=8F=E5=89=91=E6=9E=97=28Jianlin=20Su=29?= Date: Fri, 17 Jul 2020 21:19:35 +0800 Subject: [PATCH] Update task_iflytek_bert_of_theseus.py --- examples/task_iflytek_bert_of_theseus.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/task_iflytek_bert_of_theseus.py b/examples/task_iflytek_bert_of_theseus.py index 08619877..6e290004 100644 --- a/examples/task_iflytek_bert_of_theseus.py +++ b/examples/task_iflytek_bert_of_theseus.py @@ -149,7 +149,8 @@ def on_epoch_end(self, epoch, logs=None): predecessor = build_transformer_model( config_path=config_path, checkpoint_path=checkpoint_path, - return_keras_model=False + return_keras_model=False, + prefix='Predecessor-' ) # 加载预训练模型(3层) @@ -157,12 +158,10 @@ def on_epoch_end(self, epoch, logs=None): config_path=config_path, checkpoint_path=checkpoint_path, return_keras_model=False, - num_hidden_layers=3 + num_hidden_layers=3, + prefix='Successor-' ) -for layer in successor.model.layers: - layer.name = 'Successor-' + layer.name - # 判别模型 x_in = Input(shape=K.int_shape(predecessor.output)[1:]) x = Lambda(lambda x: x[:, 0])(x_in)