A Benchmark for Categorical-State Schrödinger Bridges and Entropic Optimal Transport
Project description
Entering the Era of Discrete Diffusion Models: A Benchmark for Schrödinger Bridges and Entropic Optimal Transport
Xavier Aramayo, Grigoriy Ksenofontov, Aleksei Leonov, Iaroslav Koshelev, Alexander Korotin
This repository contains the official implementation of the paper "Entering the Era of Discrete Diffusion Models: A Benchmark for Schrödinger Bridges and Entropic Optimal Transport", accepted at ICLR 2026.
📌 TL;DR
This paper proposes a benchmark for entropic optimal transport (EOT) and Schrödinger Bridge (SB) methods on discrete spaces, and adapts several continuous EOT/SB approaches to the discrete setting.
📦 CatSBench (Package)
catsbench is the standalone benchmark package. It provides benchmark definitions, evaluation metrics, and reusable utilities, including a Triton-optimized log-sum-exp (LSE) matmul kernel.
📥 Installation
Install the benchmark package via pip:
pip install catsbench
🚀 Quickstart
Load a benchmark definition and its assets from a pretrained repository:
from catsbench import BenchmarkHD
bench = BenchmarkHD.from_pretrained(
"gregkseno/catsbench",
"hd_d2_s50_gaussian_a0.02_gaussian",
init_benchmark=False, # skip heavy initialization at load time
)
To sample marginals $p_0$ and $p_1$:
x_start = bench.sample_input(32) # [B=32, D=2]
x_end = bench.sample_target(32) # [B=32, D=2]
[!IMPORTANT] This samples independently from the marginals, i.e., ( (x_0, x_1) \sim p_0(x_0), p_1(x_1) ).
To sample from the ground-truth EOT/SB coupling, i.e., $(x_0, x_1) \sim p_0(x_0),q^*(x_1 | x_0)$, use:
x_start, x_end = bench.sample_input_target(32) # ([B=32, D=2], [B=32, D=2])
Or sample them separately:
x_start = bench.sample_input(32) # [B=32, D=2]
x_end = bench.sample(x_start) # [B=32, D=2]
[!NOTE] See the end-to-end benchmark workflow (initialization, evaluation, metrics, plotting) in
notebooks/benchmark_usage.ipynb
Reproducing Experiments
This part describes how to run the full training and evaluation pipeline to reproduce paper's results. It explains how to launch experiments for the provided methods (DLightSB, DLightSB-M, CSBM, $\alpha$-CSBM) and evaluate them on the benchmarks.
|-- configs
| |-- config.yaml # main Hydra entrypoint
| |-- callbacks # Lightning callbacks: benchmark metrics + visualization
| |-- data # datamodule/dataset configs
| |-- experiment # experiment presets (override bundles)
| |-- hydra # Hydra runtime/output settings
| |-- logger # logging backends (Comet, W&B, TensorBoard)
| |-- method # method-level configs (e.g., CSBM, DLightSB)
| |-- model # model architecture configs
| |-- prior # reference process configs
| `-- trainer # trainer, hardware, precision, runtime configs
|-- logs # logs, checkpoints, and run artifacts
|-- notebooks # analysis and baselines
|-- scripts # bash (+ SLURM) launch scripts
`-- src
|-- catsbench # benchmark package code
|-- data # Lightning datamodules + reference process implementation
|-- methods # training/inference methods (e.g., CSBM, DLightSB)
|-- metrics # callbacks computing benchmark metrics
|-- plotter # callbacks for plotting samples and trajectories
|-- utils # instantiation, logging, common helpers
`-- run.py # main entrypoint for training and testing
📦 Dependencies
Create the Anaconda environment using the following command:
conda env update -f environment.yml
and activate it:
conda activate catsbench
🏋️ Training
To start training, pick an experiment config under configs/experiment/<method_name>/benchmark_hd/<exp_name>.yaml and launch it with:
python -m src.run experiment=<method_name>/benchmark_hd/<exp_name>
Example:
python -m src.run experiment=dlight_sb/benchmark_hd/d2_g002
📊 Evaluation
Use the same experiment config as in training and set a checkpoint:
- Manual path:
logs/runs/<method_name>/benchmark_hd/<exp_name>/<seed>/<date>/epoch_<...>.ckpt - Or set
ckpt_path=autoto automatically load the latest checkpoint based on the config.
python -m src.run task_name=test ckpt_path=auto \
experiment=<method_name>/benchmark_hd/<exp_filename>
Example:
python -m src.run task_name=test ckpt_path=auto \ experiment=dlight_sb/benchmark_hd/d2_g002
🎓 Citation
@inproceedings{
carrasco2026entering,
title={Entering the Era of Discrete Diffusion Models: A Benchmark for Schr\"odinger Bridges and Entropic Optimal Transport},
author={Xavier Aramayo Carrasco and Grigoriy Ksenofontov and Aleksei Leonov and Iaroslav Sergeevich Koshelev and Alexander Korotin},
booktitle={The Fourteenth International Conference on Learning Representations},
year={2026},
url={https://openreview.net/forum?id=XcPDT615Gd}
}
🙏 Credits
- Comet ML — experiment-tracking and visualization toolkit;
- Inkscape — an excellent open-source editor for vector graphics;
- Hydra/Lightning template - project template used as a starting point.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file catsbench-1.1.tar.gz.
File metadata
- Download URL: catsbench-1.1.tar.gz
- Upload date:
- Size: 30.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b9c9e15153fa67a2bb1371889a7d5f7ddc5440586ed17feacff965736c1cd3e0
|
|
| MD5 |
1e629c0bf20466f5ba20a410ae2317c2
|
|
| BLAKE2b-256 |
e8458730ebf4245052060002391d10b4a22613b92a2a57ebacf442221ded562c
|
File details
Details for the file catsbench-1.1-py3-none-any.whl.
File metadata
- Download URL: catsbench-1.1-py3-none-any.whl
- Upload date:
- Size: 31.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d3bd6e2c4ee476a690aa60536cdd3c4a366b41d9dd49c1bec8878503e81fef36
|
|
| MD5 |
b263e86a84ab0555fa06aecfdf0a9bfe
|
|
| BLAKE2b-256 |
22220f43e37e8979bb4e2fd77dce2acffb1163abf9672f0af931330bbc1978d7
|