Skip to content

juraam/snake-diffusion

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Snake Diffusion model

It is an educational repo demonstrating how to build a real-time Snake game using a Diffusion model. It was inspired by several great papers:

The goal was to create a similar implementation using Snake game due its simple logic. It took near 2 months of different experiments to get a ready-to-play model.

If you don't have GPU, you can use runpod.io(paid service).

Model inference

Model Architecture

After several experiments, I selected the EDM diffusion model for its high performance with small sample steps. DDIM requires much more steps to achieve comparable quality.

Model scheme

Installation

Install required dependencies:

pip install -r requirements.txt

Training

First, obtain the training dataset using one of these methods:

  1. Download the prepared dataset:
bash scripts/download-dataset.sh
  1. Or generate it manually:
python src/generate_dataset.py --model agent.pth --dataset training_data --record

Then start the training:

python src/train.py --model-type edm --output-prefix models/model --dataset training_data --gen-val-images

The model was trained on runpod.io for 32 epochs, taking approximately 27 hours at a cost of $10.

Inference

Download the pre-trained model:

git clone https://huggingface.co/juramoshkov/snake-diffusion models

To play the game, either:

  1. Run Play.ipynb locally to play Snake at 1 FPS(it depends on your GPU) 🤓
  2. Use runpod.io:
    • Deploy a Pod (RTX 4090 recommended for best performance)
    • Copy and run the contents of scripts/runpod.sh
    • Open Play.ipynb

About

Implementation snake game based on Diffusion model

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published