Components and algorithms for energy-based models
Project description
⚡ Energy-Based Modeling library for PyTorch, offering tools for 🔬 sampling, 🧠 inference, and 📊 learning in complex distributions.
| Gaussian Function | Double Well Function | Rastrigin Function | Rosenbrock Function |
About
TorchEBM is a CUDA-accelerated parallel library for Energy-Based Models (EBMs) built on PyTorch. It provides efficient implementations of sampling, inference, and learning algorithms for EBMs, with a focus on scalability and performance.
Features
-
Core Components:
- Energy functions: Standard energy landscapes (Gaussian, Double Well, Rosenbrock, etc.)
- Base sampler interfaces and common utilities
-
Advanced Samplers:
- Langevin Dynamics: Gradient-based MCMC with stochastic updates
- Hamiltonian Monte Carlo (HMC): Efficient exploration using Hamiltonian dynamics
-
Performance Optimizations:
- CUDA-accelerated implementations
- Parallel sampling capabilities
- Extensive diagnostics
Installation
pip install torchebm
Usage Examples
Common Setup
import torch
from torchebm.core import GaussianEnergy, DoubleWellEnergy
# Set device for computation
device = "cuda" if torch.cuda.is_available() else "cpu"
# Define dimensions
dim = 10
n_samples = 250
n_steps = 500
Energy Function Examples
# Create a multivariate Gaussian energy function
gaussian_energy = GaussianEnergy(
mean=torch.zeros(dim, device=device), # Center at origin
cov=torch.eye(dim, device=device) # Identity covariance (standard normal)
)
# Create a double well potential
double_well_energy = DoubleWellEnergy(barrier_height=2.0)
1. Langevin Dynamics Sampling
from torchebm.samplers.langevin_dynamics import LangevinDynamics
# Define a 10D Gaussian energy function
energy_fn = GaussianEnergy(mean=torch.zeros(10), cov=torch.eye(10))
# Initialize Langevin dynamics sampler
langevin_sampler = LangevinDynamics(
energy_function=energy_fn, step_size=5e-3, device=device
).to(device)
# Sample 10,000 points in 10 dimensions
final_samples = langevin_sampler.sample(
dim=10, n_steps=500, n_samples=10000, return_trajectory=False
)
print(final_samples.shape) # Result shape: (10000, 10) - (n_samples, dim)
# Sample with trajectory and diagnostics
samples, diagnostics = langevin_sampler.sample(
dim=dim,
n_steps=n_steps,
n_samples=n_samples,
return_trajectory=True,
return_diagnostics=True,
)
print(samples.shape) # Trajectory shape: (250, 500, 10) - (samples, n_steps, dim)
print(diagnostics.shape) # Diagnostics shape: (500, 4, 250, 10) - (n_steps, 3, n_samples, dim)
# The diagnostics contain: Mean (dim=0), Variance (dim=1), Energy (dim=2)
2. Hamiltonian Monte Carlo (HMC)
from torchebm.samplers.hmc import HamiltonianMonteCarlo
# Define a 10D Gaussian energy function
energy_fn = GaussianEnergy(mean=torch.zeros(10), cov=torch.eye(10))
# Initialize HMC sampler
hmc_sampler = HamiltonianMonteCarlo(
energy_function=energy_fn, step_size=0.1, n_leapfrog_steps=10, device=device
)
# Sample 10,000 points in 10 dimensions
final_samples = hmc_sampler.sample(
dim=10, n_steps=500, n_samples=10000, return_trajectory=False
)
print(final_samples.shape) # Result shape: (10000, 10) - (n_samples, dim)
# Sample with diagnostics and trajectory
final_samples, diagnostics = hmc_sampler.sample(
n_samples=n_samples,
n_steps=n_steps,
dim=dim,
return_trajectory=True,
return_diagnostics=True,
)
print(final_samples.shape) # Trajectory shape: (250, 500, 10) - (n_samples, n_steps, dim)
print(diagnostics.shape) # Diagnostics shape: (500, 4, 250, 10) - (n_steps, 4, n_samples, dim)
# The diagnostics contain: Mean (dim=0), Variance (dim=1), Energy (dim=2), Acceptance rates (dim=3)
# Sample from a custom initialization
x_init = torch.randn(n_samples, dim, dtype=torch.float32, device=device)
samples = hmc_sampler.sample(x=x_init, n_steps=100)
print(samples.shape) # Result shape: (250, 10) -> (n_samples, dim)
Library Structure
torchebm/
├── core/ # Core functionality
│ ├── energy_function.py # Energy function definitions
│ ├── basesampler.py # Base sampler class
│ └── ...
├── samplers/ # Sampling algorithms
│ ├── langevin_dynamics.py # Langevin dynamics implementation
│ ├── mcmc.py # HMC implementation
│ └── ...
├── models/ # Neural network models
├── losses/ # BaseLoss functions for training
├── utils/ # Utility functions
└── cuda/ # CUDA optimizations
Visualization Examples
| Langevin Dynamics Sampling | Single Langevin Dynamics Trajectory | Parallel Langevin Dynamics Sampling |
Check out the examples/ directory for sample scripts:
langevin_dynamics_sampling.py: Demonstrates Langevin dynamics samplinghmc_examples.py: Demonstrates Hamiltonian Monte Carlo samplingenergy_fn_visualization.py: Visualizes various energy functions
Contributing
Contributions are welcome! Please check the issues page for current tasks or create a new issue to discuss proposed changes.
License
This project is licensed under the MIT License - see the LICENSE file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchebm-0.3.0.tar.gz.
File metadata
- Download URL: torchebm-0.3.0.tar.gz
- Upload date:
- Size: 24.8 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2e12087868d5d14e922c70399f77b35d1f1f98b4ef790999972d85691bed6c88
|
|
| MD5 |
ccaec9b97aa1b76dc36195dcabab33ee
|
|
| BLAKE2b-256 |
46b1ca12bf250fae1443eec124afa2688aa4b5dc357940bdc2c0cb875a929c51
|
File details
Details for the file torchebm-0.3.0-py3-none-any.whl.
File metadata
- Download URL: torchebm-0.3.0-py3-none-any.whl
- Upload date:
- Size: 24.9 MB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
44d318767f83e41e3d963b816662b1d1cb1462a2f8c5090cfc1db049717af97b
|
|
| MD5 |
826e085aaaa2d4dc0d9e4ae3185409e9
|
|
| BLAKE2b-256 |
179e3cd61c8b3c971f43b6445010e2b4e3f5bd6792ed0a345dc7a1b59fd2a409
|