Skip to main content

Single NT/AA resoultion biological GPT2 language modelling

Project description

gpt2-prot

Train biological language models at single NT or AA resolution.

This is a simple framework for training DNA or protein language models using an easily modifiable GPT-2 architecture. Training data, model hyperparameters and training settings can be easily configured using composable yaml config files and extendable using any pytorch-lightning settings.

Features

  • Simple and extendable data handling and GPT2 implementation using pytorch lightning
  • Supports protein and DNA modelling out of the box
  • Underlying torch dataset can download datasets (eg. from Uniprot or NCBI) and caches encoded data into a memory mapped array for handling large numbers of sequences

Recipes

  1. Cas9 analogue generator
  2. Human genome foundation model
  3. Uniref50 protein foundation model

Note: These need to be fully tested and final model and data parameters will change.

Installation

pip install gpt2_prot

Usage

From the CLI

gpt2-prot -h  # Show the CLI help

# Launch tensorboard to view loss, perplexity and model generations during training:
tensorboard --logdir lightning_logs/ &

# Run the demo config for cas9 protein language modelling:
# Since this uses Lightning you can overwrite parameters from the config using the command line
gpt2-prot fit --config recipes/cas9_analogues.yml --max_epochs 10

# Generate new sequences and configure the prompt:
gpt2-prot predict --config cas9_analog_generator.yml --data.prompt MATT --data.n_samples 50

Yaml config (Tiny Cas9 protein language model demo)

seed_everything: 0
ckpt_path: last  # Loads the most recent checkpoint in `checkpoints/`

trainer:
  max_epochs: 1000
  log_every_n_steps: 25
  fast_dev_run: false
  enable_checkpointing: true
  
  # Preconfigured TensorBoard logger
  logger:
    - class_path: lightning.pytorch.loggers.TensorBoardLogger
      init_args:
        save_dir: "."

  callbacks: 
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        dirpath: "checkpoints/"  # Needs to be set for ckpt_path to correctly load `last`
        save_last: true
    
    # Configurable monitoring of model generations during training:
    - class_path: PreviewCallback
      init_args:
        mode: "aa"
        prompt: "M"
        length: 75
    
    # Inference mode config:
    - class_path: FastaInferenceWriter
      init_args:
        mode: "aa"
        output_file: "predictions.fasta"
        max_tokens: 100
        t: 1.0
        sample: true
        top_k: 5

# Model and optimiser hyperparameters:
model:
  config:
    vocab_size: 24  # mode dependent: aa -> 24, nt -> 5
    window_size: 16
    n_layers: 2
    n_heads: 2
    embed_d: 128
    emb_dropout: 0.1
    attn_dropout: 0.1
    res_dropout: 0.1
    adam_lr: 0.0003
    adam_weight_decay: 0.1
    adam_betas: [0.90, 0.95]

# Lightning datamodule parameters:
data:
  mode: "aa"
  directory: "seqs/"
  batch_size: 1
  max_seq_length: 100
  n_seq_limit: 500
  loader_num_workers: 2

  # The datamodule can also handle downloading datasets: 
  downloads: [
    ["https://rest.uniprot.org/uniprotkb/stream?compressed=true&format=fasta&query=%28gene%3Acas9%29", "uniprot_cas9.fasta.gz"]
  ]
  
  # Optionally set the inference prompt: 
  prompt: "M"
  n_samples: 100

Development

Installation From source

micromamba create -f environment.yml  # or conda etc.
micromamba activate gpt2-prot

pip install .  # Basic install
pip install -e ".[dev]"  # Install in editable mode with dev dependencies
pip install ".[test]"  # Install the package and all test dependencies

Running pre-commit hooks

# Install the hooks:
pre-commit install

# Run all the hooks:
pre-commit run --all-files

# Run unit tests:
pytest

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gpt2_prot-0.3.tar.gz (15.8 kB view details)

Uploaded Source

Built Distribution

gpt2_prot-0.3-py3-none-any.whl (15.9 kB view details)

Uploaded Python 3

File details

Details for the file gpt2_prot-0.3.tar.gz.

File metadata

  • Download URL: gpt2_prot-0.3.tar.gz
  • Upload date:
  • Size: 15.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.0 CPython/3.12.4

File hashes

Hashes for gpt2_prot-0.3.tar.gz
Algorithm Hash digest
SHA256 34754dbcf55894770af3b6b0487f762070dacd9f97f8c050f7deef9eb4c60b5b
MD5 e2df618337ef52fe032b934c466b1a08
BLAKE2b-256 a079e181f9710583de025f1d8267d5eb14d5d982c9ab723ab3a610a0684654d4

See more details on using hashes here.

File details

Details for the file gpt2_prot-0.3-py3-none-any.whl.

File metadata

  • Download URL: gpt2_prot-0.3-py3-none-any.whl
  • Upload date:
  • Size: 15.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.0 CPython/3.12.4

File hashes

Hashes for gpt2_prot-0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 dd1dd8d7576adc13ea82b0f61c32094c0c9430bcd5291694112d52a3594196f8
MD5 f4ae859c431d21f1298155bc44901ddf
BLAKE2b-256 08a4ed956db0b8909546f9425f0abad74844aae280352184419f44d7eb42276b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page