D3: DNA Discrete Diffusion -- train and sample generative models for DNA sequences
Project description
d3-dna
[examples][Zenodo data + checkpoints][PyPI]
Discrete diffusion of DNA sequences. d3-dna is a standalone, pip-installable library for training and sampling SEDD-style discrete diffusion models on nucleotide sequences, packaged together with self-contained reproducibility examples for the K562, HepG2, DeepSTARR, and FANTOM5 promoter benchmarks. Models are config-driven — the same D3Trainer / D3Sampler / D3Evaluator API works for any dataset whose sequences fit in a fixed window, with global per-sample labels, per-position labels, or no conditioning at all (the conditioning mode is determined by the shape of y, not by a flag). Architectures are swappable: a 12-block diffusion transformer (DDiT) and a 256-channel dilated convolutional model ship out of the box, and either can be replaced via the cfg.model.architecture switch. Dataset-specific logic — oracles, masking, strand averaging, real-data layout — lives in the per-example directories, never in the core library, so adopting d3-dna on a new dataset is "fill in a Dataset class and an oracle"-sized.
The fully-populated K562 example under examples/k562/ is the best place to start: it reproduces the published transformer and convolutional configurations end-to-end (training, sampling, evaluation) against the pretrained checkpoints on Zenodo.
Installation
pip install d3-dna
Extras: [flash] adds flash-attention for faster training on long sequences (otherwise the transformer falls back to PyTorch SDPA — identical model quality, slower at long sequence lengths); [logging] adds Weights & Biases; [all] installs both. Flash-attention compiles from source and imports torch during its build, so install it on a machine with a CUDA toolchain and disable build isolation so the existing torch install is visible:
pip install d3-dna
pip install flash-attn --no-build-isolation
Cold install on a fresh Python env takes about 60–90 s on a 1 Gbit link (dominated by the ~3 GB PyTorch + CUDA-runtime wheel download), under 10 s when the wheels are already cached. The [flash] source build adds another 5–15 min on a single GPU host.
Demo
After installation, clone this repo and run a self-contained K562 sampling demo:
git clone https://github.com/anirbansarkar-cs/d3-dna.git
cd d3-dna/examples/k562
python sample.py --random-labels --num-samples 100 --steps 20 --replicates 1
The first run downloads the pretrained K562 transformer checkpoint from Zenodo (~1.4 GB, one-time, cached in examples/k562/cache/). After that, sampling itself takes about 5 s on an NVIDIA H100 NVL. Output (in examples/k562/generated/): sample_0.npz of shape (100, 230, 4) and sample_0.fasta.
Usage
The four public classes are D3Trainer, D3Sampler, D3Evaluator, and BaseSPMSEValidationCallback, all re-exported from d3_dna. Each operates on standard PyTorch Dataset objects and an OmegaConf cfg, so adopting d3-dna in an existing pipeline is mostly a matter of writing the Dataset, picking a config, and instantiating one of these classes.
Define a Dataset
Each item is (X, y), where X is a LongTensor of token indices and y is the conditioning label.
import torch
from torch.utils.data import Dataset
import h5py
class MyDNADataset(Dataset):
def __init__(self, h5_path, split='train'):
with h5py.File(h5_path, 'r') as f:
# one-hot (N, L, 4) -> argmax -> (N, L) token indices
self.X = torch.from_numpy(f[f'onehot_{split}'][:]).argmax(dim=-1)
self.y = torch.tensor(f[f'y_{split}'][:], dtype=torch.float32)
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return self.X[idx], self.y[idx]
The shape of y decides the conditioning mode automatically (no flag): (N, signal_dim) broadcasts a global label across positions; (N, sequence_length, signal_dim) adds a per-position label element-wise. K562 / HepG2 / DeepSTARR use the global form, FANTOM5 promoter uses the per-position form. See d3_dna/models/transformer.py:EmbeddingLayer.forward for the dispatch.
Train
D3Trainer wraps a PyTorch Lightning trainer around D3LightningModule. Pass the config (a path or an OmegaConf object), then .fit(train_dataset, val_dataset).
from d3_dna import D3Trainer
trainer = D3Trainer('config.yaml')
trainer.fit(
train_dataset=MyDNADataset('data.h5', 'train'),
val_dataset=MyDNADataset('data.h5', 'valid'),
)
Resume from a checkpoint with resume_from='path/to/last.ckpt', attach user callbacks with callbacks=[...]. The training config (including dataset metadata) is embedded in the saved checkpoint via save_hyperparameters, so a checkpoint is self-describing for inference.
Sample
D3Sampler loads a checkpoint plus a user-instantiated model, generates one-hot sequences via a PC sampler, and writes NPZ + FASTA to disk.
from d3_dna import D3Sampler
from d3_dna.models import TransformerModel
from omegaconf import OmegaConf
cfg = OmegaConf.load('config.yaml')
model = TransformerModel(cfg)
sampler = D3Sampler(cfg)
sequences = sampler.generate(
checkpoint='outputs/last.ckpt',
model=model,
num_samples=1000,
)
sampler.save(sequences, 'generated.fasta')
For long runs, use .generate_batched(...) and tune cfg.sampling.batch_size against your GPU memory.
Evaluate
D3Evaluator is a dataset-agnostic dispatcher over four metrics: paired oracle MSE (mse), per-feature KS on oracle predictions (ks), k-mer Jensen–Shannon distance (js), and discriminator AUROC (auroc). The caller supplies pre-loaded samples, real data, and an oracle exposing a .predict(x) method.
from d3_dna import D3Evaluator
ev = D3Evaluator(tests=['mse', 'ks', 'js', 'auroc'], device='cuda')
results = ev.evaluate(
samples=generated_one_hot, # (N, L, 4) ndarray or torch tensor
real_data=real_test_one_hot, # (N, L, 4)
oracle=my_oracle, # must implement .predict
kmer_ks=[6],
)
Dataset-specific oracle loading, masking, and strand averaging live in examples/<name>/, never in the core library.
Train-time SP-MSE
BaseSPMSEValidationCallback is the abstract callback for periodic SP-MSE validation against an oracle during training. Subclass it once per dataset and override get_default_sampling_steps() and get_oracle_predictions(samples). Each example directory ships a concrete subclass — e.g. K562MSECallback in examples/k562/callbacks.py.
Examples
Each example reproduces a published D3 configuration end-to-end. Data, oracle weights, and pretrained transformer + convolutional checkpoints auto-download from Zenodo on first run.
| Example | Sequences | Conditioning | Zenodo |
|---|---|---|---|
k562 |
230 bp MPRA | global activity (N, 1) |
19774653 |
hepg2 |
230 bp MPRA | global activity (N, 1) |
19774653 |
deepstarr |
249 bp enhancers | dual-head activity (N, 2) |
19774653 |
promoter |
1024 bp FANTOM5 | per-position CAGE (N, 1024, 1) |
19738941 |
minimal |
scaffold for a new dataset | — | — |
Each examples/<name>/ is a self-contained train.py / sample.py / evaluate.py flow plus a Dataset, oracle, validation callback, and config YAMLs for both architectures. Per-example READMEs carry the reference numbers and reproduction recipe.
System requirements
Linux (POSIX); verified on RHEL 8 (kernel 4.18). Not tested on Windows or macOS. Python ≥3.9, verified on 3.11 and 3.12. Any CUDA-capable GPU works for training and sampling; benchmarks here use one NVIDIA H100 NVL (driver 580, CUDA 13). The [flash] extra additionally requires an Ampere-or-newer GPU and a CUDA toolchain (nvcc, CUDA_HOME) at install time. Python dependencies (auto-installed by pip): torch≥2.0 (verified on 2.12), pytorch-lightning≥2.0 (2.6), omegaconf≥2.3, numpy≥1.23, scipy≥1.10, h5py≥3.7, tqdm≥4.64, einops≥0.6; flash-attn≥2.0 and wandb≥0.16 for the optional extras.
Sampling performance
Single-GPU ballpark from the K562 example (230 bp, 20-step Euler predictor, transformer, bf16-mixed, no flash-attn) on one H100 NVL: ~6 step/s at batch 512, ~130 seq/s end-to-end, full 39,340-sequence test set in ~5 min. Sampling is compute-bound — the per-step rate is roughly batch-size-invariant until GPU memory becomes the limit, so doubling the batch ~doubles end-to-end throughput. Wall time scales linearly with num_samples × steps, and roughly linearly with sequence length (sub-linearly when flash-attn is enabled).
Related
d3-dna extracts the core training, sampling, and evaluation components from the D3-DNA-Discrete-Diffusion research codebase, which holds the full ablations, analysis pipelines, and experiment scripts.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file d3_dna-0.1.2.tar.gz.
File metadata
- Download URL: d3_dna-0.1.2.tar.gz
- Upload date:
- Size: 40.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0a80ecb2c3dcfee66f3c472fac439e6176e883d19a9eb070e4097271a0e27fa4
|
|
| MD5 |
b44359b5776767d88cd382e3b75e7675
|
|
| BLAKE2b-256 |
97b9a21f4aa150f7e45a20267c438c1e4807cea14a6a337e700066c7814a4d98
|
File details
Details for the file d3_dna-0.1.2-py3-none-any.whl.
File metadata
- Download URL: d3_dna-0.1.2-py3-none-any.whl
- Upload date:
- Size: 42.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
aa0ad324d8978f5e66e1994abe7a035b669417f99121bb659659dee85995ad74
|
|
| MD5 |
55f56461b422039ac51c8a13b3962357
|
|
| BLAKE2b-256 |
9f0c3b6aa3bf1ef45eb4cadba482dd2ef3a3e59a91325d5ee50cb31560aba4a6
|