Skip to content

Commit 0fdc147

Browse files
Fix bugs on saving model.
1 parent 574a95a commit 0fdc147

File tree

3 files changed

+19
-5
lines changed

3 files changed

+19
-5
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11

22
# PVT-tensorflow2
3+
[![Python 3.7](https://img.shields.io/badge/Python-3.7-3776AB)](https://www.python.org/downloads/release/python-360/)
4+
[![TensorFlow 2.4](https://img.shields.io/badge/TensorFlow-2.4-FF6F00?logo=tensorflow)](https://github.com/tensorflow/tensorflow/releases/tag/v2.2.0)
5+
36
A Tensorflow2.x implementation of Pyramid Vision Transformer as described in [Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions](https://arxiv.org/abs/2102.12122)
47

58
## Update Log
9+
[2021-06-29]
10+
* Fix bug on saving model
11+
612
[2021-03-20]
713
* Add PVT-tiny,PVT-small,PVT-medium,PVT-large.
814

model/PVT.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,14 @@ def build(self, input_shape):
101101
def call(self, x):
102102
return x+self.pos_embed
103103

104+
def get_config(self):
105+
106+
config = super().get_config().copy()
107+
config.update({
108+
'img_len': self.img_len,
109+
})
110+
return config
111+
104112
def get_pvt(img_size,num_classes,block_depth,mlp_ratio,drop_path_rate,first_level_patch_size,embed_dims,num_heads,sr_ratio,attention_drop_rate,drop_rate):
105113
block_drop_path_rate = np.linspace(0, drop_path_rate, sum(block_depth))
106114
block_depth_index = 0

train.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,18 @@ def main(args):
4949
os.makedirs(args.checkpoints)
5050
# lr_cb = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=20, verbose=1, min_lr=0)
5151
lr_cb = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)
52-
model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(filepath=args.checkpoints+'/best_weight_{epoch}_{accuracy:.3f}_{val_accuracy:.3f}',
53-
monitor='val_accuracy',mode='max',
54-
verbose=1,save_best_only=True)
52+
model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(filepath=args.checkpoints+'/best_weight_{epoch}_{accuracy:.3f}_{val_accuracy:.3f}.h5',
53+
monitor='val_accuracy',mode='max',
54+
verbose=1,save_best_only=True,save_weights_only=True)
5555
cbs=[lr_cb,
56-
# model_checkpoint_cb
56+
model_checkpoint_cb
5757
]
5858
model.compile(optimizer,loss_object,metrics=["accuracy"],)
5959
model.fit(train_generator,
6060
validation_data=val_generator,
6161
epochs=args.epochs,
6262
callbacks=cbs,
63-
verbose=2,
63+
verbose=1,
6464
)
6565

6666
if __name__== "__main__":

0 commit comments

Comments
 (0)