Skip to main content

Unofficial UR-LSTM implementation in Pytorch

Project description

UR-LSTM

Description

This repository revolves around the paper: Improving the Gating Mechanism of Recurrent Neural Networks by Albert Gu, Caglar Gulcehre, Tom Paine, Matt Hoffman and Razvan Pascanu.

In it, the authors introduce the UR-LSTM, a variant of the LSTM architecture which robustly improves the performance of the recurrent model, particularly when long-term dependencies are involved.

Unfortunately, to my knowledge the authors did not release any code, either for the model or experiments - although they did provide pseudo-code for the model. Since I thought it was a really cool read, I decided to reimplement the model as well as some of the experiments with the Pytorch framework.

I've separated the code for the UR-LSTM, which is packaged and downloadable as a standalone module, from the code for the experiments. If you want to check out how to run them, go check this page.

Installation

With Python 3.6 or higher:

pip install ur-lstm-torch

I haven't checked if the model is compatible with older versions of Pytorch, but it should be fine for everything past version 1.0.

Usage

The model can be used in the same way as the native LSTM implementation (documented over here), although I didn't implement the bidirectionnal variant and removed the bias keyword argument:

import torch
from ur_lstm import URLSTM

input_size = 10
hidden_size = 20
num_layers = 2
batch_first = False
dropout = .5

model = URLSTM(input_size, hidden_size, num_layers=num_layers, batch_first=batch_first, dropout=dropout)

batch_size = 2
seq_length = 5

x = torch.randn(seq_length, batch_size, input_size)
out, state = model(x)

print(out.shape) # (seq_length, batch_size, hidden_size)
print(len(state)) # 2, first is hidden state, second is cell state
print(state[0].shape) # (num_layers, batch_size, hidden_size)
print(state[1].shape) # (num_layers, batch_size, hidden_size)

If you want to implement a custom model, you can also import and use the URLSTMCell module in the same way you would the regular LSTMCell (documented over here), although again I removed the bias keyword argument:

import torch
from ur_lstm import URLSTMCell

input_size = 10
hidden_size = 20

cell = URLSTMCell(input_size, hidden_size)

batch_size = 2

x = torch.randn(batch_size, input_size)
state = torch.randn(batch_size, hidden_size), torch.randn(batch_size, hidden_size)
out, state = cell(x, state)

print(out.shape) # (batch_size, hidden_size)
print(len(state)) # 2, first is hidden state, second is cell state
print(state[0].shape) # (batch_size, hidden_size)
print(state[1].shape) # (num_layers, batch_size, hidden_size)

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ur-lstm-torch-0.0.1.tar.gz (4.1 kB view details)

Uploaded Source

Built Distribution

ur_lstm_torch-0.0.1-py3-none-any.whl (5.0 kB view details)

Uploaded Python 3

File details

Details for the file ur-lstm-torch-0.0.1.tar.gz.

File metadata

  • Download URL: ur-lstm-torch-0.0.1.tar.gz
  • Upload date:
  • Size: 4.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.57.0 CPython/3.7.6

File hashes

Hashes for ur-lstm-torch-0.0.1.tar.gz
Algorithm Hash digest
SHA256 0b7e8a0321492d3b683a32b306e2f515cab941d40faad48e9aaa8fd386977a6e
MD5 b20e3a06caaaacf1bbebe3c1713119ad
BLAKE2b-256 8f9bc38cc5fb0b36b9d7876efb75b44dc5dec39b3acb52e878850b9bd837a5fb

See more details on using hashes here.

File details

Details for the file ur_lstm_torch-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: ur_lstm_torch-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 5.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.57.0 CPython/3.7.6

File hashes

Hashes for ur_lstm_torch-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 a91a6ed42104fabf4ddc59237e11a5b54ca9acd276d3e185463b18a35ad1cfc9
MD5 ab2e6ddefa6c4f0842afd78616403eb9
BLAKE2b-256 d1b838591600a0e13971a953c04585f442833c9c1dc463f30c60d160f28262b6

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page