a Pytorch implementation of the Reformer network (https://openreview.net/forum?id=rkgNKkHtvB)
Project description
Reformer
a Pytorch implementation of the Reformer Network (https://openreview.net/pdf?id=rkgNKkHtvB)
Much of this code base is loosely translated from the jax implementation found here from Google: https://github.com/google/trax/blob/master/trax/models/research/reformer.py
How to use
All of the hard work has been taken care of, all you need to do is instantiate the model!
from reformer_lm.reformer_lm import ReformerLM
import torch
test = torch.rand((4, 4, 64))
model = ReformerLM(
vocab_size=300000,
d_in=test.shape[-2],
d_out=test.shape[-1],
n_layers=6,
n_heads=1,
attn_k=test.shape[-1],
attn_v=test.shape[-1],
)
output = model(test)
print(output)
This model is still in testing, and will therefore continue to see updates. PRs are welcomed! Feel free to take advantage of the Docker container for development. I have been working in notebooks to test code with the original paper, and then I refactor my code back into the package
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Hashes for reformer_lm-1.0.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 594ce4ad277ca104f78e218f1dd878e917e7a5903e025290aef5b667f4c3b26f |
|
MD5 | a2de8dfee6fec2114d8c2d08beacc455 |
|
BLAKE2b-256 | c7e8d54b1a9bc1b1a432edabd7ba631d75bf98fbd4ba43ec8775bba6cdbf0939 |