Single NT/AA resoultion biological GPT2 language modelling
Project description
gpt2-prot
Train biological language models at single NT or AA resolution.
This is a simple framework for training DNA or protein language models using an easily modifiable GPT-2 architecture. Training data, model hyperparameters and training settings can be easily configured using composable yaml config files and extendable using any pytorch-lightning settings.
Features
- Simple and extendable data handling and GPT2 implementation using pytorch lightning
- Supports protein and DNA modelling out of the box
- Underlying torch dataset can download datasets (eg. from Uniprot or NCBI) and caches encoded data into a memory mapped array for handling large numbers of sequences
Recipes
Note: These need to be fully tested and final model and data parameters will change.
Installation
pip install gpt2_prot
Usage
From the CLI
gpt2-prot -h # Show the CLI help
# Launch tensorboard to view loss, perplexity and model generations during training:
tensorboard --logdir lightning_logs/ &
# Run the demo config for cas9 protein language modelling:
# Since this uses Lightning you can overwrite parameters from the config using the command line
gpt2-prot fit --config recipes/cas9_analogues.yml --max_epochs 10
# Generate new sequences and configure the prompt:
gpt2-prot predict --config cas9_analog_generator.yml --data.prompt MATT --data.n_samples 50
Yaml config (Tiny Cas9 protein language model demo)
seed_everything: 0
ckpt_path: last # Loads the most recent checkpoint in `checkpoints/`
trainer:
max_epochs: 1000
log_every_n_steps: 25
fast_dev_run: false
enable_checkpointing: true
# Preconfigured TensorBoard logger
logger:
- class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: "."
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
dirpath: "checkpoints/" # Needs to be set for ckpt_path to correctly load `last`
save_last: true
# Configurable monitoring of model generations during training:
- class_path: PreviewCallback
init_args:
mode: "aa"
prompt: "M"
length: 75
# Inference mode config:
- class_path: FastaInferenceWriter
init_args:
mode: "aa"
output_file: "predictions.fasta"
max_tokens: 100
t: 1.0
sample: true
top_k: 5
# Model and optimiser hyperparameters:
model:
config:
vocab_size: 24 # mode dependent: aa -> 24, nt -> 5
window_size: 16
n_layers: 2
n_heads: 2
embed_d: 128
emb_dropout: 0.1
attn_dropout: 0.1
res_dropout: 0.1
adam_lr: 0.0003
adam_weight_decay: 0.1
adam_betas: [0.90, 0.95]
# Lightning datamodule parameters:
data:
mode: "aa"
directory: "seqs/"
batch_size: 1
max_seq_length: 100
n_seq_limit: 500
loader_num_workers: 2
# The datamodule can also handle downloading datasets:
downloads: [
["https://rest.uniprot.org/uniprotkb/stream?compressed=true&format=fasta&query=%28gene%3Acas9%29", "uniprot_cas9.fasta.gz"]
]
# Optionally set the inference prompt:
prompt: "M"
n_samples: 100
Development
Installation From source
micromamba create -f environment.yml # or conda etc.
micromamba activate gpt2-prot
pip install . # Basic install
pip install -e ".[dev]" # Install in editable mode with dev dependencies
pip install ".[test]" # Install the package and all test dependencies
Running pre-commit hooks
# Install the hooks:
pre-commit install
# Run all the hooks:
pre-commit run --all-files
# Run unit tests:
pytest
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file gpt2_prot-0.3.tar.gz.
File metadata
- Download URL: gpt2_prot-0.3.tar.gz
- Upload date:
- Size: 15.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.0 CPython/3.12.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
34754dbcf55894770af3b6b0487f762070dacd9f97f8c050f7deef9eb4c60b5b
|
|
| MD5 |
e2df618337ef52fe032b934c466b1a08
|
|
| BLAKE2b-256 |
a079e181f9710583de025f1d8267d5eb14d5d982c9ab723ab3a610a0684654d4
|
File details
Details for the file gpt2_prot-0.3-py3-none-any.whl.
File metadata
- Download URL: gpt2_prot-0.3-py3-none-any.whl
- Upload date:
- Size: 15.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.0 CPython/3.12.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
dd1dd8d7576adc13ea82b0f61c32094c0c9430bcd5291694112d52a3594196f8
|
|
| MD5 |
f4ae859c431d21f1298155bc44901ddf
|
|
| BLAKE2b-256 |
08a4ed956db0b8909546f9425f0abad74844aae280352184419f44d7eb42276b
|