This repository contains the code to reproduce the results of the eLife publication "Learning cortical representations through perturbed and adversarial dreaming" available here.
To install requirements:
pip install -r requirements.txt
In order to train the model with for example the CIFAR-10 dataset, for 50 epochs, with all phases (Wake, NREM REM), execute:
python main_PAD.py --dataset 'cifar10' --niter 50 --batchSize 64 --outf 'model_wnr' --nz 256 --is_continue 1 --W 1.0 --N 1.0 --R 1.0
Setting one of the phase parameters (W, N, R
) to zero will remove the phase from training. At each epoch, the endoder and generator networks, as well as the training losses, are saved into the file trained.pth
.
Once the previous command has been executed for different conditions, in order to display samples for Figure 3, for early and late training, execute:
python fig3_generate_samples.py --dataset 'cifar10' --nz 256 --outf 'model_wnr'
In order to compute FID score, use the full PAD model trained networks obtained above for 4 runs (ex: model_wnr, model_wnr1, etc.) and execute:
python fig3_compute_FID.py --dataset 'cifar10' --outf 'model_wnr' n_samples 500 --split 20
Once computed, call the following file in order to display the FID bar graph:
python fig3_plot_FID.py
In order to compute linear classification accuracy, execute:
python linear_classif.py --dataset 'cifar10' --niterC 20 --outf 'model_wnr' --nz 256
Classification accuracy can be computed at each epoch by running the following command in a batch script:
dset='cifar10'
folder='model_wnr'
for i in {1..50}
do
python main_PAD.py --dataset $dset --niter $i --batchSize 64 --outf $folder --nz 256 --is_continue 1 --W 1.0 --N 1.0 --R 1.0 --epsilon 0.0
python linear_classif.py --dataset $dset --niterC 20 --outf $folder --nz 256
done
Several runs can be saved using the following file names: model_wnr, model_wnr1, model_wnr2, model_wnr3, etc. Following this nomenclature, Figure 4 results can be displayed by executing:
python fig4_plot_accuracies.py --dataset 'cifar10'
In order to compute accuracies with different level of occlusions, execute in a batch script:
dset='cifar10'
folder='model_wnr'
for proba in {0..100..10}
do
python linear_classif_occ.py --dataset $dset --niterC 20 --outf $folder --acc_file 'accuracies_levels.pth' --tile_size 4 --nz 256 --drop $proba
done
To display Figure 5, execute:
python fig5_plot_accuracies.py
PCA results are displayed by loading trained networks while executing the following:
python fig6_plot_PCA.py --dataset 'cifar10' --outf 'model_wnr' --nz 256 --n_samples 50 --n_samples 1000 --n_samples_per_class 50
Intra/Inter class and clean/occluded ratios are computed by loading trained networks for each condition (4 runs) and executing:
python fig6_compute_distances.py --dataset 'cifar10' --n_samples 1000
These metrics are saved in the file cifar10_clustering_metrics.pth
.
Once computed for each dataset, in order to plot Fig.6e and 6f, load the metrics file executing:
python fig6_plot_distances.py
by setting the argument occlusions
to False
for Fig.6e and to True
for Fig.6f.