Skip to main content

D3: DNA Discrete Diffusion -- train and sample generative models for DNA sequences

Project description

D3-DNA: DNA Discrete Diffusion

d3-dna is a standalone, pip-installable library for training and sampling discrete diffusion models on DNA sequences — the core D3Trainer / D3Sampler / D3Evaluator API plus the transformer and convolutional architectures, with all dataset-specific logic factored out.

This repository ships two things:

  1. The d3_dna/ package (published on PyPI as d3-dna) — model architectures, diffusion math, training loop, sampler, and dataset-agnostic evaluation metrics.
  2. A small set of minimal reproducibility examples under examples/ — one self-contained directory per benchmark dataset (K562, HepG2, DeepSTARR, FANTOM5 promoter, plus a minimal/ scaffold for new datasets), each providing a Dataset, oracle, validation callback, and train.py / sample.py / evaluate.py scripts that reproduce the published configuration end-to-end against the data and pretrained checkpoints on Zenodo.

The full research codebase, ablations, and analysis pipelines live separately at D3-DNA-Discrete-Diffusion.

System requirements

Operating system. Linux (POSIX). Verified on RHEL 8 (kernel 4.18). Not tested on Windows or macOS.

Python. ≥3.9. Verified on 3.11 and 3.12.

Python dependencies (auto-installed by pip):

Package Required Verified
torch ≥2.0 2.12.0 (CUDA 13.0 build)
pytorch-lightning ≥2.0 2.6.1
omegaconf ≥2.3 2.3
numpy ≥1.23 2.4
scipy ≥1.10 1.17
h5py ≥3.7 3.16
tqdm ≥4.64 4.67
einops ≥0.6 0.8
flash-attn (extra [flash]) ≥2.0 2.5.8–2.8.x
wandb (extra [logging]) ≥0.16 0.27

Hardware. Any CUDA-capable GPU; benchmarks in this README use an NVIDIA H100 NVL (driver 580, CUDA 13.0). The transformer falls back to PyTorch SDPA when flash-attn is not installed, so non-Ampere GPUs are supported but slower at long sequence lengths. The [flash] extra additionally requires an Ampere-or-newer GPU and a CUDA toolchain (nvcc, CUDA_HOME) at install time. CPU-only operation works for imports and small-scale sampling but is impractical for training.

Installation

# Core package
pip install d3-dna

# With flash attention (faster training on long sequences)
pip install d3-dna[flash]

# With Weights & Biases logging
pip install d3-dna[logging]

# Everything
pip install d3-dna[all]

GPU acceleration: d3-dna[flash] installs flash attention for faster, more memory-efficient training on long sequences. Without it, the package uses PyTorch's built-in scaled dot-product attention (SDPA) — same model quality, just slower for long inputs.

Flash-attention compiles from source and imports torch during its build, so install it on a machine with a CUDA toolchain (CUDA_HOME set, nvcc available) and disable build isolation so the existing torch install is visible:

pip install d3-dna
pip install flash-attn --no-build-isolation

Typical install time. Cold install into a fresh Python env on a 1 Gbit link is dominated by the ~3 GB PyTorch + CUDA-runtime wheel download — about 60–90 s end-to-end on a clean conda env, under 10 s if the wheels are already cached locally. The [flash] extra additionally compiles flash-attention from source, which takes 5–15 min the first time on a single GPU host.

Demo

After installing the package, clone this repo to get the example scripts and run a self-contained K562 sampling demo:

git clone https://github.com/anirbansarkar-cs/d3-dna.git
cd d3-dna/examples/k562
python sample.py --random-labels --num-samples 100 --steps 20 --replicates 1

The first run downloads the pretrained K562 transformer checkpoint from Zenodo (~1.4 GB, one-time, cached in examples/k562/cache/). After that, the sampling itself takes about 5 s on an NVIDIA H100 NVL.

Expected output (in examples/k562/generated/):

  • sample_0.npz — one-hot tensor of shape (100, 230, 4)
  • sample_0.fasta — the same 100 sequences in FASTA format

python sample.py --help lists every override (config path, predictor, batch size, output dir, etc.).

Quickstart

1. Define your dataset

import torch
from torch.utils.data import Dataset

class MyDNADataset(Dataset):
    def __init__(self, h5_path, split='train'):
        import h5py
        with h5py.File(h5_path, 'r') as f:
            self.X = torch.tensor(f[f'X_{split}'][:]).argmax(dim=1)  # one-hot to indices
            self.y = torch.tensor(f[f'Y_{split}'][:])

    def __len__(self): return len(self.X)
    def __getitem__(self, i): return self.X[i], self.y[i]

2. Write a config

dataset:
  name: my_dataset
  sequence_length: 249
  num_classes: 4
  signal_dim: 2

ngpus: 1
tokens: 4

model:
  architecture: transformer
  hidden_size: 256
  cond_dim: 128
  n_blocks: 8
  n_heads: 8
  dropout: 0.1
  class_dropout_prob: 0.1

training:
  batch_size: 128
  accum: 1
  max_epochs: 300
  ema: 0.9999

# ... see examples/minimal/config.yaml for full template

3. Train

from d3_dna import D3Trainer

trainer = D3Trainer('config.yaml')
trainer.fit(
    train_dataset=MyDNADataset('data.h5', 'train'),
    val_dataset=MyDNADataset('data.h5', 'valid'),
)

4. Sample

from d3_dna import D3Sampler
from d3_dna.models import TransformerModel
from omegaconf import OmegaConf

cfg = OmegaConf.load('config.yaml')
model = TransformerModel(cfg)
sampler = D3Sampler(cfg)

sequences = sampler.generate(
    checkpoint='experiments/checkpoints/last.ckpt',
    model=model,
    num_samples=1000,
)
sampler.save(sequences, 'generated.fasta')

API Reference

Class Purpose
D3Trainer Train a D3 model: trainer.fit(train_ds, val_ds)
D3Sampler Generate sequences: sampler.generate(ckpt, model, n)
D3Evaluator Evaluate with oracle: subclass and implement load_oracle_model()
TransformerModel D3-Tran architecture (config-driven)
ConvolutionalModel D3-Conv architecture (config-driven)

Sampling performance

Rough single-GPU ballpark from the K562 example (230 bp sequences, 20-step Euler predictor, transformer backbone, bf16-mixed, no flash-attn) on one H100 NVL:

Sequences Batch Steps Wall time Throughput
39,340 (full K562 test set, 1 replicate) 512 20 ~5 min ~130 seq/s

Sampling is compute-bound — the per-step rate (~6 step/s at batch 512) stays roughly constant as you change batch size, so doubling the batch ~doubles end-to-end throughput until you hit GPU memory. Wall time scales linearly with num_samples × steps, and roughly linearly with sequence length without flash-attn (sub-linearly with it).

Mixed precision

Both training and sampling default to a per-architecture autocast policy, picked by d3_dna.modules.precision.precision_for_cfg:

Architecture Lightning precision Autocast dtype GradScaler
transformer bf16-mixed torch.bfloat16 not used (bf16 has fp32 range)
convolutional 16-mixed torch.float16 installed automatically by Lightning

The same dtype flows through get_score_fn so the loss path and the sampling path share a single autocast policy. LayerNorm opts out of autocast and runs in fp32; score.exp() in the predictor path is on PyTorch's fp32-cast list, so the score and all post-model sampler arithmetic land in fp32 regardless of architecture.

To override (e.g. when a checkpoint was trained under a different policy), set cfg.training.precision: '16-mixed' or 'bf16-mixed' on the config. Promoter is a known exceptionexamples/promoter/config_transformer.yaml overrides the default back to 16-mixed because the public D3_Tran_Promoter.ckpt on Zenodo was trained / validated under fp16 and degrades sharply if sampled in bf16. New transformer training runs in other examples should keep the bf16-mixed default.

Architecture

d3_dna/
├── models/          # TransformerModel, ConvolutionalModel, EMA
├── diffusion.py     # Noise schedules, transition graphs, losses
├── sampling.py      # PC sampler, predictors, D3Sampler
├── trainer.py       # Lightning module, D3Trainer
├── evaluator.py     # SP-MSE callback, D3Evaluator
└── io.py            # Checkpoint loading, data utilities

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

d3_dna-0.1.1.tar.gz (40.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

d3_dna-0.1.1-py3-none-any.whl (42.1 kB view details)

Uploaded Python 3

File details

Details for the file d3_dna-0.1.1.tar.gz.

File metadata

  • Download URL: d3_dna-0.1.1.tar.gz
  • Upload date:
  • Size: 40.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.10

File hashes

Hashes for d3_dna-0.1.1.tar.gz
Algorithm Hash digest
SHA256 abea95133773265fef2971e93b47c9d306e1f2d49c52b2ad3cea46e8fe36f4c1
MD5 100874ba2f3061431251e9571b6f6077
BLAKE2b-256 4c8065e1e4bca4e487d23e7043be481d5e7d977148b0fc99c8a403af08a7e003

See more details on using hashes here.

File details

Details for the file d3_dna-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: d3_dna-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 42.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.10

File hashes

Hashes for d3_dna-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 d200f65a951afd3d44b506e7d0b0940784fe53df73f4233e53de7f8faafd475e
MD5 f02b0ed20a9fce96a6ed451870284f98
BLAKE2b-256 1fa95d277a7a5aad8af32db87e4fc6e677b5184e1d40b32faa1f703201717a70

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page