Skip to content

youlj109/fret

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 

Repository files navigation

FRET: Feature Redundancy Elimination for Test Time Adaptation

This repo is officical PyTorch implement of 'Feature Redundancy Elimination for Test Time Adaptation'
This codebase is mainly based on TSD and AETTA.

Dependence

We use python==3.8.13, other packages including:

torch==1.12.0+cu113
torchvision==0.13.0+cu113
numpy==1.24.4
pandas==2.0.3
tqdm==4.66.2
timm==0.9.16
scikit-learn==1.3.2 
pillow==10.3.0

We also share our python environment that contains all required python packages. Please refer to the ./FRET.yml file.
You can import our environment using conda:

conda env create -f FRET.yml -n FRET

Dataset

Download PACS and OfficeHome datasets used in our paper from:
PACS
OfficeHome
Download them from the above links, and organize them as follows.

|-your_data_dir
  |-PACS
    |-art_painting
    |-cartoon
    |-photo
    |-sketch
  |-OfficeHome
    |-Art
    |-Clipart
    |-Product
    |-RealWorld

To download the CIFAR10/CIFAR10-C and CIFAR100/CIFAR100-C datasets ,run the following commands:

$. download_cifar10c.sh        #download CIFAR10/CIFAR10-C datasets
$. download_cifar100c.sh       #download CIFAR100/CIFAR100-C datasets

Also, you can download the VLCS, DomainNet, ImageNet-C from the links below.

Train source model

Please use train.py to train the source model. For example:

cd code/
python train.py --dataset PACS \
                --data_dir your_data_dir \
                --opt_type Adam \
                --lr 5e-5 \
                --max_epoch 50 \
                --net resnet18 \
                --test_envs 0  \

Change --dataset PACS for other datasets, such as office-home,VLCS,DomainNet, CIFAR-10, CIFAR-100.
Set --net to use different backbones, such as resnet50, ViT-B16.
Set --test_envs 0 to change the target domain.
For CIFAR-10 and CIFAR-100, there is no need to set the --data_dir and --test_envs .

Test time adaptation

For domain datasets such as PACS and OfficeHome, run the following code:

python unsupervise_adapt.py --dataset PACS \
                            --data_dir your_data_dir \
                            --adapt_alg FRET \ 
                            --pretrain_dir your_pretrain_model_dir \
                            --lr 1e-4 \
                            --net resnet18 \
                            --test_envs 0

For corrupted datasets such as CIFAR10-C and CIFAR100-C, run the following code:

python unsupervise_adapt_corrupted.py --dataset CIFAR-10 \
                                      --data_dir your_data_dir \
                                      --adapt_alg FRET \ 
                                      --pretrain_dir your_pretrain_model_dir \
                                      --lr 1e-4 \
                                      --net resnet18

Change --adapt_alg FRET to use different methods of test time adaptation, e.g. TSD, BN, Tent.
--pretrain_dir denotes the path of source model, e.g. ./train_outputs/model.pkl.
For FRET, we have set default parameters in our code. For better results, you might consider adjusting the parameters --lam_FRET1, --lam_FRET2, --lam_FRET3, --filter_K, and --FRET_K. For guidance on selecting hyperparameters, please refer to our paper.

Tested Environment

We tested our code in the environment described below.

OS: Ubuntu 18.04.6 LTS
GPU: NVIDIA GeForce RTX 4090
GPU Driver Version: 535.129.03
CUDA Version: 12.2

Releases

No releases published

Packages

No packages published