Components and algorithms for energy-based models
Project description
โก Energy-Based Modeling library for PyTorch, offering tools for ๐ฌ sampling, ๐ง inference, and ๐ learning in complex distributions.
What is โ TorchEBM ๐?
Energy-Based Models (EBMs) offer a powerful and flexible framework for generative modeling by assigning an unnormalized probability (or "energy") to each data point. Lower energy corresponds to higher probability.
TorchEBM simplifies working with EBMs in PyTorch. It provides a suite of tools designed for researchers and practitioners, enabling efficient implementation and exploration of:
- Defining complex energy functions: Easily create custom energy landscapes using PyTorch modules.
- Training: Loss functions and procedures suitable for EBM parameter estimation including score matching and contrastive divergence variants.
- Sampling: Algorithms to draw samples from the learned distribution ( p(x) ).
Documentation
For detailed documentation, including installation instructions, usage examples, and API references, please visit the ๐ TorchEBM Website.
Features
-
Core Components:
- Energy functions: Standard energy landscapes (Gaussian, Double Well, Rosenbrock, etc.)
- Datasets: Data generators for training and evaluation
- Loss functions: Contrastive Divergence, Score Matching, and more
- Sampling algorithms: Langevin Dynamics, Hamiltonian Monte Carlo (HMC), and more
- Evaluation metrics: Diagnostics for sampling and training
-
Performance Optimizations:
- CUDA-accelerated implementations
- Parallel sampling capabilities
- Extensive diagnostics
| Gaussian Function | Double Well Function | Rastrigin Function | Rosenbrock Function |
Installation
pip install torchebm
Dependencies
- PyTorch (with CUDA support for optimal performance)
- Other dependencies are listed in requirements.txt
Usage Examples
Common Setup
import torch
from torchebm.core import GaussianEnergy, DoubleWellEnergy
# Set device for computation
device = "cuda" if torch.cuda.is_available() else "cpu"
# Define dimensions
dim = 10
n_samples = 250
n_steps = 500
Energy Function Examples
# Create a multivariate Gaussian energy function
gaussian_energy = GaussianEnergy(
mean=torch.zeros(dim, device=device), # Center at origin
cov=torch.eye(dim, device=device) # Identity covariance (standard normal)
)
# Create a double well potential
double_well_energy = DoubleWellEnergy(barrier_height=2.0)
1. Training a simple EBM Over a Gaussian Mixture Using Langevin Dynamics Sampler
import torch.optim as optim
from torch.utils.data import DataLoader
from torchebm.losses import ContrastiveDivergence
from torchebm.datasets import GaussianMixtureDataset
from torchebm.samplers import LangevinDynamics
# Define an NN energy model
class MLPEnergy(BaseEnergyFunction):
def __init__(self, input_dim, hidden_dim=64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, 1),
)
def forward(self, x):
return self.net(x).squeeze(-1) # a scalar value
energy_fn = MLPEnergy(input_dim=2).to(device)
sampler = LangevinDynamics(energy_function=energy_fn, step_size=0.01, device=device)
cd_loss_fn = ContrastiveDivergence(
energy_function=energy_fn,
sampler=sampler,
k_steps=10 # MCMC steps for negative samples gen
)
optimizer = optim.Adam(energy_fn.parameters(), lr=0.001)
mixture_dataset = GaussianMixtureDataset(n_samples=500, n_components=4, std=0.1, seed=123).get_data()
dataloader = DataLoader(mixture_dataset, batch_size=32, shuffle=True)
# Training Loop
for epoch in range(10):
epoch_loss = 0.0
for i, batch_data in enumerate(dataloader):
batch_data = batch_data.to(device)
optimizer.zero_grad()
loss, neg_samples = cd_loss(batch_data)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
avg_loss = epoch_loss / len(dataloader)
print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {avg_loss:.6f}")
2. Hamiltonian Monte Carlo (HMC)
from torchebm.samplers import HamiltonianMonteCarlo
# Define a 10-D Gaussian energy function
energy_fn = GaussianEnergy(mean=torch.zeros(10), cov=torch.eye(10))
# Initialize HMC sampler
hmc_sampler = HamiltonianMonteCarlo(
energy_function=energy_fn, step_size=0.1, n_leapfrog_steps=10, device=device
)
# Sample 10,000 points in 10 dimensions
final_samples = hmc_sampler.sample(
dim=10, n_steps=500, n_samples=10000, return_trajectory=False
)
print(final_samples.shape) # Result batch_shape: (10000, 10) - (n_samples, dim)
# Sample with diagnostics and trajectory
final_samples, diagnostics = hmc_sampler.sample(
n_samples=n_samples,
n_steps=n_steps,
dim=dim,
return_trajectory=True,
return_diagnostics=True,
)
print(final_samples.shape) # Trajectory batch_shape: (250, 500, 10) - (n_samples, k_steps, dim)
print(diagnostics.shape) # Diagnostics batch_shape: (500, 4, 250, 10) - (k_steps, 4, n_samples, dim)
# The diagnostics contain: Mean (dim=0), Variance (dim=1), Energy (dim=2), Acceptance rates (dim=3)
# Sample from a custom initialization
x_init = torch.randn(n_samples, dim, dtype=torch.float32, device=device)
samples = hmc_sampler.sample(x=x_init, n_steps=100)
print(samples.shape) # Result batch_shape: (250, 10) -> (n_samples, dim)
Library Structure
torchebm/
โโโ core/ # Core functionality
โ โโโ energy_function.py # Energy function definitions
โ โโโ basesampler.py # Base sampler class
โ โโโ ...
โโโ samplers/ # Sampling algorithms
โ โโโ langevin_dynamics.py # Langevin dynamics implementation
โ โโโ mcmc.py # HMC implementation
โ โโโ ...
โโโ models/ # Neural network models
โโโ evaluation/ # Evaluation metrics and utilities
โโโ datasets/
โ โโโ generators.py # Data generators for training
โโโ losses/ # BaseLoss functions for training
โโโ utils/ # Utility functions
โโโ cuda/ # CUDA optimizations
Visualization Examples
| Langevin Dynamics Sampling | Single Langevin Dynamics Trajectory | Parallel Langevin Dynamics Sampling |
Check out the examples/ directory for sample scripts:
samplers/: Demonstrates different sampling algorithmsdatasets/: Depicts data generation using built-in datasetstraining_models/: Shows how to train energy-based models using TorchEBMvisualization/: Visualizes sampling results and trajectories- and more!
Contributing
Contributions are welcome! Step-by-step instructions for contributing to the project can be found on the contributing.md page on the website.
Please check the issues page for current tasks or create a new issue to discuss proposed changes.
Show your Support for โ TorchEBM ๐
Please โญ๏ธ this repository if โ TorchEBM helped you and spread the word.
Thank you! ๐
Citation
If you use โ TorchEBM in your research, please cite it using the following BibTeX entry:
@misc{torchebm_library_2025,
author = {Ghaderi, Soran and Contributors},
title = {{TorchEBM}: A PyTorch Library for Training Energy-Based Models},
year = {2025},
url = {https://github.com/soran-ghaderi/torchebm},
}
Changelog
For a detailed list of changes between versions, please see our CHANGELOG.
License
This project is licensed under the MIT License - see the LICENSE file for details.
Research Collaboration
If you are interested in collaborating on research projects (diffusion-/flow-/energy-based models) or have any questions about the library, please feel free to reach out. I am open to discussions and collaborations that can enhance the capabilities of โ TorchEBM ๐ and contribute to the field of generative modeling.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchebm-0.4.0.tar.gz.
File metadata
- Download URL: torchebm-0.4.0.tar.gz
- Upload date:
- Size: 26.2 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bef9d56e4d66249fc455989fe3b21b1147eac2e8c518aa83f99780fec4ab4205
|
|
| MD5 |
a115d43d727a739527449ec301bb307a
|
|
| BLAKE2b-256 |
9fcc39d004b9595f68dbc47e32d9bcd58e977ec378df4afb82601124c3846897
|
File details
Details for the file torchebm-0.4.0-py3-none-any.whl.
File metadata
- Download URL: torchebm-0.4.0-py3-none-any.whl
- Upload date:
- Size: 26.2 MB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0a3a953e6e1a8853f47474aa45509bd568e08f44c3418e62facedf00b8b536ef
|
|
| MD5 |
ddd6d6b8d17b19f180643fc7ab2a0184
|
|
| BLAKE2b-256 |
628558e42f711003aa44c378229610279666b0520d994351b4904b85fd2e1d98
|