It is an educational repo demonstrating how to build a real-time Snake game using a Diffusion model. It was inspired by several great papers:
The goal was to create a similar implementation using Snake game due its simple logic. It took near 2 months of different experiments to get a ready-to-play model.
If you don't have GPU, you can use runpod.io(paid service).
After several experiments, I selected the EDM diffusion model for its high performance with small sample steps. DDIM requires much more steps to achieve comparable quality.
Install required dependencies:
pip install -r requirements.txt
First, obtain the training dataset using one of these methods:
- Download the prepared dataset:
bash scripts/download-dataset.sh
- Or generate it manually:
python src/generate_dataset.py --model agent.pth --dataset training_data --record
Then start the training:
python src/train.py --model-type edm --output-prefix models/model --dataset training_data --gen-val-images
The model was trained on runpod.io for 32 epochs, taking approximately 27 hours at a cost of $10.
Download the pre-trained model:
git clone https://huggingface.co/juramoshkov/snake-diffusion models
To play the game, either:
- Run Play.ipynb locally to play Snake at 1 FPS(it depends on your GPU) 🤓
- Use runpod.io:
- Deploy a Pod (RTX 4090 recommended for best performance)
- Copy and run the contents of scripts/runpod.sh
- Open Play.ipynb