Skip to main content

Components and algorithms for energy-based models

Project description

TorchEBM Logo

PyPI License GitHub Stars Ask DeepWiki Build Status Docs Downloads Python Versions

โšก Energy-Based Modeling library for PyTorch, offering tools for ๐Ÿ”ฌ sampling, ๐Ÿง  inference, and ๐Ÿ“Š learning in complex distributions.

ebm_training_animation.gif

What is โˆ‡ TorchEBM ๐Ÿ“?

Energy-Based Models (EBMs) offer a powerful and flexible framework for generative modeling by assigning an unnormalized probability (or "energy") to each data point. Lower energy corresponds to higher probability.

TorchEBM simplifies working with EBMs in PyTorch. It provides a suite of tools designed for researchers and practitioners, enabling efficient implementation and exploration of:

  • Defining complex energy functions: Easily create custom energy landscapes using PyTorch modules.
  • Training: Loss functions and procedures suitable for EBM parameter estimation including score matching and contrastive divergence variants.
  • Sampling: Algorithms to draw samples from the learned distribution ( p(x) ).

Documentation

For detailed documentation, including installation instructions, usage examples, and API references, please visit the ๐Ÿ“š TorchEBM Website.

Features

  • Core Components:

    • Energy functions: Standard energy landscapes (Gaussian, Double Well, Rosenbrock, etc.)
    • Datasets: Data generators for training and evaluation
    • Loss functions: Contrastive Divergence, Score Matching, and more
    • Sampling algorithms: Langevin Dynamics, Hamiltonian Monte Carlo (HMC), and more
    • Evaluation metrics: Diagnostics for sampling and training
  • Performance Optimizations:

    • CUDA-accelerated implementations
    • Parallel sampling capabilities
    • Extensive diagnostics
Gaussian Double Well Rastrigin Rosenbrock
Gaussian Function Double Well Function Rastrigin Function Rosenbrock Function

Installation

pip install torchebm

Dependencies

Usage Examples

Common Setup

import torch
from torchebm.core import GaussianEnergy, DoubleWellEnergy

# Set device for computation
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define dimensions
dim = 10
n_samples = 250
n_steps = 500

Energy Function Examples

# Create a multivariate Gaussian energy function
gaussian_energy = GaussianEnergy(
    mean=torch.zeros(dim, device=device),  # Center at origin
    cov=torch.eye(dim, device=device)      # Identity covariance (standard normal)
)

# Create a double well potential
double_well_energy = DoubleWellEnergy(barrier_height=2.0)

1. Training a simple EBM Over a Gaussian Mixture Using Langevin Dynamics Sampler

import torch.optim as optim
from torch.utils.data import DataLoader

from torchebm.losses import ContrastiveDivergence
from torchebm.datasets import GaussianMixtureDataset
from torchebm.samplers import LangevinDynamics

# Define an NN energy model
class MLPEnergy(BaseEnergyFunction):
    def __init__(self, input_dim, hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, x):
        return self.net(x).squeeze(-1) # a scalar value

energy_fn = MLPEnergy(input_dim=2).to(device)
sampler = LangevinDynamics(energy_function=energy_fn, step_size=0.01, device=device)

cd_loss_fn = ContrastiveDivergence(
  energy_function=energy_fn,
  sampler=sampler,
  k_steps=10  # MCMC steps for negative samples gen
)

optimizer = optim.Adam(energy_fn.parameters(), lr=0.001)

mixture_dataset = GaussianMixtureDataset(n_samples=500, n_components=4, std=0.1, seed=123).get_data()
dataloader = DataLoader(mixture_dataset, batch_size=32, shuffle=True)

# Training Loop
for epoch in range(10):
  epoch_loss = 0.0
  for i, batch_data in enumerate(dataloader):
    batch_data = batch_data.to(device)

    optimizer.zero_grad()

    loss, neg_samples = cd_loss(batch_data)

    loss.backward()
    optimizer.step()

    epoch_loss += loss.item()

  avg_loss = epoch_loss / len(dataloader)
  print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {avg_loss:.6f}")

2. Hamiltonian Monte Carlo (HMC)

from torchebm.samplers import HamiltonianMonteCarlo

# Define a 10-D Gaussian energy function
energy_fn = GaussianEnergy(mean=torch.zeros(10), cov=torch.eye(10))

# Initialize HMC sampler
hmc_sampler = HamiltonianMonteCarlo(
  energy_function=energy_fn, step_size=0.1, n_leapfrog_steps=10, device=device
)

# Sample 10,000 points in 10 dimensions
final_samples = hmc_sampler.sample(
  dim=10, n_steps=500, n_samples=10000, return_trajectory=False
)
print(final_samples.shape)  # Result batch_shape: (10000, 10) - (n_samples, dim)

# Sample with diagnostics and trajectory
final_samples, diagnostics = hmc_sampler.sample(
  n_samples=n_samples,
  n_steps=n_steps,
  dim=dim,
  return_trajectory=True,
  return_diagnostics=True,
)

print(final_samples.shape)  # Trajectory batch_shape: (250, 500, 10) - (n_samples, k_steps, dim)
print(diagnostics.shape)  # Diagnostics batch_shape: (500, 4, 250, 10) - (k_steps, 4, n_samples, dim)
# The diagnostics contain: Mean (dim=0), Variance (dim=1), Energy (dim=2), Acceptance rates (dim=3)

# Sample from a custom initialization
x_init = torch.randn(n_samples, dim, dtype=torch.float32, device=device)
samples = hmc_sampler.sample(x=x_init, n_steps=100)
print(samples.shape)  # Result batch_shape: (250, 10) -> (n_samples, dim)

Library Structure

torchebm/
โ”œโ”€โ”€ core/                  # Core functionality
โ”‚   โ”œโ”€โ”€ energy_function.py # Energy function definitions
โ”‚   โ”œโ”€โ”€ basesampler.py     # Base sampler class
โ”‚   โ””โ”€โ”€ ...
โ”œโ”€โ”€ samplers/              # Sampling algorithms
โ”‚   โ”œโ”€โ”€ langevin_dynamics.py  # Langevin dynamics implementation
โ”‚   โ”œโ”€โ”€ mcmc.py            # HMC implementation
โ”‚   โ””โ”€โ”€ ...
โ”œโ”€โ”€ models/                # Neural network models
โ”œโ”€โ”€ evaluation/            # Evaluation metrics and utilities
โ”œโ”€โ”€ datasets/
โ”‚   โ””โ”€โ”€ generators.py      # Data generators for training
โ”œโ”€โ”€ losses/                # BaseLoss functions for training
โ”œโ”€โ”€ utils/                 # Utility functions
โ””โ”€โ”€ cuda/                  # CUDA optimizations

Visualization Examples

Langevin Dynamics Sampling Single Langevin Dynamics Trajectory Parallel Langevin Dynamics Sampling
Langevin Dynamics Sampling Single Langevin Dynamics Trajectory Parallel Langevin Dynamics Sampling

Check out the examples/ directory for sample scripts:

  • samplers/: Demonstrates different sampling algorithms
  • datasets/: Depicts data generation using built-in datasets
  • training_models/: Shows how to train energy-based models using TorchEBM
  • visualization/: Visualizes sampling results and trajectories
  • and more!

Contributing

Contributions are welcome! Step-by-step instructions for contributing to the project can be found on the contributing.md page on the website.

Please check the issues page for current tasks or create a new issue to discuss proposed changes.

Show your Support for โˆ‡ TorchEBM ๐Ÿ“

Please โญ๏ธ this repository if โˆ‡ TorchEBM helped you and spread the word.

Thank you! ๐Ÿš€

Citation

If you use โˆ‡ TorchEBM in your research, please cite it using the following BibTeX entry:

@misc{torchebm_library_2025,
  author       = {Ghaderi, Soran and Contributors},
  title        = {{TorchEBM}: A PyTorch Library for Training Energy-Based Models},
  year         = {2025},
  url          = {https://github.com/soran-ghaderi/torchebm},
}

Changelog

For a detailed list of changes between versions, please see our CHANGELOG.

License

This project is licensed under the MIT License - see the LICENSE file for details.

Research Collaboration

If you are interested in collaborating on research projects (diffusion-/flow-/energy-based models) or have any questions about the library, please feel free to reach out. I am open to discussions and collaborations that can enhance the capabilities of โˆ‡ TorchEBM ๐Ÿ“ and contribute to the field of generative modeling.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchebm-0.5.4.tar.gz (25.7 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchebm-0.5.4-py3-none-any.whl (25.8 MB view details)

Uploaded Python 3

File details

Details for the file torchebm-0.5.4.tar.gz.

File metadata

  • Download URL: torchebm-0.5.4.tar.gz
  • Upload date:
  • Size: 25.7 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for torchebm-0.5.4.tar.gz
Algorithm Hash digest
SHA256 c16e8b02a17bfa8d7a4bf971c31c5f8ba3e08b15fffa947d35053e9ac09b0ad0
MD5 28df19e6f0d1bfbdd0f5068e38cbbf3d
BLAKE2b-256 b03ec6d372b44d72ef33cd763947e966c098b64ddc03e48bc977896dff735592

See more details on using hashes here.

File details

Details for the file torchebm-0.5.4-py3-none-any.whl.

File metadata

  • Download URL: torchebm-0.5.4-py3-none-any.whl
  • Upload date:
  • Size: 25.8 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for torchebm-0.5.4-py3-none-any.whl
Algorithm Hash digest
SHA256 55372fd8e4c69af297e31e9c6c2d57c65d6b5b70a398f11f60c156e27b93c17d
MD5 17ed0f2c83142fc3ecf70470a5429c9a
BLAKE2b-256 52da9a7f14df326ebcefa307be16d114c79c3ca48231f7449d7dadff9d19b678

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page