This repository contains a proof-of-concept implementation of Quasi Recurrent Neural Networks (QRNNs) [1].
The goal was to implement the recurrent pooling without a Python loop and without a custom CUDA kernel.
To this end, the recurrence is transformed to suitable cumprod
and cumsum
operations.
See below for the derivation.
Please note that this implementation is still about 2x slower than the PyTorch LSTM implementation (on a single GPU). Also, this implementation does not enable bidirectional sequence processing. However, this can be achieved by appropriate masking and flipping of gates.
The recurrent pooling stated in [1] is:
In order to calculate
Note that the initial hidden state is initialised to all zeros and therefore does not appear in the sum.
We cannot calculate the product of gates
To circumvent this problem, we first multiply the sum by
Now, the product cumprod
operation.
Accordingly, calculating the numerator is implemented by a cumsum
operation.
Note, however, that the division prohibits any zeros in the recurrent gates. This is naturally enforced by using the sigmoid
activation function to calculate recurrent gate values.
For numerical stability, however, we perform all computations in log space.
Accordingly, cumprod
is replaced by cumsum
and cumsum
is replaced by logcumsumexp
.
Note that this way, the activation function used to calculate exp
or softplus
.
The file qrnn_layer.py
contains a class implementing a QRNN layer as described in [1].
It takes the following arguments:
input_size
: Number of input features / channelshidden_size
: Number of output features / channelskernel_size
: Convolutional kernel width = number of previous timesteps that influence gate values for a given timestepmode
: Can be "f", "fo", or "ifo". These correspond to the QRNN variants with the same names as described in [1]zoneout
: Type of recurrent dropout. Probability of randomly setting a recurrent gate to 1, i.e. copying the previous hidden state to the current timestep without modification
In order to check whether the implementation is correct, a QRNN language model is compared to a LSTM language model
on the Wikitext-2v1 dataset.
However, this comparison uses a truncated vocabulary (top 10K most frequent tokens) and treats each sentence as
input sequence.
See the accompanying jupyter notebook for data preprocessing and experiment.py
for exact hyperparameters.
The only purpose of this comparison is to show the implementation works as expected, not to reach good performance.
Results are in the following table:
Model | Perplexity |
---|---|
QRNN | 109.18 |
LSTM | 98.48 |
Perplexity of LSTM is lower, but the same order of magnitude, which shows that this implementation works in principle. It may be the case that LSTM is the superior model though (can't confirm superiority of QRNN claimed by [1]).
[1]: Bradbury, James, et al. "Quasi-recurrent neural networks." arXiv preprint arXiv:1611.01576 (2016).