Skip to main content

Unofficial codebase for the "Retentive Network: A Successor to Transformer for Large Language Models" paper [https://arxiv.org/pdf/2307.08621.pdf]

Project description

RetentiveNetwork

Unofficial codebase for the "Retentive Network: A Successor to Transformer for Large Language Models" paper [https://arxiv.org/pdf/2307.08621.pdf]

The official codebase for RetNet should be made available roughly August 1st, 2023 according to Microsoft here:

Getting Started

This library can be installed using pip.

pip install retentive-network

Example Training

The paper provides three forward passes which can all be used to train this model. However, the forward() and forward_chunkwise() are recommended for sample data and sample data with long sequences respectively. The forward_recurrent() method, while it can be used for training, the authors suggest using it for faster inference instead.

example-training-script

import torch
from retentive_network.models.clm import RetentiveNetworkCLM

batch_size = 8
sequence_length = 5
hidden_size = 32
number_of_heads = 4
number_of_layers = 4
feed_forward_size = 20
chunk_size = 2
samples = 100
vocab_size = 100

sample_data = torch.randint(0, vocab_size, (samples, batch_size, sequence_length))
labels = torch.randint(0, sequence_length, (samples,batch_size))

model = RetentiveNetworkCLM(
    number_of_layers=number_of_layers,
    hidden_size=hidden_size,
    number_of_heads=number_of_heads,
    feed_forward_size=feed_forward_size,
    vocab_size=vocab_size,
    chunk_size=chunk_size,
    softmax=True
)


optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()

initial_out = model(sample_data[0])
initial_loss = criterion(initial_out, labels[0])

for sample, label in zip(sample_data, labels):
    optimizer.zero_grad()

    out = model(sample)
    loss = criterion(out, label)
    loss.backward()
    optimizer.step()

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

retentive_network-0.1.0.tar.gz (11.1 kB view details)

Uploaded Source

Built Distribution

retentive_network-0.1.0-py3-none-any.whl (16.7 kB view details)

Uploaded Python 3

File details

Details for the file retentive_network-0.1.0.tar.gz.

File metadata

  • Download URL: retentive_network-0.1.0.tar.gz
  • Upload date:
  • Size: 11.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.9.17 Linux/5.15.0-1041-azure

File hashes

Hashes for retentive_network-0.1.0.tar.gz
Algorithm Hash digest
SHA256 bc7a3e2d961ee3972853f07cf3ef4d92544acbb329d9d1d1b37aa9565c74d2e4
MD5 bc84355866ba8fa6ea444c0dc5a961d0
BLAKE2b-256 5b049e0b623e62d7690a287666040b2b2f3d262dfd9caad2722f047ff5a4c8db

See more details on using hashes here.

File details

Details for the file retentive_network-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: retentive_network-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 16.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.9.17 Linux/5.15.0-1041-azure

File hashes

Hashes for retentive_network-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e3877ad35293097f518c332add21bf9b73274a3ddb9dd7724ecb3be1cac63e01
MD5 45ddbb2ff48a74a76a7553f65ba19c2b
BLAKE2b-256 4c193725bacec634c7deee4fb9abe5965d6da53e2ed13dbbaed134a65a134bfd

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page