Unofficial UR-LSTM implementation in Pytorch
Project description
UR-LSTM
Description
This repository revolves around the paper: Improving the Gating Mechanism of Recurrent Neural Networks by Albert Gu, Caglar Gulcehre, Tom Paine, Matt Hoffman and Razvan Pascanu.
In it, the authors introduce the UR-LSTM, a variant of the LSTM architecture which robustly improves the performance of the recurrent model, particularly when long-term dependencies are involved.
Unfortunately, to my knowledge the authors did not release any code, either for the model or experiments - although they did provide pseudo-code for the model. Since I thought it was a really cool read, I decided to reimplement the model as well as some of the experiments with the Pytorch framework.
I've separated the code for the UR-LSTM, which is packaged and downloadable as a standalone module, from the code for the experiments. If you want to check out how to run them, go check this page.
Installation
With Python 3.6 or higher:
pip install ur-lstm-torch
I haven't checked if the model is compatible with older versions of Pytorch, but it should be fine for everything past version 1.0
.
Usage
The model can be used in the same way as the native LSTM
implementation (doc is here), although I didn't implement the bidirectionnal variant and removed the bias
keyword argument:
import torch
from ur_lstm import URLSTM
input_size = 10
hidden_size = 20
num_layers = 2
batch_first = False
dropout = .5
model = URLSTM(
input_size, hidden_size, num_layers=num_layers, batch_first=batch_first, dropout=dropout
)
batch_size = 3
seq_length = 5
x = torch.randn(seq_length, batch_size, input_size)
out, state = model(x)
print(out.shape) # (seq_length, batch_size, hidden_size)
print(len(state)) # 2, first is hidden state, second is cell state
print(state[0].shape) # (num_layers, batch_size, hidden_size)
print(state[1].shape) # (num_layers, batch_size, hidden_size)
If you want to implement a custom model, you can also import and use the URLSTMCell
module in the same way you would the regular LSTMCell
(doc is here), although again I removed the bias
keyword argument:
import torch
from ur_lstm import URLSTMCell
input_size = 10
hidden_size = 20
cell = URLSTMCell(input_size, hidden_size)
batch_size = 2
x = torch.randn(batch_size, input_size)
state = torch.randn(batch_size, hidden_size), torch.randn(batch_size, hidden_size)
out, state = cell(x, state)
print(out.shape) # (batch_size, hidden_size)
print(len(state)) # 2, first is hidden state, second is cell state
print(state[0].shape) # (batch_size, hidden_size)
print(state[1].shape) # (num_layers, batch_size, hidden_size)
License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for ur_lstm_torch-0.0.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 54cd76da7a0bac77dc9a5741b1b30b21f9fe6b902f94ea7a08ff99dfa06c4428 |
|
MD5 | 747d1d3cf7ae979fd3e0f1fb0a943546 |
|
BLAKE2b-256 | ba888192aabbd4decc593bb779ea4276bd530ddc31112e4369b89199d9324af8 |