Simplified PyTorch implementation of audio classification, support multi-gpu training and validating, automatic mixed precision training, knowledge distillation etc.
torch == 1.8.1
torchvision
torchaudio
torchmetrics == 1.2.0
loguru
tqdm
timm == 0.6.12 (optional)
This repo also supports timm4 which provides hundereds of pretrained models. For example, if you want to train mobilenetv3_small
from timm, you may change the config file to
config.model = 'timm'
config.timm_model = 'mobilenetv3_small_100'
or use command-line arguments
python main.py --model timm --timm_model mobilenetv3_small_100
Details of the configurations can also be found in this file.
If you want to test datasets from torchaudio, you may refer to this site. Noted that this site is outdated since the version of torchaudio(0.9.1) is bounded to torch(1.8.1). If you want to test datasets from newer version of torchaudio, you need to update this codebase to be compatible with newer torch. You can also download the audio files and build your own dataset following the style of ESC50
dataset if you don't want to update the codebase.
Currently only support the original knowledge distillation method proposed by Geoffrey Hinton.6
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py
Model | pretrained | fold-1 acc(%) | fold-2 acc(%) | fold-3 acc(%) | fold-4 acc(%) | fold-5 acc(%) | paper acc(%) | Mean Accuracy(%) |
---|---|---|---|---|---|---|---|---|
L3Net | 81.25 | 80.75 | 78.50 | 82.50 | 81.50 | 79.3 | 80.90 | |
ResNet18 | 73.25 | 74.75 | 74.00 | 75.25 | 73.25 | n.a. | 74.10 | |
ResNet18 | ✅ | 85.50 | 85.50 | 86.50 | 88.00 | 84.75 | n.a. | 86.05 |
MobileNetv2 | 76.25 | 78.00 | 74.00 | 77.75 | 69.75 | n.a. | 75.15 | |
MobileNetv2 | ✅ | 90.00 | 87.25 | 87.75 | 88.75 | 88.50 | n.a. | 88.45 |