D3: DNA Discrete Diffusion -- train and sample generative models for DNA sequences
Project description
D3-DNA: DNA Discrete Diffusion
d3-dna is a standalone, pip-installable library for training and sampling discrete diffusion models on DNA sequences — the core D3Trainer / D3Sampler / D3Evaluator API plus the transformer and convolutional architectures, with all dataset-specific logic factored out.
This repository ships two things:
- The
d3_dna/package (published on PyPI asd3-dna) — model architectures, diffusion math, training loop, sampler, and dataset-agnostic evaluation metrics. - A small set of minimal reproducibility examples under
examples/— one self-contained directory per benchmark dataset (K562, HepG2, DeepSTARR, FANTOM5 promoter, plus aminimal/scaffold for new datasets), each providing aDataset, oracle, validation callback, andtrain.py/sample.py/evaluate.pyscripts that reproduce the published configuration end-to-end against the data and pretrained checkpoints on Zenodo.
The full research codebase, ablations, and analysis pipelines live separately at D3-DNA-Discrete-Diffusion.
System requirements
Operating system. Linux (POSIX). Verified on RHEL 8 (kernel 4.18). Not tested on Windows or macOS.
Python. ≥3.9. Verified on 3.11 and 3.12.
Python dependencies (auto-installed by pip):
| Package | Required | Verified |
|---|---|---|
torch |
≥2.0 | 2.12.0 (CUDA 13.0 build) |
pytorch-lightning |
≥2.0 | 2.6.1 |
omegaconf |
≥2.3 | 2.3 |
numpy |
≥1.23 | 2.4 |
scipy |
≥1.10 | 1.17 |
h5py |
≥3.7 | 3.16 |
tqdm |
≥4.64 | 4.67 |
einops |
≥0.6 | 0.8 |
flash-attn (extra [flash]) |
≥2.0 | 2.5.8–2.8.x |
wandb (extra [logging]) |
≥0.16 | 0.27 |
Hardware. Any CUDA-capable GPU; benchmarks in this README use an NVIDIA H100 NVL (driver 580, CUDA 13.0). The transformer falls back to PyTorch SDPA when flash-attn is not installed, so non-Ampere GPUs are supported but slower at long sequence lengths. The [flash] extra additionally requires an Ampere-or-newer GPU and a CUDA toolchain (nvcc, CUDA_HOME) at install time. CPU-only operation works for imports and small-scale sampling but is impractical for training.
Installation
# Core package
pip install d3-dna
# With flash attention (faster training on long sequences)
pip install d3-dna[flash]
# With Weights & Biases logging
pip install d3-dna[logging]
# Everything
pip install d3-dna[all]
GPU acceleration: d3-dna[flash] installs flash attention for faster, more memory-efficient training on long sequences. Without it, the package uses PyTorch's built-in scaled dot-product attention (SDPA) — same model quality, just slower for long inputs.
Flash-attention compiles from source and imports torch during its build, so install it on a machine with a CUDA toolchain (CUDA_HOME set, nvcc available) and disable build isolation so the existing torch install is visible:
pip install d3-dna
pip install flash-attn --no-build-isolation
Typical install time. Cold install into a fresh Python env on a 1 Gbit link is dominated by the ~3 GB PyTorch + CUDA-runtime wheel download — about 60–90 s end-to-end on a clean conda env, under 10 s if the wheels are already cached locally. The [flash] extra additionally compiles flash-attention from source, which takes 5–15 min the first time on a single GPU host.
Demo
After installing the package, clone this repo to get the example scripts and run a self-contained K562 sampling demo:
git clone https://github.com/anirbansarkar-cs/d3-dna.git
cd d3-dna/examples/k562
python sample.py --random-labels --num-samples 100 --steps 20 --replicates 1
The first run downloads the pretrained K562 transformer checkpoint from Zenodo (~1.4 GB, one-time, cached in examples/k562/cache/). After that, the sampling itself takes about 5 s on an NVIDIA H100 NVL.
Expected output (in examples/k562/generated/):
sample_0.npz— one-hot tensor of shape(100, 230, 4)sample_0.fasta— the same 100 sequences in FASTA format
python sample.py --help lists every override (config path, predictor, batch size, output dir, etc.).
Quickstart
1. Define your dataset
import torch
from torch.utils.data import Dataset
class MyDNADataset(Dataset):
def __init__(self, h5_path, split='train'):
import h5py
with h5py.File(h5_path, 'r') as f:
self.X = torch.tensor(f[f'X_{split}'][:]).argmax(dim=1) # one-hot to indices
self.y = torch.tensor(f[f'Y_{split}'][:])
def __len__(self): return len(self.X)
def __getitem__(self, i): return self.X[i], self.y[i]
2. Write a config
dataset:
name: my_dataset
sequence_length: 249
num_classes: 4
signal_dim: 2
ngpus: 1
tokens: 4
model:
architecture: transformer
hidden_size: 256
cond_dim: 128
n_blocks: 8
n_heads: 8
dropout: 0.1
class_dropout_prob: 0.1
training:
batch_size: 128
accum: 1
max_epochs: 300
ema: 0.9999
# ... see examples/minimal/config.yaml for full template
3. Train
from d3_dna import D3Trainer
trainer = D3Trainer('config.yaml')
trainer.fit(
train_dataset=MyDNADataset('data.h5', 'train'),
val_dataset=MyDNADataset('data.h5', 'valid'),
)
4. Sample
from d3_dna import D3Sampler
from d3_dna.models import TransformerModel
from omegaconf import OmegaConf
cfg = OmegaConf.load('config.yaml')
model = TransformerModel(cfg)
sampler = D3Sampler(cfg)
sequences = sampler.generate(
checkpoint='experiments/checkpoints/last.ckpt',
model=model,
num_samples=1000,
)
sampler.save(sequences, 'generated.fasta')
API Reference
| Class | Purpose |
|---|---|
D3Trainer |
Train a D3 model: trainer.fit(train_ds, val_ds) |
D3Sampler |
Generate sequences: sampler.generate(ckpt, model, n) |
D3Evaluator |
Evaluate with oracle: subclass and implement load_oracle_model() |
TransformerModel |
D3-Tran architecture (config-driven) |
ConvolutionalModel |
D3-Conv architecture (config-driven) |
Sampling performance
Rough single-GPU ballpark from the K562 example (230 bp sequences, 20-step Euler predictor, transformer backbone, bf16-mixed, no flash-attn) on one H100 NVL:
| Sequences | Batch | Steps | Wall time | Throughput |
|---|---|---|---|---|
| 39,340 (full K562 test set, 1 replicate) | 512 | 20 | ~5 min | ~130 seq/s |
Sampling is compute-bound — the per-step rate (~6 step/s at batch 512) stays roughly constant as you change batch size, so doubling the batch ~doubles end-to-end throughput until you hit GPU memory. Wall time scales linearly with num_samples × steps, and roughly linearly with sequence length without flash-attn (sub-linearly with it).
Mixed precision
Both training and sampling default to a per-architecture autocast policy, picked by d3_dna.modules.precision.precision_for_cfg:
| Architecture | Lightning precision | Autocast dtype | GradScaler |
|---|---|---|---|
transformer |
bf16-mixed |
torch.bfloat16 |
not used (bf16 has fp32 range) |
convolutional |
16-mixed |
torch.float16 |
installed automatically by Lightning |
The same dtype flows through get_score_fn so the loss path and the sampling path share a single autocast policy. LayerNorm opts out of autocast and runs in fp32; score.exp() in the predictor path is on PyTorch's fp32-cast list, so the score and all post-model sampler arithmetic land in fp32 regardless of architecture.
To override (e.g. when a checkpoint was trained under a different policy), set cfg.training.precision: '16-mixed' or 'bf16-mixed' on the config. Promoter is a known exception — examples/promoter/config_transformer.yaml overrides the default back to 16-mixed because the public D3_Tran_Promoter.ckpt on Zenodo was trained / validated under fp16 and degrades sharply if sampled in bf16. New transformer training runs in other examples should keep the bf16-mixed default.
Architecture
d3_dna/
├── models/ # TransformerModel, ConvolutionalModel, EMA
├── diffusion.py # Noise schedules, transition graphs, losses
├── sampling.py # PC sampler, predictors, D3Sampler
├── trainer.py # Lightning module, D3Trainer
├── evaluator.py # SP-MSE callback, D3Evaluator
└── io.py # Checkpoint loading, data utilities
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file d3_dna-0.1.1.tar.gz.
File metadata
- Download URL: d3_dna-0.1.1.tar.gz
- Upload date:
- Size: 40.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
abea95133773265fef2971e93b47c9d306e1f2d49c52b2ad3cea46e8fe36f4c1
|
|
| MD5 |
100874ba2f3061431251e9571b6f6077
|
|
| BLAKE2b-256 |
4c8065e1e4bca4e487d23e7043be481d5e7d977148b0fc99c8a403af08a7e003
|
File details
Details for the file d3_dna-0.1.1-py3-none-any.whl.
File metadata
- Download URL: d3_dna-0.1.1-py3-none-any.whl
- Upload date:
- Size: 42.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d200f65a951afd3d44b506e7d0b0940784fe53df73f4233e53de7f8faafd475e
|
|
| MD5 |
f02b0ed20a9fce96a6ed451870284f98
|
|
| BLAKE2b-256 |
1fa95d277a7a5aad8af32db87e4fc6e677b5184e1d40b32faa1f703201717a70
|