Skip to main content

A Benchmark for Categorical-State Schrödinger Bridges and Entropic Optimal Transport

Project description

Entering the Era of Discrete Diffusion Models: A Benchmark for Schrödinger Bridges and Entropic Optimal Transport

Xavier Aramayo, Grigoriy Ksenofontov, Aleksei Leonov, Iaroslav Koshelev, Alexander Korotin

arXiv Paper OpenReview Paper GitHub Hugging Face Model GitHub License

This repository contains the official implementation of the paper "Entering the Era of Discrete Diffusion Models: A Benchmark for Schrödinger Bridges and Entropic Optimal Transport", accepted at ICLR 2026.

📌 TL;DR

This paper proposes a benchmark for entropic optimal transport (EOT) and Schrödinger Bridge (SB) methods on discrete spaces, and adapts several continuous EOT/SB approaches to the discrete setting.

📦 CatSBench (Package)

catsbench is the standalone benchmark package. It provides benchmark definitions, evaluation metrics, and reusable utilities, including a Triton-optimized log-sum-exp (LSE) matmul kernel.

📥 Installation

Install the benchmark package via pip:

pip install catsbench

🚀 Quickstart

Load a benchmark definition and its assets from a pretrained repository:

from catsbench import BenchmarkHD

bench = BenchmarkHD.from_pretrained(
    "gregkseno/catsbench",
    "hd_d2_s50_gaussian_a0.02_gaussian",
    init_benchmark=False,  # skip heavy initialization at load time
)

To sample marginals $p_0$ and $p_1$:

x_start = bench.sample_input(32) # [B=32, D=2]
x_end = bench.sample_target(32)  # [B=32, D=2]

[!IMPORTANT] This samples independently from the marginals, i.e., $(x_0, x_1) \sim p_0(x_0)p_1(x_1)$.

To sample from the ground-truth EOT/SB coupling, i.e., $(x_0, x_1) \sim p_0(x_0) p^*(x_1 | x_0)$, use:

x_start, x_end = bench.sample_input_target(32) # ([B=32, D=2], [B=32, D=2])

Or sample them separately:

x_start = bench.sample_input(32) # [B=32, D=2]
x_end = bench.sample(x_start)    # [B=32, D=2]

[!NOTE] See the end-to-end benchmark workflow (initialization, evaluation, metrics, plotting) in notebooks/benchmark_usage.ipynb


Reproducing Experiments

This part describes how to run the full training and evaluation pipeline to reproduce paper's results. It explains how to launch experiments for the provided methods (DLightSB, DLightSB-M, CSBM, $\alpha$-CSBM) and evaluate them on the benchmarks.

|-- configs
|   |-- config.yaml   # main Hydra entrypoint
|   |-- callbacks     # Lightning callbacks: benchmark metrics + visualization
|   |-- data          # datamodule/dataset configs
|   |-- experiment    # experiment presets (override bundles)
|   |-- hydra         # Hydra runtime/output settings
|   |-- logger        # logging backends (Comet, W&B, TensorBoard)
|   |-- method        # method-level configs (e.g., CSBM, DLightSB)
|   |-- model         # model architecture configs
|   |-- prior         # reference process configs
|   `-- trainer       # trainer, hardware, precision, runtime configs
|-- logs              # logs, checkpoints, and run artifacts
|-- notebooks         # analysis and baselines
|-- scripts           # bash (+ SLURM) launch scripts
`-- src
    |-- catsbench     # benchmark package code
    |-- data          # Lightning datamodules + reference process implementation
    |-- methods       # training/inference methods (e.g., CSBM, DLightSB)
    |-- metrics       # callbacks computing benchmark metrics
    |-- plotter       # callbacks for plotting samples and trajectories
    |-- utils         # instantiation, logging, common helpers
    `-- run.py        # main entrypoint for training and testing

📦 Dependencies

Create the Anaconda environment using the following command:

conda env update -f environment.yml

and activate it:

conda activate catsbench

🏋️ Training

To start training, pick an experiment config under configs/experiment/<method_name>/benchmark_hd/<exp_name>.yaml and launch it with:

python -m src.run experiment=<method_name>/benchmark_hd/<exp_name>

Example:

python -m src.run experiment=dlight_sb/benchmark_hd/d2_g002

📊 Evaluation

Use the same experiment config as in training and set a checkpoint:

  • Manual path: logs/runs/<method_name>/benchmark_hd/<exp_name>/<seed>/<date>/epoch_<...>.ckpt
  • Or set ckpt_path=auto to automatically load the latest checkpoint based on the config.
python -m src.run task_name=test ckpt_path=auto \
  experiment=<method_name>/benchmark_hd/<exp_filename>

Example:

python -m src.run task_name=test ckpt_path=auto \
  experiment=dlight_sb/benchmark_hd/d2_g002

🎓 Citation

@inproceedings{
  carrasco2026entering,
  title={Entering the Era of Discrete Diffusion Models: A Benchmark for Schr\"odinger Bridges and Entropic Optimal Transport},
  author={Xavier Aramayo Carrasco and Grigoriy Ksenofontov and Aleksei Leonov and Iaroslav Sergeevich Koshelev and Alexander Korotin},
  booktitle={The Fourteenth International Conference on Learning Representations},
  year={2026},
  url={https://openreview.net/forum?id=XcPDT615Gd}
}

🙏 Credits

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

catsbench-1.2.tar.gz (30.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

catsbench-1.2-py3-none-any.whl (31.9 kB view details)

Uploaded Python 3

File details

Details for the file catsbench-1.2.tar.gz.

File metadata

  • Download URL: catsbench-1.2.tar.gz
  • Upload date:
  • Size: 30.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for catsbench-1.2.tar.gz
Algorithm Hash digest
SHA256 ec1b0dfd51dd64f0eab8cc35bfc7fb5f66a28f9255dc9f6c8f1d174cc1c185c6
MD5 39ec3061e0958a7afa417ec8ee914920
BLAKE2b-256 561fbecd71e901acf4a559e2f2fab007437a00adbd117b9ed39baf99dac6d257

See more details on using hashes here.

File details

Details for the file catsbench-1.2-py3-none-any.whl.

File metadata

  • Download URL: catsbench-1.2-py3-none-any.whl
  • Upload date:
  • Size: 31.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for catsbench-1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 f979c94e2f69babdc0d12f1bb8f06cf9401fa4d6ec6f863e7dc3396f00bc4dc5
MD5 96aaed9bd489766ec56b08a67729d499
BLAKE2b-256 260153166c2771ac0f50c5d5ae7c73c1ce06afa8b431b4bbf8ee76da662cb3f4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page