This repository represents the official implementation of the paper titled "[MASK] is All You Need".
We present Discrete Interpolants, to bridge the Diffusion Models and Maskged Generative Models in discrete-state, and scale it up in vision domain.
Please cite our paper:
@InProceedings{hu2024mask,
title={[MASK] is All You Need},
author={Vincent Tao Hu and Björn Ommer},
booktitle = {Arxiv},
year={2024}
}
Feb. 4th, 2025
: Training code released.Dec. 10th, 2024
: Arxiv released.
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --num_processes 4 --num_machines 1 --main_process_ip 127.0.0.1 --main_process_port 8868 train_ds_vq.py model=uvit_s2deep_it data=coco14_cond_indices dynamic=linear dynamic.mask_ce=1 input_tensor_type=bwh tokenizer=sd_vq_f8 optim.wd=0.00 "optim.betas=[0.9, 0.9]" data.train_steps=1_000_000 ckpt_every=20_000 data.sample_fid_every=100_000 data.sample_fid_n=20_000 data.batch_size=64 optim.name=adam optim.lr=2e-4 lrschedule.warmup_steps=5000 dstep_num=500 mixed_precision=bf16 accum=4
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --num_processes 4 --num_machines 1 --main_process_ip 127.0.0.1 --main_process_port 8868 train_acc_vq.py model=uvit_h2_it dynamic=linear input_tensor_type=bwh tokenizer=sd_vq_f8 data=imagenet256_cond_indices data.batch_size=64 data.sample_vis_n=16 data.sample_fid_every=50_000 ckpt_every=20_000 data.train_steps=1500_000 data.sample_fid_n=5_000 optim.name=adamw optim.lr=1e-4 optim.wd=0.0 lrschedule.warmup_steps=1 mixed_precision=bf16 accum=1
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --num_processes 4 --num_machines 1 --main_process_ip 127.0.0.1 --main_process_port 8868 sample_ds_vq.py model=dit_xl2_it dynamic=linear input_tensor_type=bwh tokenizer=sd_vq_f8 data=imagenet256_cond_indices data.batch_size=64 data.sample_vis_n=16 data.sample_fid_every=40_000 data.sample_fid_n=5_000 optim.name=adamw optim.lr=1e-4 optim.wd=0.0 lrschedule.warmup_steps=0 data.train_steps=1_400_000 ckpt_every=20_000 mixed_precision=bf16 accum=1 ckpt="./outputs/v1.3_vqacc_note_bf16_imagenet256_cond_indices_dit_xl2_it_linear_sd_vq_f8_bs64acc1_wd0.0_gc1.0_4g_mcml-hgx-h100-008_4980788/2024-12-08_11-58-30/checkpoints/1100000.pt" num_fid_samples=50000 offline.lbs=100 dynamic.disint.scheduler=linear dynamic.disint.sampler=maskgit maskgit_randomize=linear top_k=0 top_p=0 offline.save_samples_to_disk=1 sm_t=1.3 use_cfg=1 cfg_scale=2 dstep_num=20
You should get an FID around 8.26.
This work is licensed under the Apache License, Version 2.0 (as defined in the LICENSE).
By downloading and using the code and model you agree to the terms in the LICENSE.