Unofficial UR-LSTM implementation in Pytorch
Project description
UR-LSTM
Description
This repository revolves around the paper: Improving the Gating Mechanism of Recurrent Neural Networks by Albert Gu, Caglar Gulcehre, Tom Paine, Matt Hoffman and Razvan Pascanu.
In it, the authors introduce the UR-LSTM, a variant of the LSTM architecture which robustly improves the performance of the recurrent model, particularly when long-term dependencies are involved.
Unfortunately, to my knowledge the authors did not release any code, either for the model or experiments - although they did provide pseudo-code for the model. Since I thought it was a really cool read, I decided to reimplement the model as well as some of the experiments with the Pytorch framework.
I've separated the code for the UR-LSTM, which is packaged and downloadable as a standalone module, from the code for the experiments. If you want to check out how to run them, go check this page.
Installation
With Python 3.6 or higher:
pip install ur-lstm-torch
I haven't checked if the model is compatible with older versions of Pytorch, but it should be fine for everything past version 1.0
.
Usage
The model can be used in the same way as the native LSTM
implementation (doc is here), although I didn't implement the bidirectionnal variant and removed the bias
keyword argument:
import torch
from ur_lstm import URLSTM
input_size = 10
hidden_size = 20
num_layers = 2
batch_first = False
dropout = .5
model = URLSTM(
input_size, hidden_size, num_layers=num_layers, batch_first=batch_first, dropout=dropout
)
batch_size = 3
seq_length = 5
x = torch.randn(seq_length, batch_size, input_size)
out, state = model(x)
print(out.shape) # (seq_length, batch_size, hidden_size)
print(len(state)) # 2, first is hidden state, second is cell state
print(state[0].shape) # (num_layers, batch_size, hidden_size)
print(state[1].shape) # (num_layers, batch_size, hidden_size)
If you want to implement a custom model, you can also import and use the URLSTMCell
module in the same way you would the regular LSTMCell
(doc is here), although again I removed the bias
keyword argument:
import torch
from ur_lstm import URLSTMCell
input_size = 10
hidden_size = 20
cell = URLSTMCell(input_size, hidden_size)
batch_size = 2
x = torch.randn(batch_size, input_size)
state = torch.randn(batch_size, hidden_size), torch.randn(batch_size, hidden_size)
out, state = cell(x, state)
print(out.shape) # (batch_size, hidden_size)
print(len(state)) # 2, first is hidden state, second is cell state
print(state[0].shape) # (batch_size, hidden_size)
print(state[1].shape) # (num_layers, batch_size, hidden_size)
License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file ur-lstm-torch-0.0.2.tar.gz
.
File metadata
- Download URL: ur-lstm-torch-0.0.2.tar.gz
- Upload date:
- Size: 4.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.57.0 CPython/3.7.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1c775fefae3626a4c61b6c6b58e68a5575b5e7f8e3410e99f9136f93b7635b94 |
|
MD5 | 451c09e3245603e5db53d1c060117ebb |
|
BLAKE2b-256 | 893ef71b28268cd6aea9ed5ad6fe8ed11027135a8619c7d1794610b462b1f642 |
File details
Details for the file ur_lstm_torch-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: ur_lstm_torch-0.0.2-py3-none-any.whl
- Upload date:
- Size: 5.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.57.0 CPython/3.7.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 54cd76da7a0bac77dc9a5741b1b30b21f9fe6b902f94ea7a08ff99dfa06c4428 |
|
MD5 | 747d1d3cf7ae979fd3e0f1fb0a943546 |
|
BLAKE2b-256 | ba888192aabbd4decc593bb779ea4276bd530ddc31112e4369b89199d9324af8 |