Energy-based Machine Learners
Project description
Learnergy: Energy-based Machine Learners
Welcome to Learnergy
Learnergy is a PyTorch-based framework for energy-based machine learning, providing ready-to-use implementations of Restricted Boltzmann Machines (RBMs) and Deep Belief Networks (DBNs). It is designed for researchers and practitioners who need a clean, modular library for unsupervised feature learning, generative modeling, and classification with energy-based models.
What you can do
- Train RBMs with various unit types: Bernoulli, Gaussian, Sigmoid, ReLU, SeLU
- Apply regularization: Dropout, DropConnect, and Energy-based Dropout
- Build deep architectures: stack RBMs into DBNs and Convolutional DBNs
- Use residual learning: ResidualDBN with skip connections for improved information flow
- Classify: Discriminative and Hybrid Discriminative RBMs for supervised tasks
- Visualize: convergence plots, weight mosaics, and tensor images
Quick start
import torchvision
from learnergy.models.bernoulli import RBM
# Load MNIST
train = torchvision.datasets.MNIST(
root="./data", train=True, download=True,
transform=torchvision.transforms.ToTensor(),
)
# Train a Bernoulli RBM
model = RBM(n_visible=784, n_hidden=128, steps=1, learning_rate=0.1)
mse, pl = model.fit(train, batch_size=128, epochs=5)
# Reconstruct
rec_mse, visible_probs = model.reconstruct(train)
For a Gaussian RBM with continuous inputs:
from learnergy.models.gaussian import GaussianRBM
model = GaussianRBM(n_visible=784, n_hidden=256, steps=1, learning_rate=0.005)
mse, pl = model.fit(train, batch_size=128, epochs=10)
For a Deep Belief Network:
from learnergy.models.deep import DBN
model = DBN(
model=("gaussian", "sigmoid"),
n_visible=784, n_hidden=(256, 128),
steps=(1, 1), learning_rate=(0.01, 0.01),
momentum=(0, 0), decay=(0, 0), temperature=(1, 1),
)
mse, pl = model.fit(train, batch_size=128, epochs=(5, 5))
Browse the examples/ directory for more use cases, including classification, convolutional models, and fine-tuning.
Learnergy is compatible with: Python 3.9+ and PyTorch 1.8+.
Architecture
For a detailed walkthrough of the codebase design, class hierarchy, and design patterns, see ARCHITECTURE.md.
learnergy/
├── core/ # Dataset and Model base classes
├── math/ # SSIM metrics, scaling utilities
├── models/
│ ├── bernoulli/ # RBM, ConvRBM, DiscriminativeRBM, Dropout/DropConnect, EDropout
│ ├── gaussian/ # GaussianRBM (+ ReLU, SeLU, Variance), GaussianConvRBM
│ ├── extra/ # SigmoidRBM
│ └── deep/ # DBN, ConvDBN, ResidualDBN
├── utils/ # Constants, custom exceptions, logging
└── visual/ # Convergence plots, image mosaics, tensor display
Available models
| Family | Models |
|---|---|
| Bernoulli | RBM, ConvRBM, DiscriminativeRBM, HybridDiscriminativeRBM, DropoutRBM, DropConnectRBM, EDropoutRBM |
| Gaussian | GaussianRBM, GaussianReluRBM, GaussianSeluRBM, VarianceGaussianRBM, GaussianConvRBM |
| Extra | SigmoidRBM |
| Deep | DBN, ConvDBN, ResidualDBN |
Installation
pip install learnergy
Or install from source for the latest version:
git clone https://github.com/gugarosa/learnergy.git
cd learnergy
pip install -e .
Dependencies
| Package | Version | Purpose |
|---|---|---|
| PyTorch | ≥ 1.8.0 | Core tensor operations and GPU support |
| torchvision | ≥ 0.9.0 | Dataset loading and transforms |
| matplotlib | ≥ 3.3.4 | Visualization |
| Pillow | ≥ 8.1.2 | Image mosaic creation |
| scikit-image | ≥ 0.17.2 | SSIM metric |
| tqdm | ≥ 4.49.0 | Progress bars |
Citation
If you use Learnergy to fulfill any of your needs, please cite us:
@misc{roder2020learnergy,
title={Learnergy: Energy-based Machine Learners},
author={Mateus Roder and Gustavo Henrique de Rosa and João Paulo Papa},
year={2020},
eprint={2003.07443},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Support
If you need to report a bug or have questions, please open an issue or reach out at mateus.roder@unesp.br and gustavo.rosa@unesp.br.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file learnergy-1.2.0.tar.gz.
File metadata
- Download URL: learnergy-1.2.0.tar.gz
- Upload date:
- Size: 36.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
36315efd6cc7bad657ea31ad6f547a3b14f5ca00a265a3a9ca05645095e6b602
|
|
| MD5 |
5b9cf0b817f4008ef7ae8b0a3be8900c
|
|
| BLAKE2b-256 |
aa421570ec18f599fe0d00a544105368192f2179d03586d7f5c0bc9f7e615ed1
|
File details
Details for the file learnergy-1.2.0-py3-none-any.whl.
File metadata
- Download URL: learnergy-1.2.0-py3-none-any.whl
- Upload date:
- Size: 48.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
050c85a4c19540f84edd5a0318dee8438ff9a42ac3bcbfec8d14cac00ebdc773
|
|
| MD5 |
7d8d99106c49fbddb79cea85ae96e9a9
|
|
| BLAKE2b-256 |
341cab86873addaaecc2ca67879dddc436addf341dfb2ab78e349e7f00ccd8aa
|