Simplified PyTorch implementation of image classification, support multi-gpu training and validating, automatic mixed precision training, knowledge distillation, hyperparameter optimization using Optuna, and different datasets, like CIFAR10, MNIST etc.
torch == 1.8.1
torchvision
torchmetrics == 1.2.0
albumentations
loguru
tqdm
timm == 0.6.12 (optional)
optuna == 4.0.0 (optional)
optuna-integration == 4.0.0 (optional)
- ResNets 1
- MobileNetV2 2
- timm 3
This repo provides modified ResNets and MobileNetV2 if you want to train datasets of small-resolution images, e.g. CIFAR10 (32x32) or MNIST (28x28). You can also train datasets of normal-size images like ImageNet using this repo. Besides ResNets and MobileNetV2, you may also refer to timm3 which provides hundereds of pretrained models. For example, if you want to train mobilenetv3_small
from timm, you may change the config file to
config.model = 'timm'
config.timm_model = 'mobilenetv3_small_100'
or use command-line arguments
python main.py --model timm --timm_model mobilenetv3_small_100
Details of the configurations can also be found in this file.
Since most timm models are downsampled 32 times, to retain more details and gain better performances, you may need to modify the downsampling rates of timm model if you want to train datasets of small-resolution images.
If you want to test other datasets from torchvision, you may refer to this site. Noted that this site is outdated since the version of torchvision(0.9.1) is bounded to torch(1.8.1). If you want to test datasets from newer version of torchvision, you need to update this codebase to be compatible with newer torch. You can also download the image files and build your own dataset following the style of Custom
dataset if you don't want to update the codebase.
Currently only support the original knowledge distillation method proposed by Geoffrey Hinton.7
This repo provides batch-wise mixup augmentation.8 You may control the probability of mixup through this parameter config.mixup
. If you want to perform mixup for individual images, you may need to implement yourself.
This repo also support hyperparameter optimization using Optuna.9 For example, if you want to search hyperparameters for CIFAR10 dataset using MobileNetv2, you may simply run
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 optuna_search.py
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py
Model | pretrained | kd | optuna | mixup | Epoch | Accuracy(%) |
---|---|---|---|---|---|---|
ResNet50 | 1.0 | 200 | 95.99 | |||
ResNet50 | 1.0 | 400 | 96.62 (teacher) | |||
ResNet18 | 1.0 | 200 | 95.34 (base) | |||
ResNet18 | 0.0 | 200 | 94.25 ⬇️ | |||
ResNet18 | 0.5 | 200 | 95.08 ⬇️ | |||
ResNet18 | ✅ | 1.0 | 200 | 95.91 ⬆️ | ||
ResNet18 | ✅ | 1.0 | 400 | 95.95 ⬆️ | ||
ResNet18 | ✅ | 1.0 | 200 | 95.69 ⬆️ | ||
ResNet18 | 1.0 | 400 | 96.03 ⬆️ | |||
ResNet18 | ✅ | 1.0 | 400 | 96.12 ⬆️ | ||
MobileNetv2 | 1.0 | 200 | 94.88 (base) | |||
MobileNetv2 | ✅ | 1.0 | 200 | 95.21 ⬆️ | ||
MobileNetv2 | ✅ | 1.0 | 400 | 95.37 ⬆️ | ||
MobileNetv2 | ✅ | 1.0 | 200 | 94.83 ⬇️ | ||
MobileNetv2 | 1.0 | 400 | 95.29 ⬆️ | |||
MobileNetv2 | ✅ | 1.0 | 400 | 95.12 ⬆️ | ||
MobileNetv2 | - | - | config | - | 100 | 96.39 ⬆️ |
Model | pretrained | kd | optuna | mixup | Epoch | Accuracy(%) |
---|---|---|---|---|---|---|
ResNet50 | 1.0 | 400 | 79.52 (teacher) | |||
ResNet18 | 1.0 | 200 | 75.68 (base) | |||
ResNet18 | ✅ | 1.0 | 200 | 78.89 ⬆️ | ||
ResNet18 | ✅ | 1.0 | 400 | 78.56 ⬇️ | ||
ResNet18 | ✅ | 1.0 | 200 | 75.82 ⬆️ | ||
ResNet18 | 1.0 | 400 | 76.53 ⬆️ | |||
ResNet18 | ✅ | 1.0 | 400 | 76.85 ⬆️ | ||
MobileNetv2 | 1.0 | 200 | 76.90 (base) | |||
MobileNetv2 | ✅ | 1.0 | 200 | 78.41 ⬆️ | ||
MobileNetv2 | ✅ | 1.0 | 400 | 78.37 ⬇️ | ||
MobileNetv2 | ✅ | 1.0 | 200 | 76.81 ⬇️ | ||
MobileNetv2 | 1.0 | 400 | 77.30 ⬆️ | |||
MobileNetv2 | ✅ | 1.0 | 400 | 77.85 ⬆️ | ||
MobileNetv2 | - | - | config | - | 100 | 82.01 ⬆️ |
Model | pretrained | optuna | h_flip | mixup | Epoch | Accuracy(%) |
---|---|---|---|---|---|---|
ResNet18 | 0.5 | 1.0 | 200 | 99.65 (base) | ||
ResNet18 | 0.0 | 1.0 | 200 | 99.65 | ||
ResNet18 | ✅ | 0.0 | 1.0 | 200 | 99.65 | |
ResNet18 | ✅ | 0.5 | 1.0 | 200 | 99.68 ⬆️ | |
ResNet18 | 0.0 | 1.0 | 400 | 99.67 ⬆️ | ||
ResNet18 | 0.5 | 1.0 | 400 | 99.69 ⬆️ | ||
MobileNetv2 | 0.5 | 1.0 | 200 | 99.67 (base) | ||
MobileNetv2 | 0.0 | 1.0 | 200 | 99.64 ⬇️ | ||
MobileNetv2 | ✅ | 0.0 | 1.0 | 200 | 99.68 ⬆️ | |
MobileNetv2 | ✅ | 0.5 | 1.0 | 200 | 99.62 ⬇️ | |
MobileNetv2 | 0.0 | 1.0 | 400 | 99.64 ⬇️ | ||
MobileNetv2 | 0.5 | 1.0 | 400 | 99.65 ⬇️ | ||
MobileNetv2 | - | config | - | - | 100 | 99.73 ⬆️ |
Model | pretrained | optuna | h_flip | mixup | Epoch | Accuracy(%) |
---|---|---|---|---|---|---|
ResNet18 | 0.5 | 1.0 | 200 | 94.33 (base) | ||
ResNet18 | 0.0 | 1.0 | 200 | 94.30 ⬇️ | ||
ResNet18 | ✅ | 0.0 | 1.0 | 200 | 94.59 ⬆️ | |
ResNet18 | ✅ | 0.5 | 1.0 | 200 | 94.55 ⬆️ | |
ResNet18 | 0.0 | 1.0 | 400 | 94.20 ⬇️ | ||
ResNet18 | 0.5 | 1.0 | 400 | 94.41 ⬆️ | ||
MobileNetv2 | 0.5 | 1.0 | 200 | 94.81 (base) | ||
MobileNetv2 | 0.0 | 1.0 | 200 | 94.96 ⬆️ | ||
MobileNetv2 | ✅ | 0.0 | 1.0 | 200 | 95.28 ⬆️ | |
MobileNetv2 | ✅ | 0.5 | 1.0 | 200 | 95.20 ⬆️ | |
MobileNetv2 | 0.0 | 1.0 | 400 | 95.05 ⬆️ | ||
MobileNetv2 | 0.5 | 1.0 | 400 | 95.21 ⬆️ | ||
MobileNetv2 | - | config | - | - | 100 | 95.53 ⬆️ |