Skip to content

Latest commit

 

History

History
188 lines (156 loc) · 9.06 KB

readme.md

File metadata and controls

188 lines (156 loc) · 9.06 KB

The Linear Attention Resurrection in Vision Transformer

This repository is the official PyTorch implementation of L2ViT from the paper:

The Linear Attention Resurrection in Vision Transformer

Introduction

L2ViT (Linear global attention and Local window attention Vision Transformer) integrates the enhanced linear attention and local window self-attention in an alternatively sequential way as following: arch The local window self-attention introduces locality and translational invariance that have been proven beneficial for vision tasks, making L2ViT better at modeling fine-grained and short-distance representations. Instead, linear global attention maintains long-range dependency and constructs a global context-rich representation from the whole image, providing a large effective receptive field. The alternative design mixes these complementary feature information and provides powerful modeling capacity with only linear complexity.

Results and Pre-trained Models

ImageNet-1K trained models

name resolution acc@1 #params FLOPs model log
L2ViT-T 224x224 83.1 29M 4.7G model log
L2ViT-S 224x224 84.1 50M 9.0G model log
L2ViT-B 224x224 84.4 89M 15.9G model log

ImageNet-22K trained models

name resolution acc@1 #params FLOPs 22k model log 1k model log
L2ViT-B 224x224 86.0 89M 15.9G model log model log
L2ViT-B 384x384 87.0 89M 47.5G - - model log

Usage

Installation

pip install -r ./classification/requirements.txt

Dataset Preparation

Download the ImageNet-1K classification dataset and structure the data as follows:

/path/to/imagenet-1k/
  train/
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  val/
    class1/
      img3.jpeg
    class2/
      img4.jpeg

Evaluation

L2ViT-B pre-trained on ImageNet-1K:

Single-GPU

python main.py --model L2ViT_Base --eval true \
--resume /path/to/pre-trained-checkpoint --input_size 224 \
--data_path /path/to/imagenet-1k --output_dir output_dir

Multi-GPU

python -m torch.distributed.launch --nproc_per_node=8 main.py \
--model L2ViT_Base --eval true \
--resume /path/to/pre-trained-checkpoint --input_size 224 \
--data_path /path/to/imagenet-1k --output_dir output_dir

This should give

* Acc@1 84.384 Acc@5 96.902 loss 0.802

Training

All models use multi-GPU setting with a total batch size of 4096 on ImageNet-1k and 1024 on ImageNet-22k.

Training from scratch on ImageNet-1k.
python -m torch.distributed.launch --nproc_per_node=8 main.py \
     --model L2ViT_Base --drop_path 0.3 --layer_scale_init_value 0 --batch_size 128 \
     --lr 4e-3 --update_freq 4 --epochs 300 --save_ckpt_freq=100 --use_amp=false \
     --model_ema true --model_ema_eval true --data_path /path/to/imagenet-1k \
     --output_dir output_dir
Training from scratch on ImageNet-22k.
python -m torch.distributed.launch --nproc_per_node=8 main.py \
    --drop_path 0.2 --warmup_epochs 5 --weight_decay 0.05 --min_lr 1e-5 --warmup_lr 1e-6 \
    --layer_scale_init_value 0 --batch_size 128 --lr 1e-3 --update_freq 1 --epochs 90 \
    --save_ckpt_freq=10 --use_amp=false --evaluate_freq=10 --data_set=IMNET22k \
    --data_path=/path/to/image-22k --output_dir output_dir"
Fine-tune from ImageNet-22K pre-training (224x224)
python -m torch.distributed.launch --nproc_per_node=8 main.py \
    --drop_path 0.2 --warmup_epochs 5 --weight_decay 1e-8 --min_lr 4e-7 --warmup_lr 4e-8 \
    --layer_scale_init_value 0 --batch_size 64 --lr 4e-5 --update_freq 2 --save_ckpt_freq=10 \
    --epochs 30 --use_amp=false --model_ema true --model_ema_eval true \
    --data_path /path/to/iamgenet-1k --finetune /path/to/pre-trained-model \
    --output_dir output_dir
Fine-tune from ImageNet-22K pre-training (384x384)
python -m torch.distributed.launch --nproc_per_node=8 main.py \
    --input_size 384 --drop_path 0.2 --warmup_epochs 5 --weight_decay 1e-8 --min_lr 4e-7 \
    --warmup_lr 4e-8 --layer_scale_init_value 0 --batch_size 64 --lr 4e-5 --update_freq 2 \
    --save_ckpt_freq=10 --epochs 30 --use_amp=false --model_ema true --model_ema_eval true \
    --data_path /path/to/iamgenet-1k --finetune /path/to/pre-trained-model \
    --output_dir output_dir

Note: our ImageNet-22k dataset organization is different from other repository such as ConvNeXt, different ImageNet-22k organization causes different category mapping of 22k to 1k. And our code use the custom mapping in our repository by default. So

  • If you train from scratch on ImageNet-22k, you need add your own ImageNet-22k dataset class in datasets.py.

  • If you fine-tune from your own pre-trained ImageNet-22k model, you need to change the mapping file in main.py to ensure the category mapping from 22k to 1k is right.

Object detection and semantic segmentation

Follow the official OpenMMLab MMCV 1.4.8

OpenMMLab mmdetection 2.23.0+a86720d

OpenMMLab mmsegementation 0.26.0+01ad6bb for installation. you may choose most recent version.

After installation, you need copy the following files into mmdetection/mmsegmentation directory.

  1. Put object_detection/configs/* into path_to_mmdetection/configs/l2vit/
  2. Put object_detection/mmcv_custom into path_to_mmdetection
  3. Put object_detection/mmdet/models/backbones/* into path_to_mmdetection/mmdet/models/backbones/, then add L2ViT into path_to_mmdetection/mmdet/models/backbones/__init__.py as following:
...
from .l2vit import L2ViT
__all__ = [..., 'L2ViT']
  1. Finally, load our pre-trained checkpoint for training.

The mmseg is similar to mmdet, then you can train detection and segmentation models using our L2ViT as backbone.

Acknowledgement

Our repository is built on the ConvNeXt, timm library. We sincerely thank the authors for the nicely organized code!

License

This project is released under the MIT license. Please see the LICENSE file for more information.

Citation

If you find this repository helpful, please consider citing:

@Article{
}