Skip to main content

D3: DNA Discrete Diffusion -- train and sample generative models for DNA sequences

Project description

d3-dna

PyPI Python License

[examples][Zenodo data + checkpoints][PyPI]

Discrete diffusion of DNA sequences. d3-dna is a standalone, pip-installable library for training and sampling SEDD-style discrete diffusion models on nucleotide sequences, packaged together with self-contained reproducibility examples for the K562, HepG2, DeepSTARR, and FANTOM5 promoter benchmarks. Models are config-driven — the same D3Trainer / D3Sampler / D3Evaluator API works for any dataset whose sequences fit in a fixed window, with global per-sample labels, per-position labels, or no conditioning at all (the conditioning mode is determined by the shape of y, not by a flag). Architectures are swappable: a 12-block diffusion transformer (DDiT) and a 256-channel dilated convolutional model ship out of the box, and either can be replaced via the cfg.model.architecture switch. Dataset-specific logic — oracles, masking, strand averaging, real-data layout — lives in the per-example directories, never in the core library, so adopting d3-dna on a new dataset is "fill in a Dataset class and an oracle"-sized.

The fully-populated K562 example under examples/k562/ is the best place to start: it reproduces the published transformer and convolutional configurations end-to-end (training, sampling, evaluation) against the pretrained checkpoints on Zenodo.

Installation

pip install d3-dna

Extras: [flash] adds flash-attention for faster training on long sequences (otherwise the transformer falls back to PyTorch SDPA — identical model quality, slower at long sequence lengths); [logging] adds Weights & Biases; [all] installs both. Flash-attention compiles from source and imports torch during its build, so install it on a machine with a CUDA toolchain and disable build isolation so the existing torch install is visible:

pip install d3-dna
pip install flash-attn --no-build-isolation

Cold install on a fresh Python env takes about 60–90 s on a 1 Gbit link (dominated by the ~3 GB PyTorch + CUDA-runtime wheel download), under 10 s when the wheels are already cached. The [flash] source build adds another 5–15 min on a single GPU host.

Demo

After installation, clone this repo and run a self-contained K562 sampling demo:

git clone https://github.com/anirbansarkar-cs/d3-dna.git
cd d3-dna/examples/k562
python sample.py --random-labels --num-samples 100 --steps 20 --replicates 1

The first run downloads the pretrained K562 transformer checkpoint from Zenodo (~1.4 GB, one-time, cached in examples/k562/cache/). After that, sampling itself takes about 5 s on an NVIDIA H100 NVL. Output (in examples/k562/generated/): sample_0.npz of shape (100, 230, 4) and sample_0.fasta.

Usage

The four public classes are D3Trainer, D3Sampler, D3Evaluator, and BaseSPMSEValidationCallback, all re-exported from d3_dna. Each operates on standard PyTorch Dataset objects and an OmegaConf cfg, so adopting d3-dna in an existing pipeline is mostly a matter of writing the Dataset, picking a config, and instantiating one of these classes.

Define a Dataset

Each item is (X, y), where X is a LongTensor of token indices and y is the conditioning label.

import torch
from torch.utils.data import Dataset
import h5py

class MyDNADataset(Dataset):
    def __init__(self, h5_path, split='train'):
        with h5py.File(h5_path, 'r') as f:
            # one-hot (N, L, 4) -> argmax -> (N, L) token indices
            self.X = torch.from_numpy(f[f'onehot_{split}'][:]).argmax(dim=-1)
            self.y = torch.tensor(f[f'y_{split}'][:], dtype=torch.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

The shape of y decides the conditioning mode automatically (no flag): (N, signal_dim) broadcasts a global label across positions; (N, sequence_length, signal_dim) adds a per-position label element-wise. K562 / HepG2 / DeepSTARR use the global form, FANTOM5 promoter uses the per-position form. See d3_dna/models/transformer.py:EmbeddingLayer.forward for the dispatch.

Train

D3Trainer wraps a PyTorch Lightning trainer around D3LightningModule. Pass the config (a path or an OmegaConf object), then .fit(train_dataset, val_dataset).

from d3_dna import D3Trainer

trainer = D3Trainer('config.yaml')
trainer.fit(
    train_dataset=MyDNADataset('data.h5', 'train'),
    val_dataset=MyDNADataset('data.h5', 'valid'),
)

Resume from a checkpoint with resume_from='path/to/last.ckpt', attach user callbacks with callbacks=[...]. The training config (including dataset metadata) is embedded in the saved checkpoint via save_hyperparameters, so a checkpoint is self-describing for inference.

Sample

D3Sampler loads a checkpoint plus a user-instantiated model, generates one-hot sequences via a PC sampler, and writes NPZ + FASTA to disk.

from d3_dna import D3Sampler
from d3_dna.models import TransformerModel
from omegaconf import OmegaConf

cfg = OmegaConf.load('config.yaml')
model = TransformerModel(cfg)
sampler = D3Sampler(cfg)

sequences = sampler.generate(
    checkpoint='outputs/last.ckpt',
    model=model,
    num_samples=1000,
)
sampler.save(sequences, 'generated.fasta')

For long runs, use .generate_batched(...) and tune cfg.sampling.batch_size against your GPU memory.

Evaluate

D3Evaluator is a dataset-agnostic dispatcher over four metrics: paired oracle MSE (mse), per-feature KS on oracle predictions (ks), k-mer Jensen–Shannon distance (js), and discriminator AUROC (auroc). The caller supplies pre-loaded samples, real data, and an oracle exposing a .predict(x) method.

from d3_dna import D3Evaluator

ev = D3Evaluator(tests=['mse', 'ks', 'js', 'auroc'], device='cuda')
results = ev.evaluate(
    samples=generated_one_hot,        # (N, L, 4) ndarray or torch tensor
    real_data=real_test_one_hot,      # (N, L, 4)
    oracle=my_oracle,                 # must implement .predict
    kmer_ks=[6],
)

Dataset-specific oracle loading, masking, and strand averaging live in examples/<name>/, never in the core library.

Train-time SP-MSE

BaseSPMSEValidationCallback is the abstract callback for periodic SP-MSE validation against an oracle during training. Subclass it once per dataset and override get_default_sampling_steps() and get_oracle_predictions(samples). Each example directory ships a concrete subclass — e.g. K562MSECallback in examples/k562/callbacks.py.

Examples

Each example reproduces a published D3 configuration end-to-end. Data, oracle weights, and pretrained transformer + convolutional checkpoints auto-download from Zenodo on first run.

Example Sequences Conditioning Zenodo
k562 230 bp MPRA global activity (N, 1) 19774653
hepg2 230 bp MPRA global activity (N, 1) 19774653
deepstarr 249 bp enhancers dual-head activity (N, 2) 19774653
promoter 1024 bp FANTOM5 per-position CAGE (N, 1024, 1) 19738941
minimal scaffold for a new dataset

Each examples/<name>/ is a self-contained train.py / sample.py / evaluate.py flow plus a Dataset, oracle, validation callback, and config YAMLs for both architectures. Per-example READMEs carry the reference numbers and reproduction recipe.

System requirements

Linux (POSIX); verified on RHEL 8 (kernel 4.18). Not tested on Windows or macOS. Python ≥3.9, verified on 3.11 and 3.12. Any CUDA-capable GPU works for training and sampling; benchmarks here use one NVIDIA H100 NVL (driver 580, CUDA 13). The [flash] extra additionally requires an Ampere-or-newer GPU and a CUDA toolchain (nvcc, CUDA_HOME) at install time. Python dependencies (auto-installed by pip): torch≥2.0 (verified on 2.12), pytorch-lightning≥2.0 (2.6), omegaconf≥2.3, numpy≥1.23, scipy≥1.10, h5py≥3.7, tqdm≥4.64, einops≥0.6; flash-attn≥2.0 and wandb≥0.16 for the optional extras.

Sampling performance

Single-GPU ballpark from the K562 example (230 bp, 20-step Euler predictor, transformer, bf16-mixed, no flash-attn) on one H100 NVL: ~6 step/s at batch 512, ~130 seq/s end-to-end, full 39,340-sequence test set in ~5 min. Sampling is compute-bound — the per-step rate is roughly batch-size-invariant until GPU memory becomes the limit, so doubling the batch ~doubles end-to-end throughput. Wall time scales linearly with num_samples × steps, and roughly linearly with sequence length (sub-linearly when flash-attn is enabled).

Related

d3-dna extracts the core training, sampling, and evaluation components from the D3-DNA-Discrete-Diffusion research codebase, which holds the full ablations, analysis pipelines, and experiment scripts.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

d3_dna-0.1.2.tar.gz (40.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

d3_dna-0.1.2-py3-none-any.whl (42.5 kB view details)

Uploaded Python 3

File details

Details for the file d3_dna-0.1.2.tar.gz.

File metadata

  • Download URL: d3_dna-0.1.2.tar.gz
  • Upload date:
  • Size: 40.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.10

File hashes

Hashes for d3_dna-0.1.2.tar.gz
Algorithm Hash digest
SHA256 0a80ecb2c3dcfee66f3c472fac439e6176e883d19a9eb070e4097271a0e27fa4
MD5 b44359b5776767d88cd382e3b75e7675
BLAKE2b-256 97b9a21f4aa150f7e45a20267c438c1e4807cea14a6a337e700066c7814a4d98

See more details on using hashes here.

File details

Details for the file d3_dna-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: d3_dna-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 42.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.10

File hashes

Hashes for d3_dna-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 aa0ad324d8978f5e66e1994abe7a035b669417f99121bb659659dee85995ad74
MD5 55f56461b422039ac51c8a13b3962357
BLAKE2b-256 9f0c3b6aa3bf1ef45eb4cadba482dd2ef3a3e59a91325d5ee50cb31560aba4a6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page