Unofficial codebase for the "Retentive Network: A Successor to Transformer for Large Language Models" paper [https://arxiv.org/pdf/2307.08621.pdf]
Project description
RetentiveNetwork
Unofficial codebase for the "Retentive Network: A Successor to Transformer for Large Language Models" paper [https://arxiv.org/pdf/2307.08621.pdf]
The official codebase for RetNet should be made available roughly August 1st, 2023 according to Microsoft here:
Getting Started
This library can be installed using pip.
pip install retentive-network
Example Training
The paper provides three forward passes which can all be used to train this model. However,
the forward()
and forward_chunkwise()
are recommended for sample data and sample data
with long sequences respectively. The forward_recurrent()
method, while it can be used for
training, the authors suggest using it for faster inference instead.
import torch
from retentive_network.models.clm import RetentiveNetworkCLM
batch_size = 8
sequence_length = 5
hidden_size = 32
number_of_heads = 4
number_of_layers = 4
feed_forward_size = 20
chunk_size = 2
samples = 100
vocab_size = 100
sample_data = torch.randint(0, vocab_size, (samples, batch_size, sequence_length))
labels = torch.randint(0, sequence_length, (samples,batch_size))
model = RetentiveNetworkCLM(
number_of_layers=number_of_layers,
hidden_size=hidden_size,
number_of_heads=number_of_heads,
feed_forward_size=feed_forward_size,
vocab_size=vocab_size,
chunk_size=chunk_size,
softmax=True
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()
initial_out = model(sample_data[0])
initial_loss = criterion(initial_out, labels[0])
for sample, label in zip(sample_data, labels):
optimizer.zero_grad()
out = model(sample)
loss = criterion(out, label)
loss.backward()
optimizer.step()
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file retentive_network-0.1.0.tar.gz
.
File metadata
- Download URL: retentive_network-0.1.0.tar.gz
- Upload date:
- Size: 11.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.5.1 CPython/3.9.17 Linux/5.15.0-1041-azure
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | bc7a3e2d961ee3972853f07cf3ef4d92544acbb329d9d1d1b37aa9565c74d2e4 |
|
MD5 | bc84355866ba8fa6ea444c0dc5a961d0 |
|
BLAKE2b-256 | 5b049e0b623e62d7690a287666040b2b2f3d262dfd9caad2722f047ff5a4c8db |
File details
Details for the file retentive_network-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: retentive_network-0.1.0-py3-none-any.whl
- Upload date:
- Size: 16.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.5.1 CPython/3.9.17 Linux/5.15.0-1041-azure
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e3877ad35293097f518c332add21bf9b73274a3ddb9dd7724ecb3be1cac63e01 |
|
MD5 | 45ddbb2ff48a74a76a7553f65ba19c2b |
|
BLAKE2b-256 | 4c193725bacec634c7deee4fb9abe5965d6da53e2ed13dbbaed134a65a134bfd |