Skip to main content

minGRU

Project description

minGRU

Implementation of the proposed minGRU in Pytorch, only the log-space numerically stable version.

Yannic's paper review

Appreciation

Install

$ pip install minGRU-pytorch

Usage

import torch
from minGRU_pytorch import minGRU

min_gru = minGRU(512)

x = torch.randn(2, 1024, 512)

out = min_gru(x)

assert x.shape == out.shape

Sanity check

import torch
from minGRU_pytorch import minGRU

min_gru = minGRU(dim = 512, expansion_factor = 1.5)

x = torch.randn(1, 2048, 512)

# parallel

parallel_out = min_gru(x)[:, -1:]

# sequential

prev_hidden = None
for token in x.unbind(dim = 1):
    sequential_out, prev_hidden = min_gru(token[:, None, :], prev_hidden, return_next_prev_hidden = True)

assert torch.allclose(parallel_out, sequential_out, atol = 1e-4)

Test

enwik8

$ python train.py

Citations

@inproceedings{Feng2024WereRA,
    title   = {Were RNNs All We Needed?},
    author  = {Leo Feng and Frederick Tung and Mohamed Osama Ahmed and Yoshua Bengio and Hossein Hajimirsadegh},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:273025630}
}
@inproceedings{anonymous2024hymba,
    title   = {Hymba: A Hybrid-head Architecture for Small Language Models},
    author  = {Anonymous},
    booktitle = {Submitted to The Thirteenth International Conference on Learning Representations},
    year    = {2024},
    url     = {https://openreview.net/forum?id=A1ztozypga},
    note    = {under review}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mingru_pytorch-0.2.1.tar.gz (36.6 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mingru_pytorch-0.2.1-py3-none-any.whl (8.8 kB view details)

Uploaded Python 3

File details

Details for the file mingru_pytorch-0.2.1.tar.gz.

File metadata

  • Download URL: mingru_pytorch-0.2.1.tar.gz
  • Upload date:
  • Size: 36.6 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.21

File hashes

Hashes for mingru_pytorch-0.2.1.tar.gz
Algorithm Hash digest
SHA256 fc9b278b4e04db930a5cd4cdab36e2713b1d6e14084269b161d4f133819e758d
MD5 f180915351ad37c17c9dbebd79142f03
BLAKE2b-256 71b5ef07d75b1e7a33175c372332d3885bc4411908e0a9785325f92599005ac0

See more details on using hashes here.

File details

Details for the file mingru_pytorch-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: mingru_pytorch-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 8.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.21

File hashes

Hashes for mingru_pytorch-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 22e83b7e9437f4b1a04c21b7e47686fd6ef2fa60bc961cd19e295a592e5e58d7
MD5 33a53bb30157f7971cb02e1ddd2c2a50
BLAKE2b-256 60e5d2992e2813314add3a0d368d8b57904007a31917a5455005d1d7b8714c6d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page