Skip to main content

D3: DNA Discrete Diffusion -- train and sample generative models for DNA sequences

Project description

D3-DNA: DNA Discrete Diffusion

A PyPI package for training and sampling discrete diffusion models on DNA sequences.

This package provides a clean, reusable implementation of the D3 (DNA Discrete Diffusion) framework from D3-DNA-Discrete-Diffusion. The original repository contains the full research codebase, experiment scripts, and analysis pipelines. This package extracts the core model, training, and sampling components into a pip-installable library suitable for integration into new projects.

Installation

# Core package
pip install d3-dna

# With flash attention (faster training on long sequences)
pip install d3-dna[flash]

# With Weights & Biases logging
pip install d3-dna[logging]

# Everything
pip install d3-dna[all]

GPU acceleration: d3-dna[flash] installs flash attention for faster, more memory-efficient training on long sequences. Without it, the package uses PyTorch's built-in scaled dot-product attention (SDPA) — same model quality, just slower for long inputs.

Flash-attention compiles from source and imports torch during its build, so install it on a machine with a CUDA toolchain (CUDA_HOME set, nvcc available) and disable build isolation so the existing torch install is visible:

pip install d3-dna
pip install flash-attn --no-build-isolation

Quickstart

1. Define your dataset

import torch
from torch.utils.data import Dataset

class MyDNADataset(Dataset):
    def __init__(self, h5_path, split='train'):
        import h5py
        with h5py.File(h5_path, 'r') as f:
            self.X = torch.tensor(f[f'X_{split}'][:]).argmax(dim=1)  # one-hot to indices
            self.y = torch.tensor(f[f'Y_{split}'][:])

    def __len__(self): return len(self.X)
    def __getitem__(self, i): return self.X[i], self.y[i]

2. Write a config

dataset:
  name: my_dataset
  sequence_length: 249
  num_classes: 4
  signal_dim: 2

ngpus: 1
tokens: 4

model:
  architecture: transformer
  hidden_size: 256
  cond_dim: 128
  n_blocks: 8
  n_heads: 8
  dropout: 0.1
  class_dropout_prob: 0.1

training:
  batch_size: 128
  accum: 1
  max_epochs: 300
  ema: 0.9999

# ... see examples/minimal/config.yaml for full template

3. Train

from d3_dna import D3Trainer

trainer = D3Trainer('config.yaml')
trainer.fit(
    train_dataset=MyDNADataset('data.h5', 'train'),
    val_dataset=MyDNADataset('data.h5', 'valid'),
)

4. Sample

from d3_dna import D3Sampler
from d3_dna.models import TransformerModel
from omegaconf import OmegaConf

cfg = OmegaConf.load('config.yaml')
model = TransformerModel(cfg)
sampler = D3Sampler(cfg)

sequences = sampler.generate(
    checkpoint='experiments/checkpoints/last.ckpt',
    model=model,
    num_samples=1000,
)
sampler.save(sequences, 'generated.fasta')

API Reference

Class Purpose
D3Trainer Train a D3 model: trainer.fit(train_ds, val_ds)
D3Sampler Generate sequences: sampler.generate(ckpt, model, n)
D3Evaluator Evaluate with oracle: subclass and implement load_oracle_model()
TransformerModel D3-Tran architecture (config-driven)
ConvolutionalModel D3-Conv architecture (config-driven)

Sampling performance

Rough single-GPU ballpark from the K562 example (230 bp sequences, 20-step Euler predictor, transformer backbone, bf16-mixed, no flash-attn) on one H100 NVL:

Sequences Batch Steps Wall time Throughput
39,340 (full K562 test set, 1 replicate) 512 20 ~5 min ~130 seq/s

Sampling is compute-bound — the per-step rate (~6 step/s at batch 512) stays roughly constant as you change batch size, so doubling the batch ~doubles end-to-end throughput until you hit GPU memory. Wall time scales linearly with num_samples × steps, and roughly linearly with sequence length without flash-attn (sub-linearly with it).

Mixed precision

Both training and sampling default to a per-architecture autocast policy, picked by d3_dna.modules.precision.precision_for_cfg:

Architecture Lightning precision Autocast dtype GradScaler
transformer bf16-mixed torch.bfloat16 not used (bf16 has fp32 range)
convolutional 16-mixed torch.float16 installed automatically by Lightning

The same dtype flows through get_score_fn so the loss path and the sampling path share a single autocast policy. LayerNorm opts out of autocast and runs in fp32; score.exp() in the predictor path is on PyTorch's fp32-cast list, so the score and all post-model sampler arithmetic land in fp32 regardless of architecture.

To override (e.g. when a checkpoint was trained under a different policy), set cfg.training.precision: '16-mixed' or 'bf16-mixed' on the config. Promoter is a known exceptionexamples/promoter/config_transformer.yaml overrides the default back to 16-mixed because the public D3_Tran_Promoter.ckpt on Zenodo was trained / validated under fp16 and degrades sharply if sampled in bf16. New transformer training runs in other examples should keep the bf16-mixed default.

Architecture

d3_dna/
├── models/          # TransformerModel, ConvolutionalModel, EMA
├── diffusion.py     # Noise schedules, transition graphs, losses
├── sampling.py      # PC sampler, predictors, D3Sampler
├── trainer.py       # Lightning module, D3Trainer
├── evaluator.py     # SP-MSE callback, D3Evaluator
└── io.py            # Checkpoint loading, data utilities

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

d3_dna-0.1.0.tar.gz (37.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

d3_dna-0.1.0-py3-none-any.whl (40.9 kB view details)

Uploaded Python 3

File details

Details for the file d3_dna-0.1.0.tar.gz.

File metadata

  • Download URL: d3_dna-0.1.0.tar.gz
  • Upload date:
  • Size: 37.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.10

File hashes

Hashes for d3_dna-0.1.0.tar.gz
Algorithm Hash digest
SHA256 bc205f273d3359c45865914c0106551a9378cea6caa3b33532cfcc4e39590f35
MD5 7fb1b95e5b05bacb0b2066a461c397b0
BLAKE2b-256 7e384e670f22c17ed2a94997b7eab99ef1b6ba9f9359b7826ae239fa358481b5

See more details on using hashes here.

File details

Details for the file d3_dna-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: d3_dna-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 40.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.10

File hashes

Hashes for d3_dna-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 844b8573aaae505ce147bae40a5c439526ea0ee1585b7932694645faff9efe98
MD5 038e08441c0c8ba57b744320d02b3ceb
BLAKE2b-256 2d9bfb99a3d8f8841297c38ecd3524af28021032d4676ff3d9d13e428e4caa04

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page