Skip to content

Simplified PyTorch implementation of image classification, support CIFAR10, CIFAR100, MNIST, custom dataset, multi-gpu training and validating, automatic mixed precision training, knowledge distillation, hyperparameter optimization using Optuna etc.

License

Notifications You must be signed in to change notification settings

zh320/image-classification-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Introduction

Simplified PyTorch implementation of image classification, support multi-gpu training and validating, automatic mixed precision training, knowledge distillation, hyperparameter optimization using Optuna, and different datasets, like CIFAR10, MNIST etc.

Requirements

torch == 1.8.1
torchvision
torchmetrics == 1.2.0
albumentations
loguru
tqdm
timm == 0.6.12 (optional)
optuna == 4.0.0 (optional)
optuna-integration == 4.0.0 (optional)

Supported models

This repo provides modified ResNets and MobileNetV2 if you want to train datasets of small-resolution images, e.g. CIFAR10 (32x32) or MNIST (28x28). You can also train datasets of normal-size images like ImageNet using this repo. Besides ResNets and MobileNetV2, you may also refer to timm3 which provides hundereds of pretrained models. For example, if you want to train mobilenetv3_small from timm, you may change the config file to

config.model = 'timm'
config.timm_model = 'mobilenetv3_small_100'

or use command-line arguments

python main.py --model timm --timm_model mobilenetv3_small_100

Details of the configurations can also be found in this file.

Since most timm models are downsampled 32 times, to retain more details and gain better performances, you may need to modify the downsampling rates of timm model if you want to train datasets of small-resolution images.

Supported datasets

If you want to test other datasets from torchvision, you may refer to this site. Noted that this site is outdated since the version of torchvision(0.9.1) is bounded to torch(1.8.1). If you want to test datasets from newer version of torchvision, you need to update this codebase to be compatible with newer torch. You can also download the image files and build your own dataset following the style of Custom dataset if you don't want to update the codebase.

Knowledge Distillation

Currently only support the original knowledge distillation method proposed by Geoffrey Hinton.7

MixUp

This repo provides batch-wise mixup augmentation.8 You may control the probability of mixup through this parameter config.mixup. If you want to perform mixup for individual images, you may need to implement yourself.

Hyperparameter Optimization

This repo also support hyperparameter optimization using Optuna.9 For example, if you want to search hyperparameters for CIFAR10 dataset using MobileNetv2, you may simply run

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 optuna_search.py

How to use

DDP training (recommend)

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py

DP training

CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py

Performances

CIFAR10

Model pretrained kd optuna mixup Epoch Accuracy(%)
ResNet50 1.0 200 95.99
ResNet50 1.0 400 96.62 (teacher)
ResNet18 1.0 200 95.34 (base)
ResNet18 0.0 200 94.25 ⬇️
ResNet18 0.5 200 95.08 ⬇️
ResNet18 1.0 200 95.91 ⬆️
ResNet18 1.0 400 95.95 ⬆️
ResNet18 1.0 200 95.69 ⬆️
ResNet18 1.0 400 96.03 ⬆️
ResNet18 1.0 400 96.12 ⬆️
MobileNetv2 1.0 200 94.88 (base)
MobileNetv2 1.0 200 95.21 ⬆️
MobileNetv2 1.0 400 95.37 ⬆️
MobileNetv2 1.0 200 94.83 ⬇️
MobileNetv2 1.0 400 95.29 ⬆️
MobileNetv2 1.0 400 95.12 ⬆️
MobileNetv2 - - config - 100 96.39 ⬆️

CIFAR100

Model pretrained kd optuna mixup Epoch Accuracy(%)
ResNet50 1.0 400 79.52 (teacher)
ResNet18 1.0 200 75.68 (base)
ResNet18 1.0 200 78.89 ⬆️
ResNet18 1.0 400 78.56 ⬇️
ResNet18 1.0 200 75.82 ⬆️
ResNet18 1.0 400 76.53 ⬆️
ResNet18 1.0 400 76.85 ⬆️
MobileNetv2 1.0 200 76.90 (base)
MobileNetv2 1.0 200 78.41 ⬆️
MobileNetv2 1.0 400 78.37 ⬇️
MobileNetv2 1.0 200 76.81 ⬇️
MobileNetv2 1.0 400 77.30 ⬆️
MobileNetv2 1.0 400 77.85 ⬆️
MobileNetv2 - - config - 100 82.01 ⬆️

MNIST

Model pretrained optuna h_flip mixup Epoch Accuracy(%)
ResNet18 0.5 1.0 200 99.65 (base)
ResNet18 0.0 1.0 200 99.65
ResNet18 0.0 1.0 200 99.65
ResNet18 0.5 1.0 200 99.68 ⬆️
ResNet18 0.0 1.0 400 99.67 ⬆️
ResNet18 0.5 1.0 400 99.69 ⬆️
MobileNetv2 0.5 1.0 200 99.67 (base)
MobileNetv2 0.0 1.0 200 99.64 ⬇️
MobileNetv2 0.0 1.0 200 99.68 ⬆️
MobileNetv2 0.5 1.0 200 99.62 ⬇️
MobileNetv2 0.0 1.0 400 99.64 ⬇️
MobileNetv2 0.5 1.0 400 99.65 ⬇️
MobileNetv2 - config - - 100 99.73 ⬆️

Fashion-MNIST

Model pretrained optuna h_flip mixup Epoch Accuracy(%)
ResNet18 0.5 1.0 200 94.33 (base)
ResNet18 0.0 1.0 200 94.30 ⬇️
ResNet18 0.0 1.0 200 94.59 ⬆️
ResNet18 0.5 1.0 200 94.55 ⬆️
ResNet18 0.0 1.0 400 94.20 ⬇️
ResNet18 0.5 1.0 400 94.41 ⬆️
MobileNetv2 0.5 1.0 200 94.81 (base)
MobileNetv2 0.0 1.0 200 94.96 ⬆️
MobileNetv2 0.0 1.0 200 95.28 ⬆️
MobileNetv2 0.5 1.0 200 95.20 ⬆️
MobileNetv2 0.0 1.0 400 95.05 ⬆️
MobileNetv2 0.5 1.0 400 95.21 ⬆️
MobileNetv2 - config - - 100 95.53 ⬆️

References

Footnotes

  1. Deep Residual Learning for Image Recognition

  2. MobileNetV2: Inverted Residuals and Linear Bottlenecks

  3. PyTorch Image Models 2

  4. The CIFAR-10 dataset 2

  5. The MNIST database of handwritten digits

  6. Fashion MNIST

  7. Distilling the Knowledge in a Neural Network

  8. mixup: Beyond Empirical Risk Minimization

  9. Optuna: A hyperparameter optimization framework

About

Simplified PyTorch implementation of image classification, support CIFAR10, CIFAR100, MNIST, custom dataset, multi-gpu training and validating, automatic mixed precision training, knowledge distillation, hyperparameter optimization using Optuna etc.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages