Skip to main content

a Pytorch implementation of the Reformer network (https://openreview.net/forum?id=rkgNKkHtvB)

Project description

Reformer

a Pytorch implementation of the Reformer Network (https://openreview.net/pdf?id=rkgNKkHtvB)

Much of this code base is loosely translated from the jax implementation found here from Google: https://github.com/google/trax/blob/master/trax/models/research/reformer.py

How to use

All of the hard work has been taken care of, all you need to do is instantiate the model!

from reformer_lm.reformer_lm import ReformerLM
import torch

test = torch.rand((4, 4, 64))
model = ReformerLM(
    vocab_size=300000,
    d_in=test.shape[-2],
    d_out=test.shape[-1],
    n_layers=6,
    n_heads=1,
    attn_k=test.shape[-1],
    attn_v=test.shape[-1],
)

output = model(test)
print(output)

This model is still in testing, and will therefore continue to see updates. PRs are welcomed! Feel free to take advantage of the Docker container for development. I have been working in notebooks to test code with the original paper, and then I refactor my code back into the package

paypal

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for reformer-lm, version 1.0.1
Filename, size File type Python version Upload date Hashes
Filename, size reformer_lm-1.0.1-py3-none-any.whl (7.6 kB) File type Wheel Python version py3 Upload date Hashes View hashes
Filename, size reformer_lm-1.0.1.tar.gz (5.6 kB) File type Source Python version None Upload date Hashes View hashes

Supported by

Elastic Elastic Search Pingdom Pingdom Monitoring Google Google BigQuery Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN DigiCert DigiCert EV certificate StatusPage StatusPage Status page