项目作者: jayparks

项目描述 :
A PyTorch Implementation of "Quasi-Recurrent Neural Networks"
高级语言: Python
项目地址: git://github.com/jayparks/quasi-rnn.git
创建时间: 2017-08-24T09:48:06Z
项目社区:https://github.com/jayparks/quasi-rnn

开源协议:

下载


Neural Machine Translation using Quasi-RNN

Pytorch implementation of Neural Machine Translation using “Quasi-Recurrent Neural Networks”, ICLR 2017

Requirements

  • NumPy >= 1.11.1
  • Pytorch >= 0.2.0

Usage Instructions

Codes

  • layer.py : Implementation of the quasi-recurrent layer
  • model.py: Implementation of the Encoder-Decoder model using qrnn layer
  • train.py: Code to train a NMT model
  • decode.py: Code to translate a source file using a trained model

Training

To train a quasi-rnn NMT model,

  1. $ python train.py --kernel_size 3 \
  2. --hidden_size 640 \
  3. --emb_size 500 \
  4. --num_enc_symbols 30000 \
  5. --num_dec_symbols 30000 ...

Decoding

To run the trained model for translation,

  1. $ python eval.py --model_path $path_to_model \
  2. --decode_input $path_to_source \
  3. --decode_output $path_to_output
  4. --max_decode_step 300 \
  5. --batch_size 30 ...

For simplicity, we used greedy decoding at each time step, not the beam search decoding.

Notes

For more in-depth exploration, QRNN API for Pytorch is available: https://github.com/salesforce/pytorch-qrnn

For any comments and feedbacks, please email me at pjh0308@gmail.com or open an issue here.