Skip to main content

Easy-to-use Retrieval-Enhanced Transformer (Retro) implementation

Project description

RETRO

Retrieval-Enhanced Transformer (WIP)

Easy-to-use Retro implementation in PyTorch.

This code based on labml.ai and accelerate for light inference and training on CPUs, GPUs, TPUs.

from retro_transformer.bert import BERTForChunkEmbeddings
from retro_transformer.tools.database import build_database, RetroIndex
from retro_transformer.tools.dataset import build_dataset
from retro_transformer.model import RetroModel, NearestNeighborEncoder
from retro_transformer.tools.train import train

chunk_len = 16
d_model = 128
d_ff = 512
n_heads = 16
d_k = 16
n_layers = 16
workspace = './workspace'
text_file = 'text.txt'

bert = BERTForChunkEmbeddings('bert-base-uncased', 'cuda')
index = RetroIndex(workspace, chunk_len, bert=bert)

build_database(workspace, text_file, bert=bert, chunk_len=chunk_len)
num_tokens = build_dataset(workspace, text_file, chunk_len=chunk_len, index=index)

nearest_neighbor_encoder = NearestNeighborEncoder(chunk_len=chunk_len, n_layers=n_layers,
                                                  d_model=d_model, d_ff=d_ff, n_heads=n_heads,
                                                  d_k=d_k, ca_layers={3})

model = RetroModel(n_vocab=num_tokens, d_model=d_model, n_layers=n_layers, chunk_len=chunk_len,
                   n_heads=n_heads, d_k=d_k, d_ff=d_ff, encoder=nearest_neighbor_encoder, ca_layers={3, 5})

train(model, workspace, text_file, chunk_len=chunk_len, d_model=d_model)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

retro_transformer-1.0.3.tar.gz (17.7 kB view details)

Uploaded Source

File details

Details for the file retro_transformer-1.0.3.tar.gz.

File metadata

  • Download URL: retro_transformer-1.0.3.tar.gz
  • Upload date:
  • Size: 17.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.9.7

File hashes

Hashes for retro_transformer-1.0.3.tar.gz
Algorithm Hash digest
SHA256 341c21281d4d171b3e41a9f3a991230ffd8e5183e25de4e940da7a5233703a3c
MD5 b7d4c85e856c15764cdd5959cfbc2bfe
BLAKE2b-256 ab36583afbc26c86be6fa9ceba830d532518a380b72aebd048f579d4c8145a2a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page