This repository is modified from prior repository of ICCV-2019, which includes defense codes and other codes for profiling purpose.
This repository constains a Pytorch implementation of BFA and its defense as discussed in the papers:
- "Bit-Flip Attack: Crushing Neural Network with Progressive Bit Search", which is published in ICCV-2019.
- "Defending and Harnessing the Bit-Flip based Adversarial Weight Attack", which is published in CVPR-2020.
If you find this project useful to you, please cite our work:
@inproceedings{he2019bfa,
title={Bit-Flip Attack: Crushing Neural Network with Progressive Bit Search},
author={Adnan Siraj Rakin and He, Zhezhi and Fan, Deliang},
booktitle={Proceedings of the IEEE International Conference on Computer Vision (ICCV)},
pages={1211-1220},
year={2019}
}
@inproceedings{he2020defend,
title={Defending and Harnessing the Bit-Flip based Adversarial Weight Attack},
author={He, Zhezhi and Rakin, Adnan Siraj and Li, Jingtao and Chakrabarti, Chaitali and Fan, Deliang},
booktitle={Proceedings of the IEEE International Conference on Computer Vision (CVPR)},
year={2019}
}
- Bit-Flips Attack and Defense
This repository includes a Bit-Flip Attack (BFA) algorithm which search and identify the vulernable bits within a quantized deep neural network.
- Python 3.6 (Anaconda)
- Pytorch >=1.01
- TensorboardX
For more specific dependency, please refer environment.yml and environment_setup.md
Please modify "alhpha"
, PYTHON=
, TENSORBOARD=
and data_path=
in the example bash code (BFA_imagenet.sh
) before running the code. This configuration is extremely useful to run the same code on different nodes.
HOST=$(hostname)
echo "Current host is: $HOST"
# Automatic check the host and configuration
case $HOST in
"alpha") # alpha is the hostname (check your current host in terminal by 'hostname')
PYTHON="/home/elliot/anaconda3/envs/pytorch041/bin/python" # python environment path
TENSORBOARD='/home/elliot/anaconda3/envs/pytorch041/bin/tensorboard' # tensorboard environment path
data_path='/home/elliot/data/imagenet' # imagenet/cifar10 dataset path
;;
esac
Note: BFA evalution can only be performed on signle GPU (i.e., data_parallel lead to bug). Note: Keep the bit-width of weight quantization as 8-bit.
Here I show the BFA on the ResNet-18, where the ResNet-18 is from pytorch pretrained model Zoo. BFA can be performed by just running the following command in the terminal.
$ bash BFA_imagenet.sh
# CUDA_VISIBLE_DEVICES=2 bash BFA_imagenet.sh # to specify GPU id to ex. 2
The example output log file of BFA on ResNet18:
**Test** Prec@1 69.498 Prec@5 88.976 Error@1 30.502
k_top is set to 10
Attack sample size is 128
**********************************
attacked module: conv1
attacked weight index: [42 2 4 5]
weight before attack: 21.0
weight after attack: -107.0
Iteration: [001/020] Attack Time 1.824 (1.824) [2020-05-06 21:14:43]
loss before attack: 0.6131
loss after attack: 0.8230
bit flips: 1
hamming_dist: 1
**Test** Prec@1 67.538 Prec@5 87.756 Error@1 32.462
iteration Time 61.966 (61.966)
**********************************
attacked module: layer2.0.downsample.0
attacked weight index: [33 50 0 0]
weight before attack: -1.0
weight after attack: 127.0
Iteration: [002/020] Attack Time 1.315 (1.569) [2020-05-06 21:15:47]
loss before attack: 0.8230
loss after attack: 1.4941
bit flips: 2
hamming_dist: 2
**Test** Prec@1 59.754 Prec@5 82.390 Error@1 40.246
iteration Time 62.318 (62.142)
**********************************
It shows to identify one bit througout the entire model only takes ~2 Second (i.e., Attack Time) using 128 sample images for BFA.
Taken the MobileNet v2 as example, the step-by-step tutorial is listed as follow:
- the first step is find a pretrained pytorch model online.
- create the model definition as
./models/vanilla_models/vanilla_mobilenet_imagenet.py
, and copy the model into it. Then add the following line to the.models/__init__.py
:
############# Mobilenet for ImageNet #######
from .vanilla_models.vanilla_mobilenet_imagenet import mobilenet_v2
And make sure you are using the pretrained model option is enabled by setting pretrained=True
in ./models/vanilla_models/vanilla_mobilenet_imagenet.py
:
def mobilenet_v2(pretrained=True, progress=True, **kwargs):
...
- Run the
bash eval_imagenet.sh
can see that accuracy on validation dataset is 71.878%.
**Test** Prec@1 71.878 Prec@5 90.286 Error@1 28.122
- To check the accuracy with 8-bit weight quantization. create a copy of quantized mobilenetv2 in
models/quan_mobilenet_imagenet.py
. The following modifications are made sequentially:
- import the quantized convolution and fully-connected layer.
from .quantization import *
- Change all
nn.Conv2d
andnn.Linear
toquan_Conv2d
andquan_Linear
. - Add codes for proper model loading:
# Modification for proper model loading
model_dict = model.state_dict()
pretrained_dict = {
k: v
for k, v in pretrained_dict.items() if k in model_dict
}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
- Initialize the model in
.models/__init__.py
:
from .quan_mobilenet_imagenet import mobilenet_v2_quan
- To evaluate the accuracy of quantized version, run
bash eval_imagenet_quan.sh
, then you get:
**Test** Prec@1 71.138 Prec@5 90.012 Error@1 28.862
- To perform the BFA on mobilenet-v2, simply change the configuration in
BFA_imagenet.sh
:
model=mobilenet_v2_quan
attack_sample_size=10 # reduce the data sampes to 10, otherwise GPU out-of-memory
The BFA result is:
**Test** Prec@1 71.138 Prec@5 90.012 Error@1 28.862
k_top is set to 10
Attack sample size is 10
**********************************
attacked module: features.1.conv.0.0
attacked weight index: [6 0 1 2]
weight before attack: -41.0
weight after attack: 87.0
Iteration: [001/020] Attack Time 1.004 (1.004) [2020-05-07 04:26:19]
loss before attack: 1.1194
loss after attack: 13.1416
bit flips: 1
hamming_dist: 1
**Test** Prec@1 0.238 Prec@5 0.866 Error@1 99.762
iteration Time 64.102 (64.102)
**********************************
Single bit-flip on 8-bit Mobilenet-V2 degrade the top-1 accuracy from 71.138% to 0.206%.
The random attack is performed on all the possible weight bit (regardless MSB to LSB). Take the above MobileNet-v2 as example, you just need to add another line to enable the random bit flip --random_bfa
in BFA_imagent.sh
:
...
--attack_sample_size ${attack_sample_size} \
--random_bfa
...
Taken the ResNet-20 on CIFAR-10 as example:
-
Define a binarized ResNet20 in
models/quan_resnet_cifar.py
. -
To use the weight binariztaion function. Comment out multi-bit quantization and uncomment the binarization modules.
-
Perform the model training, where the binarized model is initialized in
models/__init__.py
asresnet20_quan
. Then runbash train_CIFAR.sh
in terminal (Don't forget the path configuration!). -
With binarized model trained and stored at
<path-to-model>/model_best.pth.tar
, make sure the following changes in theBFA_CIFAR.sh
:
pretrained_model='<path-to-model>/model_best.pth.tar'
The piecewise weight clutering should not be applied on the binarized NN.
-
Make sure
models/quantization.py
use the multi-bit quantization, in constrast to the binarized counterpart. To change the bit-width, please access the code inmodels/quantization.py
. Under the definition ofquan_Conv2d
andquan_Linear
, change the argself.N_bits = 8
if you want 8-bit quantization. -
In
train_CIFAR.sh
, enable (i.e., uncomment) the following command:
--clustering --lambda_coeff 1e-3
Then train the model by bash train_CIFAR.sh
.
- For the BFA evaluation, please refer the binarization case.
We direct adopt the post-training quantization on the DNN pretrained model provided by the model-zoo of pytorch.
Note: for save the model in INT-8, additional data conversion is expected.
Considering the quantized weight is a integer ranging from to , if using bits quantization. For example, the value range is -128 to 127 with 8-bit representation. In this work, we use the two's complement as its binary format (), where the back and forth conversion can be described as:
Warning: The correctness of the code is also depends on the
dtype
setup for the quantized weight, when convert it back and forth between signed integer and two's complement (unsigned integer). By default, we use.short()
for 16-bit signed integers to prevent overflowing.
This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
The software is for educaitonal and academic research purpose only.