Skip to content

--The simplified implementation of World Model based on PyTorch--

Notifications You must be signed in to change notification settings

zZhiG/World-Model-PyTorch

Repository files navigation

World-Model-PyTorch

--The simplified implementation of World Model based on PyTorch--

1. Data Generate

Run generate_CarRacing_dataset.py to randomly generate data.

We generated a total of 200 trajectories, with 30 steps executed each time, resulting in a total of 6000 data. Each trajectory is saved separately as a .npz file.

Scale the observed image to a uniform size of 64 $\times$ 64.

In the early stage of the car's movement, we will apply an additional speed to make it move as much as possible and collect richer data.

2. Train VAE

Firstly, we run vae_trainer.py to train the VAE network. Its latent feature channel is 32.

A total of 1000 epochs were trained, with a batch size of 128.

We have also released the loss curve during the training process and final weights.

Here are some visual examples. On the left is the original image, and on the right is the reconstructed image.

3. Train MDN-RNN

About

--The simplified implementation of World Model based on PyTorch--

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages