TensorFlow implementations of Wasserstein GAN with Gradient Penalty (WGAN-GP), Least Squares GAN (LSGAN), GANs with the hinge loss.
This is my TensorFlow implementations of Wasserstein GANs with Gradient Penalty (WGAN-GP) proposed in Improved Training of Wasserstein GANs, Least Squares GANs (LSGAN), and GANs with the hinge loss.
The key insight of WGAN-GP is as follows. To enforce Lipschitz constraint in Wasserstein GAN, the original paper proposes to clip the weights of the discriminator (critic), which can lead to undesired behavior including exploding and vanishing gradients. Instead of weight clipping, this paper proposes to employ a gradient penalty term to constrain the gradient norm of the critic’s output with respect to its input, resulting the learning objective:
This enables stable training of a variety of GAN models on a wide range of datasets. This implementation is tested on several datasets inlcuding LSUN bedroom, CelebA, CityScape(leftImg8bit_sequence_trainvaltest), ImageNet, CIFAR100, CIFAR10, Street View House Number (SVHN), MNIST, and Fashion_MNIST. Randomly sampled results are as follows.
*This code is still being developed and subject to change.
python download.py --dataset bedroom celeba CIFAR10 CIFAR100 SVHN MNIST Fashion_MNIST
python trainer.py --dataset [bedroom / celeba / CityScape / ImageNet / CIFAR10 / CIFAR100 / SVHN / MNIST / Fashion_MNIST] --batch_size 36 --num_dis_conv 6 --gan_type wgan-gp
config.py
for more details)bedroom
, celeba
, ImageNet
, CIFAR10
, CIFAR100
, SVHN
, MNIST
, and Fashion_MNIST
. You can also add your own datasets.[train step 10] D loss: 1.26449 G loss: 46.01093 (0.057 sec/batch, 558.933 instances/sec)
).wgan-gp
, lsgan
, or hinge
.
python evaler.py --dataset [DATASET] [--train_dir /path/to/the/training/dir/ OR --checkpoint /path/to/the/trained/model] --write_summary_image True --output_file output.hdf5
config.py
for more details)Launch TensorBoard and go to the specified port, you can see different losses in the scalars tab and plotted images in the images tab. The images could be interpreted as follows.
fake_image
: a batch of generated images in the current batchimg
:D(real image)
), reflecting how the discrimiantor thinks about this image. White: real; balck: fake.D(generated image)
), reflecting how the discrimiantor thinks about this image. White: real; balck: fake.
$ mkdir datasets/YOUR_DATASET
Step 1: organize your data
With the HDF5 loader:
With the image loader:
datasets/YOUR_DATASET
.jpg
, .jpeg
, .JPEG
, .webp
, and .png
.Step 2: train and test
$ python trainer.py --dataset YOUR_DATASET --dataset_path datasets/YOUR_DATASET
$ python evaler.py --dataset YOUR_DATASET --dataset_path datasets/YOUR_DATASET --train_dir dir
As part of the implementation series of Cognitive Learning for Vision and Robotics Lab at the University of Southern California, our motivation is to accelerate (or sometimes delay) the research in AI community by promoting open-source projects. To this end, we implement state-of-the-art research papers, and publicly share them with concise reports. Please visit our group GitHub site.
This project is implemented by Shao-Hua Sun and reviewed by Youngwoon Lee.
Shao-Hua Sun / @shaohua0116