- Overview
- Installation
- Data
- Pretrained weights
- Train
- Evaluation
- Implementation details
- Acknowledgments
- Contacts
Official PyTorch implementation of "DiMSUM: Diffusion Mamba - A Scalable and Unified Spatial-Frequency Method for Image Generation" (NeurIPS'24)
Hoang Phan4 · Dimitris N. Metaxas2 · Anh Tran1
1VinAI Research 2Rutgers University 3Cornell University 4New York University
[Page] [Paper] [HuggingFace
*Equal contribution
†Work done while at VinAI Research
We propose DiMSUM, a hybrid Mamba-Transformer diffusion model that synergistically leverages both spatial and frequency information for high-quality image synthesis. Through extensive experiments on standard benchmarks, our method achieves state-of-the-art results, with a FID of 4.62 on CelebHQ 256, 3.76 on LSUN Church, and 2.11 on ImageNet1k 256. Additionally, our approach attains faster training convergence compared to Zigma and other diffusion methods. In detail, our method outperforms both DiT and SiT while requiring less than a third of the training iterations, achieving the best FID score of 2.11.
Details of the model architecture and experimental results can be found in our following paper:
@inproceedings{phung2024dimsum,
title={DiMSUM: Diffusion Mamba - A Scalable and Unified Spatial-Frequency Method for Image Generation},
author={Phung, Hao and Dao, Quan and Dao, Trung and Phan, Hoang and Metaxas, Dimitris and Tran, Anh},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year= {2024},
}
Please CITE our paper and give us a ⭐ whenever this repository is used to help produce published results or incorporated into other software.
News
- [Feb 17th, 2025] Upload ImageNet-1K model onto HuggingFace
for easy access.
-
Python 3.10.13
conda create -n dimsum python=3.10.13
-
torch 2.1.1 + cu118
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
-
Requirements:
pip install -r requirements.txt
-
Install
causal_conv1d
andmamba
conda install conda-forge::cudatoolkit-dev
cd causal_conv1d && pip install -e . && cd ..
cd mamba && pip install -e . && cd ..
-
Add python path for DiMSUM:
export PYTHONPATH=$PYTHONPATH:$(pwd)
For CelebA HQ (256) and LSUN, please follow this repo for dataset preparation.
We divide imagenet into many subsets and parallel preprocess each subsets with different gpus to speed up preprocessing stage (within 2 hours).
python preprocess_latent_imagenet_dat.py --data-path <path-to-your-imagenet> --features-path <save-path-for-latent-imagenet> --total-batch <number-of-imagenet-partition> --batch-idx <index-of-partition>
For evaluation, please resize and extract "jpeg" images from dataset first.
For LMDB data (like celeba_256
and lsun_church
), run this command:
python eval_toolbox/resize_lmdb.py --dataset celeba_256 --datadir ./data/celeba_256/celeba-lmdb/ --image_size 256 --save_dir real_samples/
For image folder of jpeg/png images, run this command instead:
python eval_toolbox/resize.py main input_data_dir real_samples/dataname
Currently, there is a Hugging Face model for ImageNet-1K available at haopt/dimsum-L2-imagenet256
. To use, simply set --ckpt haopt/dimsum-L2-imagenet256
in sampling bash scripts of Evaluation.
Or manually download all models with provided links below:
Exp | #Params | FID | Checkpoints |
---|---|---|---|
Celeba 256 | 460M | 4.62 | celeb256_225ep.pt |
Church 256 | 460M | 3.76 | church_395ep.pt |
ImageNet-1K 256 (CFG) | 460M | 2.11 | imnet256_510ep.pt |
Comment/Uncomment command lines for desired dataset, then run:
bash scripts/train.sh
To sampe images from pretrained checkpoints, run:
bash scripts/sample.sh
To evaluate, select a relevant command and run:
bash scripts/eval.sh
- DiMSUM architecture is put in dimsum/models_dim.py.
- Conditional Mamba can be found in mamba/mamba_ssm/ops/selective_scan_interface.py and causal-conv1d/csrc/causal_conv1d.cpp.
- Frequency transformations: dimsum/wavelet_layer.py and dimsum/dct_layer.py.
- Mamba Scanning strategies (e.g. sweep8, jpeg8): dimsum/scanning_orders.py.
This project is based on Vim, LFM, SiT, DiT, ZigMa. Thanks for publishing their wonderful works with codes.
If you have any problems, please open an issue in this repository or ping an email to [email protected] and [email protected].