Skip to main content

Components and algorithms for energy-based models

Project description

TorchEBM Logo

PyPI License GitHub Stars Ask DeepWiki Build Status Docs Downloads Python Versions

⚡ A PyTorch library for energy-based modeling, with support for flow and diffusion methods.

EBM Training Animation

What is ∇ TorchEBM 🍓?

Energy-based models define distributions through a scalar energy function, where lower energy means higher probability. This is a very general formulation and many generative approaches, from MCMC sampling to score matching to flow-based generation, can be understood through this lens.

TorchEBM is a PyTorch library that gives you composable tools for this entire spectrum. You can define energy landscapes, train models with various learning objectives, and sample via MCMC, optimization, or learned continuous-time dynamics (ODEs/SDEs). The library handles classical EBM training (contrastive divergence, score matching) as well as modern interpolant-based and equilibrium-based generation methods.

📚 For the full documentation, please visit the official website of TorchEBM 🍓.

Features

  • Energy models with built-in analytical potentials and support for custom neural network energy functions
  • MCMC and optimization-based samplers for drawing samples from energy landscapes
  • Flow and diffusion samplers that generate via ODE/SDE integration of learned velocity or score fields
  • Training objectives including contrastive divergence variants, score matching variants, and equilibrium matching
  • Interpolation schemes for specifying noise-to-data paths in flow and diffusion models
  • Numerical integrators for SDE, ODE, and Hamiltonian dynamics
  • Neural network architectures ready for conditional generation
  • Synthetic datasets for rapid prototyping and benchmarking
  • Hyperparameter schedulers for step sizes, noise scales, and other training parameters
  • CUDA acceleration and mixed precision support

8 Gaussians Flow

Gaussian Double Well Rastrigin Rosenbrock
Gaussian Double Well Rastrigin Rosenbrock
Gaussian Mixture Two Moons Swiss Roll Checkerboard
Gaussian Mixture Two Moons Swiss Roll Checkerboard

Installation

pip install torchebm

Dependencies

Usage Examples

MCMC Sampling

import torch
from torchebm.core import GaussianModel
from torchebm.samplers import LangevinDynamics

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GaussianModel(mean=torch.zeros(2), cov=torch.eye(2), device=device)

sampler = LangevinDynamics(model=model, step_size=0.01, device=device)
samples = sampler.sample(x=torch.randn(500, 2, device=device), n_steps=100)
print(samples.shape)  # torch.Size([500, 2])

Training with Contrastive Divergence

import torch
from torchebm.core import BaseModel
from torchebm.samplers import LangevinDynamics
from torchebm.losses import ContrastiveDivergence
from torchebm.datasets import GaussianMixtureDataset
from torch.utils.data import DataLoader

class MLPEnergy(BaseModel):
    def __init__(self, dim):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(dim, 64), torch.nn.SiLU(),
            torch.nn.Linear(64, 64), torch.nn.SiLU(),
            torch.nn.Linear(64, 1),
        )
    def forward(self, x):
        return self.net(x).squeeze(-1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MLPEnergy(dim=2).to(device)
sampler = LangevinDynamics(model=model, step_size=0.01, device=device)
cd_loss = ContrastiveDivergence(model=model, sampler=sampler, k_steps=10)

data = GaussianMixtureDataset(n_samples=1000, n_components=4).get_data()
loader = DataLoader(data, batch_size=64, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(10):
    for batch in loader:
        optimizer.zero_grad()
        loss, _ = cd_loss(batch.to(device))
        loss.backward()
        optimizer.step()

Hamiltonian Monte Carlo

import torch
from torchebm.core import GaussianModel
from torchebm.samplers import HamiltonianMonteCarlo

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GaussianModel(mean=torch.zeros(10), cov=torch.eye(10), device=device)

hmc = HamiltonianMonteCarlo(model=model, step_size=0.1, n_leapfrog_steps=10, device=device)
samples = hmc.sample(dim=10, n_steps=500, n_samples=1000)
print(samples.shape)  # torch.Size([1000, 10])

Library Structure

torchebm/
├── core/           # Base classes, energy models, schedulers, device management
├── samplers/       # MCMC, optimization, and flow/diffusion samplers
├── losses/         # Training objectives (CD, score matching, equilibrium matching)
├── interpolants/   # Noise-to-data interpolation schemes
├── integrators/    # Numerical integrators for SDE/ODE/Hamiltonian dynamics
├── models/         # Neural network architectures
├── datasets/       # Synthetic data generators
├── utils/          # Visualization and training utilities
└── cuda/           # CUDA-accelerated implementations

Visualization Examples

Langevin Dynamics Sampling Langevin Dynamics Trajectory Parallel Sampling
Langevin Dynamics Sampling Langevin Dynamics Trajectory Parallel Sampling

Flow Comparison
Equilibrium Matching: Linear, VP, and Cosine interpolants transforming noise into data.

Check out the examples/ directory for sample scripts.

Contributing

Contributions are welcome! Step-by-step instructions for contributing to the project can be found on the contributing.md page on the website.

Please check the issues page for current tasks or create a new issue to discuss proposed changes.

Show your Support for ∇ TorchEBM 🍓

Please ⭐️ this repository if ∇ TorchEBM helped you and spread the word.

Thank you! 🚀

Citation

If TorchEBM is useful in your research, please cite it:

@misc{torchebm_library_2025,
  author       = {Ghaderi, Soran and Contributors},
  title        = {{TorchEBM}: A PyTorch Library for Training Energy-Based Models},
  year         = {2025},
  url          = {https://github.com/soran-ghaderi/torchebm},
}

Changelog

See CHANGELOG for version history.

License

MIT License. See LICENSE for details.

Research Collaboration

If you are interested in collaborating on research around energy-based, flow-based, or diffusion models, feel free to reach out. Contributions to TorchEBM 🍓 and discussions that push the field forward are always welcome.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchebm-0.5.7.tar.gz (31.4 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchebm-0.5.7-py3-none-any.whl (31.4 MB view details)

Uploaded Python 3

File details

Details for the file torchebm-0.5.7.tar.gz.

File metadata

  • Download URL: torchebm-0.5.7.tar.gz
  • Upload date:
  • Size: 31.4 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for torchebm-0.5.7.tar.gz
Algorithm Hash digest
SHA256 400961bb413c0038f51280cebd13bed3050036c2085cf547676d3dce7e8f7c0c
MD5 54db9fbf82b7bcab5cde1a693a8162a3
BLAKE2b-256 17a1acbd10784a2576cce6bf2bd94b7c39d15f764f40aac918912b28356b3ef5

See more details on using hashes here.

File details

Details for the file torchebm-0.5.7-py3-none-any.whl.

File metadata

  • Download URL: torchebm-0.5.7-py3-none-any.whl
  • Upload date:
  • Size: 31.4 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for torchebm-0.5.7-py3-none-any.whl
Algorithm Hash digest
SHA256 a22bb07f2aa3a152bbbadbaa63d61af1750493f3dd272de08d80830aeffca21e
MD5 0cc65a2a51a90a2b4f112802ac94e12a
BLAKE2b-256 db9d7ebbd0a2f117416e1b7e3256c852e0efed055d6fb3bbe31aea85ab94a6c3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page