Cross-Covariance Image Transformer (XCiT)

PyTorch implementation and pretrained models for XCiT models. See XCiT: Cross-Covariance Image Transformer

Linear complexity in time and memory

Our XCiT models has a linear complexity w.r.t number of patches/tokens: ![peak_mem_xcit](https://render.githubusercontent.com/render/math?math=\mathcal{O}(N d ^2)“>

![ims_xcit](https://pythonawesome.com/content/images/2021/06/ims_xcit.png)
Peak Memory (inference)Millisecond/Image (Inference)

Scaling to high resolution inputs

XCiT can scale to high resolution inputs both due to cheaper compute requirement as well as better adaptability to higher resolution at test time (see Figure 3 in the paper)

Detection and Instance Segmentation for Ultra high resolution images (6000x4000)

122474488-962c7380-cfc3-11eb-9e9c-51beda07740b

XCiT+DINO: High Res. Self-Attention Visualization :t-rex:

Our XCiT models with self-supervised training using DINO can obtain high resolution attention maps.

Self-Attention visualization per head

Below we show the attention maps for each of the 8 heads separately and we can observe that every head specializes in different semantic aspects of the scene for the foreground as well as the background.


Getting Started

First, clone the repo

git clone https://github.com/facebookresearch/XCiT.git

Then, you can install the required packages including: Pytorch version 1.7.1, torchvision version 0.8.2 and Timm version 0.4.8

pip install -r requirements.txt

Download and extract the ImageNet dataset. Afterwards, set the --data-path argument to the corresponding extracted ImageNet path.

For full details about all the available arguments, you can use

python main.py --help

For detection and segmentation downstream tasks, please check:

COCO Object detection and Instance segmentation: XCiT Detection

ADE20k Semantic segmentation: XCiT Semantic Segmentation


Model Zoo

We provide XCiT models pre-trained weights on ImageNet-1k.

§: distillation

Models with 16x16 patch size

ArchparamsModel
224224 §384 §
top-1weightstop-1weightstop-1weights
xcit_nano_12_p163M69.9%download72.2%download75.4%download
xcit_tiny_12_p167M77.1%download78.6%download80.9%download
xcit_tiny_24_p1612M79.4%download80.4%download82.6%download
xcit_small_12_p1626M82.0%download83.3%download84.7%download
xcit_small_24_p1648M82.6%download83.9%download85.1%download
xcit_medium_24_p1684M82.7%download84.3%download85.4%download
xcit_large_24_p16189M82.9%download84.9%download85.8%download

Models with 8x8 patch size

ArchparamsModel
224224 §384 §
top-1weightstop-1weightstop-1weights
xcit_nano_12_p83M73.8%download76.3%download77.8%download
xcit_tiny_12_p87M79.7%download81.2%download82.4%download
xcit_tiny_24_p812M81.9%download82.6%download83.7%download
xcit_small_12_p826M83.4%download84.2%download85.1%download
xcit_small_24_p848M83.9%download84.9%download85.6%download
xcit_medium_24_p884M83.7%download85.1%download85.8%download
xcit_large_24_p8189M84.4%download85.4%download86.0%download

XCiT + DINO Self-supervised models

Archparamsk-nnlineardownload
xcit_small_12_p1626M76.0%77.8%backbone
xcit_small_12_p826M77.1%79.2%backbone
xcit_medium_24_p1684M76.4%78.8%backbone
xcit_medium_24_p884M77.9%80.3%backbone

Training

For training using a single node, use the following command

python -m torch.distributed.launch --nproc_per_node=[NUM_GPUS] --use_env main.py --model [MODEL_KEY] --batch-size [BATCH_SIZE] --drop-path [STOCHASTIC_DEPTH_RATIO] --output_dir [OUTPUT_PATH]

For example, the XCiT-S12/16 model can be trained using the following command

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model xcit_small_12_p16 --batch-size 128 --drop-path 0.05 --output_dir /experiments/xcit_small_12_p16/ --epochs [NUM_EPOCHS]

For multinode training via SLURM you can alternatively use

python run_with_submitit.py --partition [PARTITION_NAME] --nodes 2 --ngpus 8 --model xcit_small_12_p16 --batch-size 64 --drop-path 0.05 --job_dir /experiments/xcit_small_12_p16/ --epochs 400

More details for the hyper-parameters used to train the different models can be found in Table B.1 in the paper.

Evaluation

To evaluate an XCiT model using the checkpoints above or models you trained use the following command:

python main.py --eval --model <MODEL_KEY> --input-size <IMG_SIZE> [--full_crop] --pretrained <PATH/URL>

By default we use the --full_crop flag which evaluates the model with a crop ratio of 1.0 instead of 0.875 following CaiT.

For example, the command to evaluate the XCiT-S12/16 using 224x224 images:

python main.py --eval --model xcit_small_12_p16 --input-size 384 --full_crop --pretrained https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224.pth

Acknowledgement

This repository is built using the Timm library and the DeiT repository. The self-supervised training is based on the DINO repository.

GitHub

https://github.com/facebookresearch/xcit

Source: https://pythonawesome.com/pytorch-implementation-and-pretrained-models-for-xcit-models/