Single NT/AA resoultion biological GPT2 language modelling
Project description
gpt2-prot
Train biological language models at single NT or AA resolution.
This is a simple framework for training DNA or protein language models using an easily modifiable GPT-2 architecture. Training data, model hyperparameters and training settings can be easily configured using composable yaml
config files and extendable using any pytorch-lightning
settings.
Features
- Simple and extendable data handling and GPT2 implementation using pytorch lightning
- Supports protein and DNA modelling out of the box
- Underlying torch dataset can download datasets (eg. from Uniprot or NCBI) and caches encoded data into a memory mapped array for handling large numbers of sequences
Recipes
Note: These need to be fully tested and final model and data parameters will change.
Installation
pip install gpt2_prot
Usage
From the CLI
gpt2-prot -h # Show the CLI help
# Launch tensorboard to view loss, perplexity and model generations during training:
tensorboard --logdir lightning_logs/ &
# Run the demo config for cas9 protein language modelling:
# Since this uses Lightning you can overwrite parameters from the config using the command line
gpt2-prot fit --config recipes/cas9_analogues.yml --max_epochs 10
# Generate new sequences and configure the prompt:
gpt2-prot predict --config cas9_analog_generator.yml --data.prompt MATT --data.n_samples 50
Yaml config (Tiny Cas9 protein language model demo)
seed_everything: 0
ckpt_path: last # Loads the most recent checkpoint in `checkpoints/`
trainer:
max_epochs: 1000
log_every_n_steps: 25
fast_dev_run: false
enable_checkpointing: true
# Preconfigured TensorBoard logger
logger:
- class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: "."
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
dirpath: "checkpoints/" # Needs to be set for ckpt_path to correctly load `last`
save_last: true
# Configurable monitoring of model generations during training:
- class_path: PreviewCallback
init_args:
mode: "aa"
prompt: "M"
length: 75
# Inference mode config:
- class_path: FastaInferenceWriter
init_args:
mode: "aa"
output_file: "predictions.fasta"
max_tokens: 100
t: 1.0
sample: true
top_k: 5
# Model and optimiser hyperparameters:
model:
config:
vocab_size: 24 # mode dependent: aa -> 24, nt -> 5
window_size: 16
n_layers: 2
n_heads: 2
embed_d: 128
emb_dropout: 0.1
attn_dropout: 0.1
res_dropout: 0.1
adam_lr: 0.0003
adam_weight_decay: 0.1
adam_betas: [0.90, 0.95]
# Lightning datamodule parameters:
data:
mode: "aa"
directory: "seqs/"
batch_size: 1
max_seq_length: 100
n_seq_limit: 500
loader_num_workers: 2
# The datamodule can also handle downloading datasets:
downloads: [
["https://rest.uniprot.org/uniprotkb/stream?compressed=true&format=fasta&query=%28gene%3Acas9%29", "uniprot_cas9.fasta.gz"]
]
# Optionally set the inference prompt:
prompt: "M"
n_samples: 100
Development
Installation From source
micromamba create -f environment.yml # or conda etc.
micromamba activate gpt2-prot
pip install . # Basic install
pip install -e ".[dev]" # Install in editable mode with dev dependencies
pip install ".[test]" # Install the package and all test dependencies
Running pre-commit hooks
# Install the hooks:
pre-commit install
# Run all the hooks:
pre-commit run --all-files
# Run unit tests:
pytest
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
gpt2_prot-0.3.tar.gz
(15.8 kB
view hashes)
Built Distribution
gpt2_prot-0.3-py3-none-any.whl
(15.9 kB
view hashes)