Skip to main content

Unofficial UR-LSTM implementation in Pytorch

Project description

UR-LSTM

Description

This repository revolves around the paper: Improving the Gating Mechanism of Recurrent Neural Networks by Albert Gu, Caglar Gulcehre, Tom Paine, Matt Hoffman and Razvan Pascanu.

In it, the authors introduce the UR-LSTM, a variant of the LSTM architecture which robustly improves the performance of the recurrent model, particularly when long-term dependencies are involved.

Unfortunately, to my knowledge the authors did not release any code, either for the model or experiments - although they did provide pseudo-code for the model. Since I thought it was a really cool read, I decided to reimplement the model as well as some of the experiments with the Pytorch framework.

I've separated the code for the UR-LSTM, which is packaged and downloadable as a standalone module, from the code for the experiments. If you want to check out how to run them, go check this page.

Installation

With Python 3.6 or higher:

pip install ur-lstm-torch

I haven't checked if the model is compatible with older versions of Pytorch, but it should be fine for everything past version 1.0.

Usage

The model can be used in the same way as the native LSTM implementation (doc is here), although I didn't implement the bidirectionnal variant and removed the bias keyword argument:

import torch
from ur_lstm import URLSTM

input_size = 10
hidden_size = 20
num_layers = 2
batch_first = False
dropout = .5

model = URLSTM(
    input_size, hidden_size, num_layers=num_layers, batch_first=batch_first, dropout=dropout
)

batch_size = 3
seq_length = 5

x = torch.randn(seq_length, batch_size, input_size)
out, state = model(x)

print(out.shape) # (seq_length, batch_size, hidden_size)
print(len(state)) # 2, first is hidden state, second is cell state
print(state[0].shape) # (num_layers, batch_size, hidden_size)
print(state[1].shape) # (num_layers, batch_size, hidden_size)

If you want to implement a custom model, you can also import and use the URLSTMCell module in the same way you would the regular LSTMCell (doc is here), although again I removed the bias keyword argument:

import torch
from ur_lstm import URLSTMCell

input_size = 10
hidden_size = 20

cell = URLSTMCell(input_size, hidden_size)

batch_size = 2

x = torch.randn(batch_size, input_size)
state = torch.randn(batch_size, hidden_size), torch.randn(batch_size, hidden_size)
out, state = cell(x, state)

print(out.shape) # (batch_size, hidden_size)
print(len(state)) # 2, first is hidden state, second is cell state
print(state[0].shape) # (batch_size, hidden_size)
print(state[1].shape) # (num_layers, batch_size, hidden_size)

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ur-lstm-torch-0.0.2.tar.gz (4.1 kB view details)

Uploaded Source

Built Distribution

ur_lstm_torch-0.0.2-py3-none-any.whl (5.0 kB view details)

Uploaded Python 3

File details

Details for the file ur-lstm-torch-0.0.2.tar.gz.

File metadata

  • Download URL: ur-lstm-torch-0.0.2.tar.gz
  • Upload date:
  • Size: 4.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.57.0 CPython/3.7.6

File hashes

Hashes for ur-lstm-torch-0.0.2.tar.gz
Algorithm Hash digest
SHA256 1c775fefae3626a4c61b6c6b58e68a5575b5e7f8e3410e99f9136f93b7635b94
MD5 451c09e3245603e5db53d1c060117ebb
BLAKE2b-256 893ef71b28268cd6aea9ed5ad6fe8ed11027135a8619c7d1794610b462b1f642

See more details on using hashes here.

File details

Details for the file ur_lstm_torch-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: ur_lstm_torch-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 5.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.57.0 CPython/3.7.6

File hashes

Hashes for ur_lstm_torch-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 54cd76da7a0bac77dc9a5741b1b30b21f9fe6b902f94ea7a08ff99dfa06c4428
MD5 747d1d3cf7ae979fd3e0f1fb0a943546
BLAKE2b-256 ba888192aabbd4decc593bb779ea4276bd530ddc31112e4369b89199d9324af8

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page