Skip to main content

minGRU

Project description

minGRU

Implementation of the proposed minGRU in Pytorch, only the log-space numerically stable version.

Yannic's paper review

Install

$ pip install minGRU-pytorch

Usage

import torch
from minGRU_pytorch import minGRU

min_gru = minGRU(512)

x = torch.randn(2, 1024, 512)

out = min_gru(x)

assert x.shape == out.shape

Sanity check

import torch
from minGRU_pytorch import minGRU

min_gru = minGRU(dim = 512, expansion_factor = 1.5)

x = torch.randn(1, 2048, 512)

# parallel

parallel_out = min_gru(x)[:, -1:]

# sequential

prev_hidden = None
for token in x.unbind(dim = 1):
    sequential_out, prev_hidden = min_gru(token[:, None, :], prev_hidden, return_next_prev_hidden = True)

assert torch.allclose(parallel_out, sequential_out, atol = 1e-4)

Test

enwik8

$ python train.py

Citations

@inproceedings{Feng2024WereRA,
    title   = {Were RNNs All We Needed?},
    author  = {Leo Feng and Frederick Tung and Mohamed Osama Ahmed and Yoshua Bengio and Hossein Hajimirsadegh},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:273025630}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mingru_pytorch-0.0.14.tar.gz (36.6 MB view details)

Uploaded Source

Built Distribution

mingru_pytorch-0.0.14-py3-none-any.whl (5.8 kB view details)

Uploaded Python 3

File details

Details for the file mingru_pytorch-0.0.14.tar.gz.

File metadata

  • Download URL: mingru_pytorch-0.0.14.tar.gz
  • Upload date:
  • Size: 36.6 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.20

File hashes

Hashes for mingru_pytorch-0.0.14.tar.gz
Algorithm Hash digest
SHA256 e97a2d7b683184fa1ce83bdae6ae8919e6008ce3bdcc178a3771592cac92f221
MD5 6c1485e9362b3ef2dd1109dc7fabd0ab
BLAKE2b-256 0f1469e75dc161d5262ba64f84f4440b58b8ab6669cc374b994d6a72b3e9383c

See more details on using hashes here.

File details

Details for the file mingru_pytorch-0.0.14-py3-none-any.whl.

File metadata

File hashes

Hashes for mingru_pytorch-0.0.14-py3-none-any.whl
Algorithm Hash digest
SHA256 9f228910ef449f4e84d68bb28074bdf232cffd38da399d5c5e2b64d20704e374
MD5 8779af90c7faa9596863b96ce938bfbd
BLAKE2b-256 90154bbd7b2d846eedbb3a6178d27c4987fae9ccad76e7308befbd53e50fe6ab

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page