Skip to main content

Energy-based Machine Learners

Project description

Learnergy: Energy-based Machine Learners

Latest release DOI Open issues License

Welcome to Learnergy

Learnergy is a PyTorch-based framework for energy-based machine learning, providing ready-to-use implementations of Restricted Boltzmann Machines (RBMs) and Deep Belief Networks (DBNs). It is designed for researchers and practitioners who need a clean, modular library for unsupervised feature learning, generative modeling, and classification with energy-based models.

What you can do

  • Train RBMs with various unit types: Bernoulli, Gaussian, Sigmoid, ReLU, SeLU
  • Apply regularization: Dropout, DropConnect, and Energy-based Dropout
  • Build deep architectures: stack RBMs into DBNs and Convolutional DBNs
  • Use residual learning: ResidualDBN with skip connections for improved information flow
  • Classify: Discriminative and Hybrid Discriminative RBMs for supervised tasks
  • Visualize: convergence plots, weight mosaics, and tensor images

Quick start

import torchvision
from learnergy.models.bernoulli import RBM

# Load MNIST
train = torchvision.datasets.MNIST(
    root="./data", train=True, download=True,
    transform=torchvision.transforms.ToTensor(),
)

# Train a Bernoulli RBM
model = RBM(n_visible=784, n_hidden=128, steps=1, learning_rate=0.1)
mse, pl = model.fit(train, batch_size=128, epochs=5)

# Reconstruct
rec_mse, visible_probs = model.reconstruct(train)

For a Gaussian RBM with continuous inputs:

from learnergy.models.gaussian import GaussianRBM

model = GaussianRBM(n_visible=784, n_hidden=256, steps=1, learning_rate=0.005)
mse, pl = model.fit(train, batch_size=128, epochs=10)

For a Deep Belief Network:

from learnergy.models.deep import DBN

model = DBN(
    model=("gaussian", "sigmoid"),
    n_visible=784, n_hidden=(256, 128),
    steps=(1, 1), learning_rate=(0.01, 0.01),
    momentum=(0, 0), decay=(0, 0), temperature=(1, 1),
)
mse, pl = model.fit(train, batch_size=128, epochs=(5, 5))

Browse the examples/ directory for more use cases, including classification, convolutional models, and fine-tuning.

Learnergy is compatible with: Python 3.9+ and PyTorch 1.8+.


Architecture

For a detailed walkthrough of the codebase design, class hierarchy, and design patterns, see ARCHITECTURE.md.

learnergy/
├── core/          # Dataset and Model base classes
├── math/          # SSIM metrics, scaling utilities
├── models/
│   ├── bernoulli/ # RBM, ConvRBM, DiscriminativeRBM, Dropout/DropConnect, EDropout
│   ├── gaussian/  # GaussianRBM (+ ReLU, SeLU, Variance), GaussianConvRBM
│   ├── extra/     # SigmoidRBM
│   └── deep/      # DBN, ConvDBN, ResidualDBN
├── utils/         # Constants, custom exceptions, logging
└── visual/        # Convergence plots, image mosaics, tensor display

Available models

Family Models
Bernoulli RBM, ConvRBM, DiscriminativeRBM, HybridDiscriminativeRBM, DropoutRBM, DropConnectRBM, EDropoutRBM
Gaussian GaussianRBM, GaussianReluRBM, GaussianSeluRBM, VarianceGaussianRBM, GaussianConvRBM
Extra SigmoidRBM
Deep DBN, ConvDBN, ResidualDBN

Installation

pip install learnergy

Or install from source for the latest version:

git clone https://github.com/gugarosa/learnergy.git
cd learnergy
pip install -e .

Dependencies

Package Version Purpose
PyTorch ≥ 1.8.0 Core tensor operations and GPU support
torchvision ≥ 0.9.0 Dataset loading and transforms
matplotlib ≥ 3.3.4 Visualization
Pillow ≥ 8.1.2 Image mosaic creation
scikit-image ≥ 0.17.2 SSIM metric
tqdm ≥ 4.49.0 Progress bars

Citation

If you use Learnergy to fulfill any of your needs, please cite us:

@misc{roder2020learnergy,
    title={Learnergy: Energy-based Machine Learners},
    author={Mateus Roder and Gustavo Henrique de Rosa and João Paulo Papa},
    year={2020},
    eprint={2003.07443},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

Support

If you need to report a bug or have questions, please open an issue or reach out at mateus.roder@unesp.br and gustavo.rosa@unesp.br.


Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

learnergy-1.2.0.tar.gz (36.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

learnergy-1.2.0-py3-none-any.whl (48.5 kB view details)

Uploaded Python 3

File details

Details for the file learnergy-1.2.0.tar.gz.

File metadata

  • Download URL: learnergy-1.2.0.tar.gz
  • Upload date:
  • Size: 36.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for learnergy-1.2.0.tar.gz
Algorithm Hash digest
SHA256 36315efd6cc7bad657ea31ad6f547a3b14f5ca00a265a3a9ca05645095e6b602
MD5 5b9cf0b817f4008ef7ae8b0a3be8900c
BLAKE2b-256 aa421570ec18f599fe0d00a544105368192f2179d03586d7f5c0bc9f7e615ed1

See more details on using hashes here.

File details

Details for the file learnergy-1.2.0-py3-none-any.whl.

File metadata

  • Download URL: learnergy-1.2.0-py3-none-any.whl
  • Upload date:
  • Size: 48.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for learnergy-1.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 050c85a4c19540f84edd5a0318dee8438ff9a42ac3bcbfec8d14cac00ebdc773
MD5 7d8d99106c49fbddb79cea85ae96e9a9
BLAKE2b-256 341cab86873addaaecc2ca67879dddc436addf341dfb2ab78e349e7f00ccd8aa

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page