Skip to main content

DiffBio: End-to-end differentiable bioinformatics pipelines built on Datarax, Artifex, Opifex, and Calibrax

Project description

DiffBio

Python 3.11+ JAX Flax License

End-to-End Differentiable Bioinformatics Pipelines

Built on Datarax, Artifex, Opifex, and Calibrax | Powered by JAX & Flax NNX


Overview

DiffBio is a framework for building end-to-end differentiable bioinformatics pipelines. By replacing discrete operations with differentiable relaxations, DiffBio enables gradient-based optimization through entire analysis workflows.

DiffBio is the biology-specific differentiable operator layer of a wider JAX/NNX scientific ML ecosystem. It uses:

  • Datarax for operator and dataflow contracts
  • Artifex for reusable model-building and transformer components
  • Opifex for scientific ML and advanced optimization primitives
  • Calibrax for metrics, benchmarking, comparison, and regression control

Traditional bioinformatics pipelines use discrete operations (hard thresholds, argmax decisions) that block gradient flow. DiffBio addresses this by:

  • Soft quality filtering using sigmoid-based weights instead of hard cutoffs
  • Differentiable pileup with soft position assignments via temperature-controlled softmax
  • Soft alignment scoring replacing discrete Smith-Waterman with continuous relaxations
  • End-to-end training of complete pipelines using gradient descent

This enables learning optimal pipeline parameters directly from data, rather than manual tuning.

Features

  • 40+ Differentiable Operators covering alignment, variant calling, single-cell analysis, epigenomics, RNA-seq, preprocessing, normalization, multi-omics, drug discovery, and protein/RNA structure
  • 6 End-to-End Pipelines for variant calling, enhanced variant calling, single-cell analysis, differential expression, perturbation, and preprocessing
  • GPU-Accelerated computation via JAX's XLA compilation
  • Composable Architecture built on the Datarax, Artifex, Opifex, and Calibrax stack
  • Training Utilities with gradient clipping, custom loss functions, and synthetic data generation

For complete operator and pipeline listings, see the Operators Overview and Pipelines Overview in the documentation.

Installation

# Clone the repository
git clone https://github.com/avitai/DiffBio.git
cd DiffBio

# Install with uv
uv sync

Quick Start

Using Individual Operators

import jax
import jax.numpy as jnp
from flax import nnx

from diffbio.operators import DifferentiableQualityFilter
from diffbio.operators.variant.pileup import DifferentiablePileup
from diffbio.operators.alignment.smith_waterman import SmoothSmithWaterman

# Quality filtering with learnable threshold
quality_filter = DifferentiableQualityFilter(
    threshold=20.0,
    temperature=1.0,
    rngs=nnx.Rngs(0),
)

# Apply to reads
quality_scores = jnp.array([35.0, 15.0, 28.0, 10.0])
reads = jax.nn.one_hot(jnp.array([[0, 1, 2, 3]] * 4), 4)
data = {"reads": reads, "quality": quality_scores}

filtered_data, _, _ = quality_filter.apply(data, {}, None)
# filtered_data["weights"] contains soft weights for each read

Using the Variant Calling Pipeline

from diffbio.pipelines import (
    VariantCallingPipeline,
    VariantCallingPipelineConfig,
    create_variant_calling_pipeline,
)

# Create pipeline with default configuration
pipeline = create_variant_calling_pipeline(
    reference_length=100,
    num_classes=3,  # ref, SNP, indel
    hidden_dim=32,
    seed=42,
)

# Process reads
batch_data = {
    "reads": reads,           # (num_reads, read_length, 4)
    "positions": positions,   # (num_reads,)
    "quality": quality,       # (num_reads, read_length)
}

result, _, _ = pipeline.apply(batch_data, {}, None)
# result["logits"] contains per-position variant predictions
# result["probabilities"] contains class probabilities

Training a Pipeline

from diffbio.utils import (
    Trainer,
    TrainingConfig,
    cross_entropy_loss,
    create_synthetic_training_data,
    data_iterator,
)

# Generate synthetic training data
inputs, targets = create_synthetic_training_data(
    num_samples=100,
    num_reads=10,
    read_length=50,
    reference_length=100,
    variant_rate=0.1,
)

# Configure training
config = TrainingConfig(
    learning_rate=1e-3,
    num_epochs=50,
    log_every=10,
    grad_clip_norm=1.0,
)

# Create trainer
trainer = Trainer(pipeline, config)

# Define loss function
def loss_fn(predictions, targets):
    return cross_entropy_loss(
        predictions["logits"],
        targets["labels"],
        num_classes=3,
    )

# Train
trainer.train(
    data_iterator_fn=lambda: data_iterator(inputs, targets),
    loss_fn=loss_fn,
)

# Access trained pipeline
trained_pipeline = trainer.pipeline

Architecture

DiffBio sits on a layered ecosystem rather than standing alone:

Layer Library Role In DiffBio
Execution contracts Datarax Operator, data-source, and pipeline contracts
Modeling substrate Artifex Reusable transformer and generative-model components
Scientific ML substrate Opifex Scientific optimization, operator learning, and advanced training methods
Evaluation substrate Calibrax Metrics, benchmarking, comparison, profiling, and regression checks
Biology-specific layer DiffBio Differentiable biological operators and domain compositions

Each DiffBio operator inherits from Datarax's OperatorModule and implements:

apply(data, state, metadata) -> (output_data, output_state, output_metadata)

This enables:

  • Composition: Chain operators into pipelines
  • Batch processing: Automatic vectorization via apply_batch()
  • Gradient flow: End-to-end differentiability through the pipeline

Operator Composition

Operators are chained by threading the (data, state, metadata) triple returned by apply() into the next operator:

data, state, metadata = quality_filter.apply(batch_data, {}, None)
data, state, metadata = pileup.apply(data, state, metadata)
data, state, metadata = classifier.apply(data, state, metadata)

# `data` is a dict of JAX arrays — read out the per-position predictions
predictions = data["logits"]

Testing

# Run all tests
uv run pytest -vv

# Run with coverage
uv run pytest -vv --cov=src/ --cov-report=term-missing

# Run specific test modules
uv run pytest tests/operators/ -vv
uv run pytest tests/pipelines/ -vv
uv run pytest tests/integration/ -vv

Project Structure

DiffBio/
├── src/diffbio/
│   ├── core/               # Base operators, graph utils, soft ops
│   ├── operators/           # 35+ differentiable operators
│   │   ├── alignment/       # Smith-Waterman, profile HMM, soft MSA
│   │   ├── variant/         # Pileup, classifiers, CNV segmentation
│   │   ├── singlecell/      # Clustering, trajectory, velocity, GRN, ...
│   │   ├── drug_discovery/  # Fingerprints, property prediction, ADMET
│   │   ├── epigenomics/     # Peak calling, chromatin state
│   │   ├── normalization/   # VAE normalizer, UMAP, PHATE
│   │   ├── statistical/     # HMM, NB GLM, EM quantification
│   │   ├── multiomics/      # Hi-C, spatial deconvolution
│   │   └── ...              # preprocessing, protein, RNA, assembly, ...
│   ├── pipelines/           # End-to-end pipelines
│   ├── losses/              # Alignment, single-cell, statistical losses
│   ├── sources/             # Data loaders (FASTA, BAM, MolNet, ...)
│   ├── splitters/           # Dataset splitting strategies
│   └── utils/               # Training utilities
├── tests/                   # Unit, integration, and benchmark tests
├── benchmarks/              # Domain benchmarks with training + baselines
└── docs/                    # MkDocs documentation

Requirements

  • Python 3.11+
  • JAX 0.6.1+
  • Flax 0.12+
  • Optax 0.1.4+
  • jaxtyping 0.2.20+
  • Datarax, Artifex, Opifex, and Calibrax (installed automatically from PyPI)

License

MIT License. See LICENSE for details.

Acknowledgments

DiffBio builds on ideas from:

  • SMURF: Differentiable Smith-Waterman for end-to-end MSA learning
  • Datarax: Composable data processing framework
  • Artifex: Generative-model and transformer substrate
  • Opifex: Scientific ML and advanced optimization substrate
  • Calibrax: Benchmarking, comparison, and regression substrate
  • Flax NNX: Neural network library for JAX

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

diffbio-0.1.0.tar.gz (366.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

diffbio-0.1.0-py3-none-any.whl (507.0 kB view details)

Uploaded Python 3

File details

Details for the file diffbio-0.1.0.tar.gz.

File metadata

  • Download URL: diffbio-0.1.0.tar.gz
  • Upload date:
  • Size: 366.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.6

File hashes

Hashes for diffbio-0.1.0.tar.gz
Algorithm Hash digest
SHA256 704f79469867687e4d3c14fa48ee2a1993e4cd49937a3d134a7d5c25ca870754
MD5 4610c718466fd92bf4a05467cea3376b
BLAKE2b-256 f9228bf2ac33bf458d943e67c1b0e64ee1a3a9e232d7ac8c9680e6f5a2054f76

See more details on using hashes here.

File details

Details for the file diffbio-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: diffbio-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 507.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.6

File hashes

Hashes for diffbio-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 558d17d6e05b634ce74e3dd079e15afe389d722213b134bb768b9970535edbf7
MD5 448087a43f54433d3246c8ed7186db32
BLAKE2b-256 b7900969ed48f03e648219992b8b67256c21efd03f2b6b06aa3cc9dad8d2c7ac

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page