Skip to main content

D3: DNA Discrete Diffusion -- train and sample generative models for DNA sequences

Project description

d3-dna

PyPI Python License

[examples][Zenodo data + checkpoints][PyPI]

Discrete diffusion of DNA sequences. d3-dna is a standalone, pip-installable library for training and sampling SEDD-style discrete diffusion models on nucleotide sequences, packaged together with self-contained reproducibility examples for the K562, HepG2, DeepSTARR, and FANTOM5 promoter benchmarks. Models are config-driven — the same D3Trainer / D3Sampler / D3Evaluator API works for any dataset whose sequences fit in a fixed window, with global per-sample labels, per-position labels, or no conditioning at all (the conditioning mode is determined by the shape of y, not by a flag). Architectures are swappable: a 12-block diffusion transformer (DDiT) and a 256-channel dilated convolutional model ship out of the box, and either can be replaced via the cfg.model.architecture switch. Dataset-specific logic — oracles, masking, strand averaging, real-data layout — lives in the per-example directories, never in the core library, so adopting d3-dna on a new dataset is "fill in a Dataset class and an oracle"-sized.

The fully-populated K562 example under examples/k562/ is the best place to start: it reproduces the published transformer and convolutional configurations end-to-end (training, sampling, evaluation) against the pretrained checkpoints on Zenodo.

Installation

pip install d3-dna

Extras: [flash] adds flash-attention for faster training on long sequences (otherwise the transformer falls back to PyTorch SDPA — identical model quality, slower at long sequence lengths); [logging] adds Weights & Biases; [all] installs both. Flash-attention compiles from source and imports torch during its build, so install it on a machine with a CUDA toolchain and disable build isolation so the existing torch install is visible:

pip install d3-dna
pip install flash-attn --no-build-isolation

Cold install on a fresh Python env takes about 60–90 s on a 1 Gbit link (dominated by the ~3 GB PyTorch + CUDA-runtime wheel download), under 10 s when the wheels are already cached. The [flash] source build adds another 5–15 min on a single GPU host.

Demo

After installation, clone this repo and run a self-contained K562 sampling demo:

git clone https://github.com/anirbansarkar-cs/d3-dna.git
cd d3-dna/examples/k562
python sample.py --random-labels --num-samples 100 --steps 20 --replicates 1

The first run downloads the pretrained K562 transformer checkpoint from Zenodo (~1.4 GB, one-time, cached in examples/k562/cache/). After that, sampling itself takes about 5 s on an NVIDIA H100 NVL. Output (in examples/k562/generated/): sample_0.npz of shape (100, 230, 4) and sample_0.fasta.

Usage

The four public classes are D3Trainer, D3Sampler, D3Evaluator, and BaseSPMSEValidationCallback, all re-exported from d3_dna. Each operates on standard PyTorch Dataset objects and an OmegaConf cfg, so adopting d3-dna in an existing pipeline is mostly a matter of writing the Dataset, picking a config, and instantiating one of these classes.

Define a Dataset

Each item is (X, y), where X is a LongTensor of token indices and y is the conditioning label.

import torch
from torch.utils.data import Dataset
import h5py

class MyDNADataset(Dataset):
    def __init__(self, h5_path, split='train'):
        with h5py.File(h5_path, 'r') as f:
            # one-hot (N, L, 4) -> argmax -> (N, L) token indices
            self.X = torch.from_numpy(f[f'onehot_{split}'][:]).argmax(dim=-1)
            self.y = torch.tensor(f[f'y_{split}'][:], dtype=torch.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

The shape of y decides the conditioning mode automatically (no flag): (N, signal_dim) broadcasts a global label across positions; (N, sequence_length, signal_dim) adds a per-position label element-wise. K562 / HepG2 / DeepSTARR use the global form, FANTOM5 promoter uses the per-position form. See d3_dna/models/transformer.py:EmbeddingLayer.forward for the dispatch.

Train

D3Trainer wraps a PyTorch Lightning trainer around D3LightningModule. Pass the config (a path or an OmegaConf object), then .fit(train_dataset, val_dataset).

from d3_dna import D3Trainer

trainer = D3Trainer('config.yaml')
trainer.fit(
    train_dataset=MyDNADataset('data.h5', 'train'),
    val_dataset=MyDNADataset('data.h5', 'valid'),
)

Resume from a checkpoint with resume_from='path/to/last.ckpt', attach user callbacks with callbacks=[...]. The training config (including dataset metadata) is embedded in the saved checkpoint via save_hyperparameters, so a checkpoint is self-describing for inference.

Sample

D3Sampler loads a checkpoint plus a user-instantiated model, generates one-hot sequences via a PC sampler, and writes NPZ + FASTA to disk.

from d3_dna import D3Sampler
from d3_dna.models import TransformerModel
from omegaconf import OmegaConf

cfg = OmegaConf.load('config.yaml')
model = TransformerModel(cfg)
sampler = D3Sampler(cfg)

sequences = sampler.generate(
    checkpoint='outputs/last.ckpt',
    model=model,
    num_samples=1000,
)
sampler.save(sequences, 'generated.fasta')

For long runs, use .generate_batched(...) and tune cfg.sampling.batch_size against your GPU memory.

Evaluate

D3Evaluator is a dataset-agnostic dispatcher over four metrics: paired oracle MSE (mse), per-feature KS on oracle predictions (ks), k-mer Jensen–Shannon distance (js), and discriminator AUROC (auroc). The caller supplies pre-loaded samples, real data, and an oracle exposing a .predict(x) method.

from d3_dna import D3Evaluator

ev = D3Evaluator(tests=['mse', 'ks', 'js', 'auroc'], device='cuda')
results = ev.evaluate(
    samples=generated_one_hot,        # (N, L, 4) ndarray or torch tensor
    real_data=real_test_one_hot,      # (N, L, 4)
    oracle=my_oracle,                 # must implement .predict
    kmer_ks=[6],
)

Dataset-specific oracle loading, masking, and strand averaging live in examples/<name>/, never in the core library.

Train-time SP-MSE

BaseSPMSEValidationCallback is the abstract callback for periodic SP-MSE validation against an oracle during training. Subclass it once per dataset and override get_default_sampling_steps() and get_oracle_predictions(samples). Each example directory ships a concrete subclass — e.g. K562MSECallback in examples/k562/callbacks.py.

Examples

Each example reproduces a published D3 configuration end-to-end. Data, oracle weights, and pretrained transformer + convolutional checkpoints auto-download from Zenodo on first run.

Example Sequences Conditioning Zenodo
k562 230 bp MPRA global activity (N, 1) 19774653
hepg2 230 bp MPRA global activity (N, 1) 19774653
deepstarr 249 bp enhancers dual-head activity (N, 2) 19774653
promoter 1024 bp FANTOM5 per-position CAGE (N, 1024, 1) 19738941
minimal scaffold for a new dataset

Each examples/<name>/ is a self-contained train.py / sample.py / evaluate.py flow plus a Dataset, oracle, validation callback, and config YAMLs for both architectures. Per-example READMEs carry the reference numbers and reproduction recipe.

System requirements

Linux (POSIX); verified on RHEL 8 (kernel 4.18). Not tested on Windows or macOS. Python ≥3.9, verified on 3.11 and 3.12. Any CUDA-capable GPU works for training and sampling; benchmarks here use one NVIDIA H100 NVL (driver 580, CUDA 13). The [flash] extra additionally requires an Ampere-or-newer GPU and a CUDA toolchain (nvcc, CUDA_HOME) at install time. Python dependencies (auto-installed by pip): torch≥2.0 (verified on 2.12), pytorch-lightning≥2.0 (2.6), omegaconf≥2.3, numpy≥1.23, scipy≥1.10, h5py≥3.7, tqdm≥4.64, einops≥0.6; flash-attn≥2.0 and wandb≥0.16 for the optional extras.

Sampling performance

Single-GPU ballpark from the K562 example (230 bp, 20-step Euler predictor, transformer, bf16-mixed, no flash-attn) on one H100 NVL: ~6 step/s at batch 512, ~130 seq/s end-to-end, full 39,340-sequence test set in ~5 min. Sampling is compute-bound — the per-step rate is roughly batch-size-invariant until GPU memory becomes the limit, so doubling the batch ~doubles end-to-end throughput. Wall time scales linearly with num_samples × steps, and roughly linearly with sequence length (sub-linearly when flash-attn is enabled).

Citation

If you use d3-dna in your work, please cite the accompanying paper (bioRxiv preprint). A BibTeX entry is included at CITATION.bib:

@article{sarkar2024d3dna,
    title = {Designing {DNA} With Tunable Regulatory Activity Using Discrete Diffusion},
    author = {Sarkar, Anirban and Duran, Alejandra and Yu, Yiyang and Lin, Da-Wei and Kang, Yijie and Somia, Nirali and Mantilla, Pablo and Zhou, Jessica and Nagai, Masayuki and Tang, Ziqi and Hanington, Kaarina and Chang, Kenneth and Koo, Peter K.},
    journal = {bioRxiv},
    year = {2024},
    doi = {10.1101/2024.05.23.595630},
    url = {https://www.biorxiv.org/content/10.1101/2024.05.23.595630v3},
    publisher = {Cold Spring Harbor Laboratory},
    note = {Preprint, version 3}
}

Related

d3-dna extracts the core training, sampling, and evaluation components from the D3-DNA-Discrete-Diffusion research codebase, which holds the full ablations, analysis pipelines, and experiment scripts.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

d3_dna-0.1.3.tar.gz (41.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

d3_dna-0.1.3-py3-none-any.whl (43.0 kB view details)

Uploaded Python 3

File details

Details for the file d3_dna-0.1.3.tar.gz.

File metadata

  • Download URL: d3_dna-0.1.3.tar.gz
  • Upload date:
  • Size: 41.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.10

File hashes

Hashes for d3_dna-0.1.3.tar.gz
Algorithm Hash digest
SHA256 6c877a7f57fcc06d79f927959b17737ed2b23ca7859b509cd61fe4bb90319d61
MD5 d3b0463b1c680ac5c47d34325805f775
BLAKE2b-256 2bccda1efa2684702234c2b05c7b101a3271645f1aa91355624079ae0355dccc

See more details on using hashes here.

File details

Details for the file d3_dna-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: d3_dna-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 43.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.10

File hashes

Hashes for d3_dna-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 197f3745af9e1dd959b17bea56558590e86c125e85cb36c9fa72edc31d72905b
MD5 a898305770e1452b5d227054b0510d51
BLAKE2b-256 6760e3725e44149b1d6b04f399d886ad636a5fbbba2f6adaba0f288ac34b5001

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page