项目作者: e-hulten

项目描述 :
PyTorch implementation of Real NVP for density estimation
高级语言: HTML
项目地址: git://github.com/e-hulten/real-nvp-2d.git
创建时间: 2019-10-17T12:51:35Z
项目社区:https://github.com/e-hulten/real-nvp-2d

开源协议:

下载


Real NVP

PyTorch implementation of the Real NVP paper by Dinh et al. (1). This is an implementation of Real NVP for density estimation, rather than generative modelling. The model supports sample generation (backward pass through the flow) at the same computational cost as the one of density evaluation, but the code is not (yet) adapted for dealing with images. However, visualising the inverse and forward pass of two-dimensional densities is feasible, and I have recreated Figure 1 from (1) as a gif below:



All the interesting functionality is found in model.py.

Change the relevant parameters in train.py and run. E.g.,

  1. # ------------ parameters ------------
  2. continue_training = False
  3. gif = True # if you want to visualise the training as a gif (only for 2d densities)
  4. density = "moons" # set to true if you want to use the two moons dataset
  5. n_c_layers = 10 # number of coupling layers
  6. epochs = 200 # number of training epochs
  7. batch_size = 100 # set batch size
  8. lr = 5e-4 # set the learning rate of the adam optimiser
  9. plot_interval = 1 # plot at the end of every epoch (for gif)
  10. path = r"/Users/edvardhulten/real_nvp_2d/" # change to your own path (unless your name is Edvard Hultén too)
  11. distr_name = "two_moons"
  12. duration = 0.1
  13. # ------------------------------------

should yield good results on the two moons dataset in a very reasonable amount of time.

I have also added the batch norm layer presented in Appendix B of (2) by Papamakarios et al.

(1): https://arxiv.org/abs/1605.08803

(2): https://arxiv.org/abs/1705.07057