Edoardo Debenedetti (ETH Zürich*), Vikash Sehwag (Princeton University) and Prateek Mittal (Princeton University)
*work done as a Master's student at EPFL
Repository for the paper A Light Recipe to Train Robust Vision Transformers. It contains both the code and the checkpoints of 22 models from the paper relative to 5 different datasets. The codebase is also capable of training models if they are part of, or can be integrated in timm
, on both GPUs and TPUs. It also supports any 3-channels image classification dataset which is either part of tfds
or torchvision
.
The paper was accepted at the first IEEE Conference on Secure and Trustworthy Machine Learning (SaTML)!
In this paper, we ask whether Vision Transformers (ViTs) can serve as an underlying architecture for improving the adversarial robustness of machine learning models against evasion attacks. While earlier works have focused on improving Convolutional Neural Networks, we show that also ViTs are highly suitable for adversarial training to achieve competitive performance. We achieve this objective using a custom adversarial training recipe, discovered using rigorous ablation studies on a subset of the ImageNet dataset. The canonical training recipe for ViTs recommends strong data augmentation, in part to compensate for the lack of vision inductive bias of attention modules, when compared to convolutions. We show that this recipe achieves suboptimal performance when used for adversarial training. In contrast, we find that omitting all heavy data augmentation, and adding some additional bag-of-tricks (
A large part of the codebase is based on timm
.
This repo works with:
- Python
3.8.10
, and will probably work with newer versions. torch==1.8.1
and1.10.1
, and it will probably work with PyTorch1.9.x
and newer versions.torchvision==0.9.1
and0.11.2
, and it will probably work with torchvision0.10.x
.- The other requirements are in
requirements.txt
, hence they can be installed withpip install -r requirements.txt
In case you want to use Weights and Biases, after installing the requisites, install wandb
with pip install wandb
, and run wandb login
.
In case you want to read or write your results to a Google Cloud Storage bucket (which is supported by this repo), install the gcloud
CLI, and login. Then you are ready to use GCS for both storing data and experiments results, as well as download checkpoints by using paths in the form of gs://bucket-name/path-to-dir-or-file
.
All these commands are meant to be run on TPU VMs with 8 TPU cores. They can be easily adapted to work on GPUs by using torch.distributed.launch
(and by removing the launch_xla.py --num-devices 8
part). The LR can be scaled as explained in the appendix of our paper (which follows DeiT's convention). More info about how to run the training script on TPUs and GPUs can be found in timm.bits
's README.
It's possible to train any model which is already part of timm
. In case a model is not part of timm
, can be registered to timm
's model store using the @register_model
decorator on a function which instantiates the model. You can take a look at how this can be done here.
Moreover, it's possible to use arbitrary loss functions, as long as they have the same structure as AdvTrainingLoss
and TRADESLoss
. Such loss function must be a nn.Module
, whose forward method takes as input the model, x, y, and the current epoch (used for the
To log the results to W&B it is enough to add the flag --log-wandb
. The W&B experiment will have the name passed to the --experiment
flag.
python launch_xla.py --num-devices 8 train.py $DATA_DIR --dataset $DATASET --experiment $EXPERIMENT --output $OUTPUT --model $MODEL --config configs/xcit-adv-training.yaml
python launch_xla.py --num-devices 8 train.py $DATA_DIR --dataset $DATASET --experiment $EXPERIMENT --output $OUTPUT --model $MODEL --config configs/xcit-adv-training.yaml --attack-steps 2
python launch_xla.py --num-devices 8 train.py $DATA_DIR --dataset $DATASET --experiment $EXPERIMENT --output $OUTPUT --model $MODEL --config configs/xcit-adv-training.yaml --attack-steps 2 --attack-eps 8
python launch_xla.py --num-devices 8 train.py $DATA_DIR --dataset $DATASET --experiment $EXPERIMENT --output $OUTPUT --model $MODEL --config configs/xcit-adv-finetuning-hi-res.yaml --finetune $CHECKPOINT
python launch_xla.py --num-devices 8 train.py $DATA_DIR --dataset $DATASET --experiment $EXPERIMENT --output $OUTPUT --model $MODEL --config configs/xcit-adv-finetuning-hi-res.yaml --finetune $CHECKPOINT --mean 0.5 0.5 0.5 --std 0.5 0.5 0.5 --normalize-model
You should first download the dataset you want to finetune on with:
python3 -c "from torchvision.datasets import CIFAR10; CIFAR10('<download_dir>', download=True)"
python3 -c "from torchvision.datasets import CIFAR100; CIFAR100('<download_dir>', download=True)"
Then the command to start the training is:
python launch_xla.py --num-devices 8 train.py $DATA_DIR --dataset $DATASET --experiment $EXPERIMENT --output $OUTPUT --model $MODEL --config configs/xcit-adv-finetuning.yaml --finetune $CHECKPOINT --mean DATASET_MEAN --std DATASET_STD --normalize-model
The models are different, as we need to adapt the patch embedding layer to work on smaller resolutions. In particular, the models are:
- XCiT-S:
xcit_small_12_p4_32
- XCiT-M:
xcit_medium_12_p4_32
- XCiT-L:
xcit_large_12_p4_32
- ResNet-50:
resnet_50_32
Moreover, for CIFAR10 you should specify --mean 0.4914 0.4822 0.4465 --std 0.2471 0.2435 0.2616
, and for CIFAR100 you should specify --mean 0.5071 0.4867 0.4408 --std 0.2675 0.2565 0.2761
.
All the checkpoints can be found in this Google Drive folder.
Model | AutoAttack accuracy | Clean accuracy | Checkpoint | Model name |
---|---|---|---|---|
XCiT-S12 | 41.78 | 72.34 | link | xcit_small_12_p16_224 |
XCiT-M12 | 45.24 | 74.04 | link | xcit_medium_12_p16_224 |
XCiT-L12 | 47.60 | 73.76 | link | xcit_large_12_p16_224 |
ConvNeXt-T | 44.44 | 71.64 | link | convnext_tiny |
GELU ResNet-50 | 35.12 | 66.54 | link | resnet_50_gelu |
Model | AutoAttack accuracy | Clean accuracy | Checkpoint | Model name |
---|---|---|---|---|
XCiT-S12 | 25.00 | 63.46 | link | xcit_small_12_p16_224 |
XCiT-M12 | 26.58 | 67.80 | link | xcit_medium_12_p16_224 |
XCiT-L12 | 28.74 | 69.24 | link | xcit_large_12_p16_224 |
ConvNeXt-T | 27.98 | 65.96 | link | convnext_tiny |
GELU ResNet-50 | 17.15 | 58.08 | link | resnet_50_gelu |
Model | AutoAttack accuracy | Clean accuracy | Checkpoint | Model name |
---|---|---|---|---|
XCiT-S12 | 56.14 | 90.06 | link | xcit_small_12_p4_32 |
XCiT-M12 | 57.27 | 91.30 | link | xcit_medium_12_p4_32 |
XCiT-L12 | 57.58 | 91.73 | link | xcit_large_12_4_32 |
ResNet-50 | 41.56 | 84.80 | link | resnet_50_32 |
Model | AutoAttack accuracy | Clean accuracy | Checkpoint | Model name |
---|---|---|---|---|
XCiT-S12 | 32.19 | 67.34 | link | xcit_small_12_p4_32 |
XCiT-M12 | 34.21 | 69.21 | link | xcit_medium_12_p4_32 |
XCiT-L12 | 35.08 | 70.76 | link | xcit_large_12_4_32 |
ResNet-50 | 22.01 | 61.28 | link | resnet_50_32 |
Model | AutoAttack accuracy | Clean accuracy | Checkpoint | Model name |
---|---|---|---|---|
XCiT-S12 | 47.91 | 82.86 | link | xcit_small_12_p12_224 |
ResNet-50 | 32.75 | 74.51 | link | resnet_50 |
Model | AutoAttack accuracy | Clean accuracy | Checkpoint | Model name |
---|---|---|---|---|
XCiT-S12 | 61.74 | 87.59 | link | xcit_small_12_p12_224 |
ResNet-50 | 34.49 | 81.38 | link | resnet_50 |
For validating using full AA models trained on ImageNet, CIFAR-10 and CIFAR-100 it is recommended to use this command. To evaluate using APGD-CE only, or to evaluate other datasets than those above (e.g., Caltech101 and Oxford Flowers), then use this script instead.
This script will run the full AutoAttack using RobustBench's interface.
python3 validate_robustbench.py --data-dir $DATA_DIR --dataset $DATASET --model $MODEL --batch-size 1024 --checkpoint $CHECKPOINT --eps $EPS
If the model has been trained using a specific mean and std, then they should be specified with the --mean
and --std
flags, similarly to training.
Do not use this script to run APGD-CE or AutoAttack on TPU (and XLA in general), as the compilation will take an unreasonable amount of time.
python3 validate.py $DATA_DIR --dataset $DATASET --log-freq 1 --model $MODEL --checkpoint $CHECKPOINT --mean <mean> --std <std> --attack $ATTACK --attack-eps $EPS
If the model has been trained using a specific mean and std, then they should be specified with the --mean
and --std
flags, and the --normalize-model
flag should be specified, similarly to training. Otherwise the --no-normalize
flag sould be specified. For both Caltech101 and Oxford Flowers, you should specify --num-classes 102
, and for Caltech101 only --split test
. If you just want to run PGD, then you can specify the number of steps with --attack-steps 200
.
To reproduce the attack effectiveness experiment, you can run the attack_effectiveness.py
script. The results are written to a CSV file created in the same folder as that of the checkpoints that are tested. We process the CSV files generated with the attack_effectiveness.ipynb notebook. For instance, the validation of XCiT-S12 checkpoints, assuming that in the folder checkpoints
the checkpoints are in the form checkpoint-{epoch}-pth.tar
, the script can be launched as
DATA_DIR=... \ # Location of the TFDS data or the torch data
DATASET=tfds/robustbench_image_net \ # or any other dataset, both torch and tfds
MODEL=xcit_small_12_p16_224 \ # Or any other model
CHECKPOINTS_DIR=checkpoints \ # The checkpoint to validate
EPS=8 \ # The epsilon to use to evaluate
python3 attack_effectiveness.py $DATA_DIR --dataset $DATASET --model $MODEL --checkpoints-dir $CHECKPOINTS_DIR --no-normalize
This will run PGD with 1, 2, 5, 10, and 200 steps, for the checkpoints at the epochs 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, with three different seeds. The results will be written in checkpoints/summary.csv
.
The checkpoints for the models we test with this experiments (XCiT-S12, ResNet-50 GELU, and ConvNeXt-T) can be found at this Google Drive link.
A large amount of the code is adapted from timm
, in particular from the bits_and_tpu
branch. The code by Ross Wightman is originally released under Apache-2.0 License, which can be found here.
The entry point for training is train.py. While in src there is a bunch of utility modules, as well as model definitions (which are found in src/models).
The datasets directory contains the code to generate the TFDS datasets for:
- CIFAR 10 synthetic data (from https://arxiv.org/abs/2104.09425): to mix with CIFAR-10
- ImageNet Subset: to generate the ImageNet subset of 100 classes used for the ablation.
- RobustBench Imagenet: to generate the subset of 5000 images used in RobustBench as a TFDS dataset.
- ImageNet Perturbations: to generate a dataset of adversarial perturbations targeting several models for the RobustBench subset. We used these perturbations to classify them with SOTA ImageNet models to quantify the perceptual nature of adversarial perturbations.
- ImageNet AdvEx: to generate a dataset of adversarial examples targeting several models for the RobustBench subset.
Additional information about how TFDS datasets are generated can be found on TFDS' documentation.
In order to run the unit tests in the repo, install pytest via pip install pytest
, and run
python -m pytest .
As mentioned above, this codebase is based on timm
, we thank @rwightman for sharing it. Moreover, we thank Google’s TPU Research Cloud (TRC) program, which provided us with extremely generous computing resources which enabled us to train the models we are sharing.
If you find our work useful, please consider citing it using the following BibTeX entry:
@inproceedings{
debenedetti2023light,
title={A Light Recipe to Train Robust Vision Transformers},
author={Edoardo Debenedetti and Vikash Sehwag and Prateek Mittal},
booktitle={First IEEE Conference on Secure and Trustworthy Machine Learning},
year={2023},
url={https://openreview.net/forum?id=IztT98ky0cKs}
}