项目作者: arnaudvl

项目描述 :
PyTorch World Model implementation with PPO.
高级语言: Python
项目地址: git://github.com/arnaudvl/world-models-ppo.git
创建时间: 2019-12-22T12:17:18Z
项目社区:https://github.com/arnaudvl/world-models-ppo

开源协议:

下载


world-models-ppo

World Model implementation with PPO in PyTorch. This repository builds on world-models for the VAE and MDN-RNN implementations and firedup for the PPO optimization of the Controller network. Check the firedup setup file for requirements.

First save a number of the CarRacing-v0 Gym environment rollouts used for the train and test sets in the data_dir folder:

  1. python env/carracing.py --data_dir './env/data' ---n_fold_train 20 ---n_fold_test 1

Then train the Variational Autoencoder (VAE) using the stored rollouts:

  1. from vae.train import run
  2. run(data_dir='./env/data', vae_dir='./vae/model', epochs=5)

Using the pretrained VAE, we train the Recurrent Mixture Density Network (MDN-RNN) model to predict the future latent state:

  1. from mdnrnn.train import run
  2. run(data_dir='./env/data', vae_dir='./vae/model', mdnrnn_dir='./mdnrnn/model', epochs=5)

We can finally train the Controller network which steers the car with PPO:

  1. from rl.algos.ppo.ppo import run
  2. run(exp_name='carracing_ppo', epochs=100)