Easy-to-use Retrieval-Enhanced Transformer (Retro) implementation
Project description
Retrieval-Enhanced Transformer (WIP)
Easy-to-use Retro implementation in PyTorch.
This code based on labml.ai and accelerate for light inference and training on CPUs, GPUs, TPUs.
from retro_transformer.bert import BERTForChunkEmbeddings
from retro_transformer.tools.database import build_database, RetroIndex
from retro_transformer.tools.dataset import build_dataset
from retro_transformer.model import RetroModel, NearestNeighborEncoder
from retro_transformer.tools.train import train
chunk_len = 16
d_model = 128
d_ff = 512
n_heads = 16
d_k = 16
n_layers = 16
workspace = './workspace'
text_file = 'text.txt'
bert = BERTForChunkEmbeddings('bert-base-uncased', 'cuda')
index = RetroIndex(workspace, chunk_len, bert=bert)
build_database(workspace, text_file, bert=bert, chunk_len=chunk_len)
num_tokens = build_dataset(workspace, text_file, chunk_len=chunk_len, index=index)
nearest_neighbor_encoder = NearestNeighborEncoder(chunk_len=chunk_len, n_layers=n_layers,
d_model=d_model, d_ff=d_ff, n_heads=n_heads,
d_k=d_k, ca_layers={3})
model = RetroModel(n_vocab=num_tokens, d_model=d_model, n_layers=n_layers, chunk_len=chunk_len,
n_heads=n_heads, d_k=d_k, d_ff=d_ff, encoder=nearest_neighbor_encoder, ca_layers={3, 5})
train(model, workspace, text_file, chunk_len=chunk_len, d_model=d_model)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
retro_transformer-1.0.3.tar.gz
(17.7 kB
view details)
File details
Details for the file retro_transformer-1.0.3.tar.gz.
File metadata
- Download URL: retro_transformer-1.0.3.tar.gz
- Upload date:
- Size: 17.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.9.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
341c21281d4d171b3e41a9f3a991230ffd8e5183e25de4e940da7a5233703a3c
|
|
| MD5 |
b7d4c85e856c15764cdd5959cfbc2bfe
|
|
| BLAKE2b-256 |
ab36583afbc26c86be6fa9ceba830d532518a380b72aebd048f579d4c8145a2a
|