Skip to main content

Transformer Lab - experimental implementation and training of transformer models at scale, using modern techniques

Project description

XLab

PyPI Project PyTorch Powered Lightning Powered

Transformer Lab - experimental implementation and training of transformer models at scale, using modern techniques.

It supports:

  • Configurable model sizes, architecture modifications, datasets, tokenizer settings, and training procedure
  • Parallel and reproducible training on multiple GPUs and nodes
  • Inference using different algorithms / sampling strategies (e.g. temperature, top-p, top-k, beam search)

Getting started

Setup

First, clone this repo and create a virtual environment. Install dependencies by running:

pip install -r requirements.txt

Alternatively, if you only intend to import the model (e.g. for inference, or to use your own training script), you can simply install the PyPI package:

pip install xlabml

Download weights

If you want to run inference using pre-trained weights, download the .pt file(s) from this project's releases.

Train the model

You can skip training if you only want to run inference using pre-trained weights. To train the model, run the following command:

./xlab.py fit -c conf/xlab.yaml

Specifying -c conf/xlab.yaml tells the training script to use a larger dataset and model (the default ones are intended for quick experiments). This will also download and pre-process the dataset, as well as train the tokenizer, which takes about 2 hours. The actual training takes about 10 hours per epoch on an A100 GPU. For models with less memory, you may need to modify the configuration and decrease the context size and/or batch size. By default, the learning curves are saved in TensorBoard format, and you can monitor them by running:

tensorboard --logdir .

Keep this running, point your browser to http://localhost:6006/, and click the Scalars tab.

Validate (optional)

./xlab.py validate -c conf/xlab.yaml --ckpt_path PATH

where PATH points to a checkpoint, downloaded from this project's releases, or saved during training. To evaluate on the test set, replace valudate with test in the above command.

Inference

For basic inference using multinomial sampling, run:

./infer.py [OPTIONS] CHECKPOINT_PATH "PROMPT"

To see other inference options, run ./infer.py --help.

Exporting model weights and tokenizer vocabulary

Checkpoints created during training contain not only the model weights, but also the optimizer state and other information needed to resume training from a saved checkpoint. This makes the checkpoints 3x larger than the actual model weights. To export a "clean" checkpoint, containing only the weights and vocabulary, run:

./manage.py export-checkpoint CHECKPOINT_PATH

Use in your code

import torch
from xlabml.datamodules import XLabDataModule
from xlabml.models import XLabModel
from xlabml import inference

# adjust these
checkpoint_path = 'logs/version_0/checkpoints/last.ckpt'
prompt = 'april'
device = 'cuda'
limit = 10

tokenizer = XLabDataModule.load_from_checkpoint(checkpoint_path, map_location=device).tokenizer
model = XLabModel.load_from_checkpoint(checkpoint_path, map_location=device).eval().requires_grad_(False)
inputs = torch.tensor([tokenizer[tokenizer.sos_token]] + tokenizer.encode(prompt), device=model.device)
outputs = inference.sample(
    model, inputs, limit,
    block_size=model.hparams['max_len'],
    eos_class=tokenizer[tokenizer.eos_token]
)
output = tokenizer.decode(outputs.tolist())

Configuration

All configuration and hyperparameters are exposed in YAML files, passed to the training/validation script. Hyperparameters are saved in checkpoints and automatically restored when loading. The default settings are in conf/defaults.yaml. Additional YAML configuration can be specified with the -c PATH option. See conf/xlab.yaml for the configuration used to train the current release model. Additional options (or overrides of the above configuration) can be specified on the command line. To see the full list, run ./xlab.py --help.

Model

A transformer decoder, corresponding to a causal (unidirectional) encoder in the original architecture (Vaswani et al. 2017), "base" variant, with the following modifications:

  • Normalization is performed before the transformer sublayers. An extra normalization layer is added after the last feedforward sublayer. This improved both performance and training stability.
  • The GELU activation function is used (performance improvement)
  • Dropout is only applied in the attention and feedforward sublayers (to the product of the queries and keys, and to the hidden activations, respectively) (performance improvement)

Positional encodings are used, because learned positional embeddings degrade performance in the current setup.

Tokenizer

The implementation from torchtext, which lowercases the input text, strips punctuation, and yields tokens between whitespace boundaries. The vocabulary is built from the 32K tokens with the highest frequencies across the training set. To avoid misinterpretation and improve reversibility, tokens matching special values (e.g. <unk> and <pad>) are escaped.

Dataset

wikimedia/wikipedia, 20231101.en, split into:

  • train: 90%
  • val: 5%
  • test: 2.5%
  • predict: 2.5%

Articles (texts) are chunked into sequences with 50% overlap.

Training

The model was trained on sequences of maximum length 256. To speed up training and reduce memory usage, 16-bit mixed precision is used. To mitigate stability issues, the bfloat16 data type is used, along with gradient clipping by norm 1.0. The AdamW optimizer is used, with learning rate 3e-4 and weight decay 0.1. The training ran on a single A100 GPU, with batch size 256, and was stopped after 4 epochs (440K steps after 2 days).

Results

Version Checkpoint Loss (test) Accuracy (test)
0.1 last 3.18 40.0%

Generations

Prompt: april is
april is the second album of the band dead from dead . it was released on february 15 , 2007 to predominantly negative reviews . it features several varied development and music styles . the album peaked at number only in the <unk> area , thus breaking the album ' s passage into a new york times best seller . it was also ranked as the fourth best album of 2007 by music critic roger ebert . track listing spirit of the line ( ' ticket to ride ' ) meant to awake into us quarter-finals phenomenon memory of misery activates fire

Prompt: the world
the world bowling championships were a women ' s national bowling championships organized in walnut street , new york city to open in 1980 . it is the world development and development archive event at the world bowling hall of fame , located in oak park , wisconsin . initially developed as a bowling and tennis track for workers , it was expanded into winter surface courses as a grassroots project to preserve open library needed oral materials to support the expanded bowling programs for disabled adults . medal summary results by round matches us quarter-finals u . s . championship top

Prompt: cats and dogs
cats and dogs ( , , ) is a 1986 indian malayalam-language drama film directed by m . k . raman nair and produced by the film production company <unk> development . it stars <unk> <unk> and <unk> , while <unk> in the lead roles , <unk> <unk> , and muhammed in three members of the comedy team . the film is a remake of the 1989 hindi film <unk> . it was remade in telugu as oral cough . inscription the film was released on 6 april 1986 in kerala . plot a criminal named <unk> visits rani ( <unk> <unk> )

All of these can be reproduced with the included inference script, using random seed 42 and limit 100.

Future work

  • use a BPE tokenizer, e.g. from tiktoken or sentencepiece
  • group sequences of similar lengths in the same training batches
  • add learning rate scheduling (e.g. cosine with warmup, or reduce lr on plateau)
  • rotary positional embeddings
  • RMSNorm
  • SwiGLU activation
  • increase the maximum context length, e.g. via gradient checkpointing
  • train a larger model on a larger and more diverse dataset
  • fine-tune for a downstream task
  • quantization
  • compile the model during training, and for inference
  • cache KV pairs in inference, try multi-query/grouped-query attention

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

xlabml-0.1.1.tar.gz (27.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

xlabml-0.1.1-py3-none-any.whl (22.4 kB view details)

Uploaded Python 3

File details

Details for the file xlabml-0.1.1.tar.gz.

File metadata

  • Download URL: xlabml-0.1.1.tar.gz
  • Upload date:
  • Size: 27.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-requests/2.32.3

File hashes

Hashes for xlabml-0.1.1.tar.gz
Algorithm Hash digest
SHA256 a60038f6005268eef29c53e6de97fb72f5069d03971349734fcb13dc547a4191
MD5 e3a725a93c50efd20419715b936a062b
BLAKE2b-256 94d69eb3cb6a1801964bd290c1644dc6a4d422152931ddf0fbe212a0d90a19d8

See more details on using hashes here.

File details

Details for the file xlabml-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: xlabml-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 22.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-requests/2.32.3

File hashes

Hashes for xlabml-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 feb65b1422fa5f084135faac0c2b92b3f067ac1f02e51fbad95ac610554dfeb1
MD5 09aef7564dd6d3bf1d3311e8c5fbaf00
BLAKE2b-256 4012d62e32be4a02f2ff4c804466b3a6195604ff5eba66e9f9000823e7cb8649

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page