Skip to main content

minGRU

Project description

minGRU

Implementation of the proposed minGRU in Pytorch, only the log-space numerically stable version.

Install

$ pip install minGRU-pytorch

Usage

import torch
from minGRU_pytorch import minGRU

min_gru = minGRU(512)

x = torch.randn(2, 1024, 512)

out = min_gru(x)

assert x.shape == out.shape

Sanity check

import torch
from minGRU_pytorch import minGRU

min_gru = minGRU(dim = 512, expansion_factor = 1.5)

x = torch.randn(1, 2048, 512)

# parallel

parallel_out = min_gru(x)[:, -1:]

# sequential

prev_hidden = None
for token in x.unbind(dim = 1):
    sequential_out, prev_hidden = min_gru(token[:, None, :], prev_hidden, return_next_prev_hidden = True)

assert torch.allclose(parallel_out, sequential_out, atol = 1e-4)

Test

enwik8

$ python train.py

Citations

@inproceedings{Feng2024WereRA,
    title   = {Were RNNs All We Needed?},
    author  = {Leo Feng and Frederick Tung and Mohamed Osama Ahmed and Yoshua Bengio and Hossein Hajimirsadegh},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:273025630}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mingru_pytorch-0.0.12.tar.gz (36.6 MB view details)

Uploaded Source

Built Distribution

mingru_pytorch-0.0.12-py3-none-any.whl (5.7 kB view details)

Uploaded Python 3

File details

Details for the file mingru_pytorch-0.0.12.tar.gz.

File metadata

  • Download URL: mingru_pytorch-0.0.12.tar.gz
  • Upload date:
  • Size: 36.6 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.20

File hashes

Hashes for mingru_pytorch-0.0.12.tar.gz
Algorithm Hash digest
SHA256 aed10f7b831c999559f53f3012895a14bab3f6c71586bc8d74ee0bf9ddb20a00
MD5 3d209218e00669af4cb955ebb334f304
BLAKE2b-256 767fb6e858d8e87c4d95ed5604b6032229d03f5aa73f232a0dadeb15ec63467f

See more details on using hashes here.

File details

Details for the file mingru_pytorch-0.0.12-py3-none-any.whl.

File metadata

File hashes

Hashes for mingru_pytorch-0.0.12-py3-none-any.whl
Algorithm Hash digest
SHA256 e0d9dde6cf8d0c76f8833345f240fb9e35f048fe5dd02249f72ecd7a9c33c6f8
MD5 98e94bd9136d2ed3347fa7d05f933f1b
BLAKE2b-256 9eac9b6f6d310738b2d77406d4bfd27dc3b8d2363842d4f094ee9795f1e6c814

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page