This repo is officical PyTorch implement of 'Feature Redundancy Elimination for Test Time Adaptation'
This codebase is mainly based on TSD and AETTA.
We use python==3.8.13
, other packages including:
torch==1.12.0+cu113
torchvision==0.13.0+cu113
numpy==1.24.4
pandas==2.0.3
tqdm==4.66.2
timm==0.9.16
scikit-learn==1.3.2
pillow==10.3.0
We also share our python environment that contains all required python packages. Please refer to the ./FRET.yml
file.
You can import our environment using conda:
conda env create -f FRET.yml -n FRET
Download PACS and OfficeHome datasets used in our paper from:
PACS
OfficeHome
Download them from the above links, and organize them as follows.
|-your_data_dir
|-PACS
|-art_painting
|-cartoon
|-photo
|-sketch
|-OfficeHome
|-Art
|-Clipart
|-Product
|-RealWorld
To download the CIFAR10/CIFAR10-C and CIFAR100/CIFAR100-C datasets ,run the following commands:
$. download_cifar10c.sh #download CIFAR10/CIFAR10-C datasets
$. download_cifar100c.sh #download CIFAR100/CIFAR100-C datasets
Also, you can download the VLCS, DomainNet, ImageNet-C from the links below.
- VLCS: https://drive.google.com/uc?id=1skwblH1_okBwxWxmRsp9_qi15hyPpxg8
- DomainNet: http://ai.bu.edu/M3SDA/
- ImageNet-C: https://zenodo.org/record/2235448
Please use train.py
to train the source model. For example:
cd code/
python train.py --dataset PACS \
--data_dir your_data_dir \
--opt_type Adam \
--lr 5e-5 \
--max_epoch 50 \
--net resnet18 \
--test_envs 0 \
Change --dataset PACS
for other datasets, such as office-home
,VLCS
,DomainNet
, CIFAR-10
, CIFAR-100
.
Set --net
to use different backbones, such as resnet50
, ViT-B16
.
Set --test_envs 0
to change the target domain.
For CIFAR-10 and CIFAR-100, there is no need to set the --data_dir
and --test_envs
.
For domain datasets such as PACS and OfficeHome, run the following code:
python unsupervise_adapt.py --dataset PACS \
--data_dir your_data_dir \
--adapt_alg FRET \
--pretrain_dir your_pretrain_model_dir \
--lr 1e-4 \
--net resnet18 \
--test_envs 0
For corrupted datasets such as CIFAR10-C and CIFAR100-C, run the following code:
python unsupervise_adapt_corrupted.py --dataset CIFAR-10 \
--data_dir your_data_dir \
--adapt_alg FRET \
--pretrain_dir your_pretrain_model_dir \
--lr 1e-4 \
--net resnet18
Change --adapt_alg FRET
to use different methods of test time adaptation, e.g. TSD
, BN
, Tent
.
--pretrain_dir
denotes the path of source model, e.g. ./train_outputs/model.pkl
.
For FRET, we have set default parameters in our code. For better results, you might consider adjusting the parameters --lam_FRET1
, --lam_FRET2
, --lam_FRET3
, --filter_K
, and --FRET_K
. For guidance on selecting hyperparameters, please refer to our paper.
We tested our code in the environment described below.
OS: Ubuntu 18.04.6 LTS
GPU: NVIDIA GeForce RTX 4090
GPU Driver Version: 535.129.03
CUDA Version: 12.2