Skip to main content

Deep Ensemble Energy Models - RBM-based ensemble aggregation for crowd learning and classifier combination

Project description

DEEM - Deep Ensemble Energy Models

PyPI version Python 3.8+ License: MIT

DEEM is a Python library for training Restricted Boltzmann Machines (RBMs) on ensemble predictions from multiple classifiers. It provides a scikit-learn compatible API for unsupervised ensemble aggregation, crowd learning, and model combination.

Features

  • 🚀 Simple 3-line API - Fit and predict in just a few lines of code
  • 🔬 Unsupervised Learning - No labels required for training (though they can be used for evaluation)
  • 🧮 Energy-Based Models - Uses RBMs to learn the joint distribution of classifier predictions
  • 🎯 Hungarian Alignment - Automatic label permutation handling via Hungarian algorithm
  • GPU Acceleration - Full PyTorch backend with CUDA support
  • 🔧 Scikit-learn Compatible - Standard .fit(), .predict(), .score() interface
  • 📊 Automatic Hyperparameters - Optional meta-learning for hyperparameter selection

Installation

pip install deem

From Source

git clone https://github.com/Rem4rkable/rbm_python.git
cd rbm_python
pip install -e .

Quick Start

import numpy as np
from deem import DEEM

# Ensemble predictions from 15 classifiers on 100 samples with 3 classes
predictions = np.random.randint(0, 3, (100, 15))

# Train and predict in 3 lines!
model = DEEM()
model.fit(predictions)
consensus = model.predict(predictions)

With Evaluation

# If you have true labels, evaluate with automatic label alignment
model = DEEM(n_classes=3, epochs=50)
model.fit(train_predictions)

# Automatically handles label permutation problem
accuracy = model.score(test_predictions, test_labels)
print(f"Consensus accuracy: {accuracy:.2%}")

Custom Configuration

model = DEEM(
    n_classes=5,
    hidden_dim=2,           # Number of hidden units
    learning_rate=0.01,
    epochs=100,
    batch_size=64,
    cd_k=10,                # Contrastive divergence steps
    deterministic=True,     # Use probabilities (more stable)
    device='cuda'           # Use GPU
)
model.fit(predictions)

Use Cases

1. Crowd Learning / Multi-Annotator Aggregation

Aggregate noisy labels from multiple human annotators:

# annotator_labels: (n_samples, n_annotators) with values 0 to k-1
model = DEEM(n_classes=k)
model.fit(annotator_labels)
consensus_labels = model.predict(annotator_labels)

2. Ensemble Model Aggregation

Combine predictions from multiple trained classifiers:

# Get predictions from multiple models
predictions = np.column_stack([
    model1.predict(X),
    model2.predict(X),
    model3.predict(X),
    # ... more models
])

# Learn optimal aggregation
ensemble = DEEM()
ensemble.fit(predictions)
final_predictions = ensemble.predict(predictions)

3. Missing Predictions

DEEM automatically handles cases where some classifiers don't provide predictions (use -1 for missing):

predictions = np.array([
    [0, 1, -1, 2, 1],  # Classifier 3 missing
    [1, 1, 1, -1, 1],  # Classifier 4 missing
    # ...
])
model = DEEM(n_classes=3)
model.fit(predictions)  # Missing values handled automatically

How It Works

DEEM uses Restricted Boltzmann Machines (RBMs) - energy-based models that learn the joint probability distribution over classifier predictions and hidden representations. The key insight is that multiple weak classifiers contain complementary information that can be combined through unsupervised learning.

Key Components

  1. Energy Function: Models compatibility between visible (predictions) and hidden (consensus) states
  2. Contrastive Divergence: Trains the RBM using DLP/GWG sampling
  3. Hungarian Algorithm: Solves the label permutation problem during evaluation

Architecture

Classifier Predictions → RBM → Hidden Representation → Consensus Label
     (visible layer)           (hidden layer)

API Reference

DEEM

Main class for ensemble aggregation.

Parameters:

  • n_classes (int, optional): Number of classes. Auto-detected if not specified.
  • hidden_dim (int, default=1): Number of hidden units.
  • cd_k (int, default=10): Contrastive divergence steps.
  • deterministic (bool, default=True): Use probabilities instead of sampling.
  • learning_rate (float, default=0.001): Learning rate for SGD.
  • momentum (float, default=0.9): SGD momentum.
  • epochs (int, default=100): Training epochs.
  • batch_size (int, default=128): Batch size.
  • device (str, default='auto'): Device ('cpu', 'cuda', or 'auto').
  • random_state (int, optional): Random seed.

Methods:

  • fit(predictions, labels=None, **kwargs): Train the model
  • predict(predictions, return_probs=False): Get consensus predictions
  • predict_with_hungarian(predictions, true_labels): Predict with label alignment
  • score(predictions, true_labels): Compute accuracy with Hungarian alignment
  • save(path): Save model to disk
  • load(path): Load model from disk
  • get_params(): Get parameters (sklearn compatibility)
  • set_params(**params): Set parameters (sklearn compatibility)

Advanced Features

Automatic Hyperparameter Selection

model = DEEM(
    auto_hyperparameters=True,
    model_dir='saved_hyp_models_v1'  # Path to trained predictor
)
model.fit(predictions)  # Hyperparameters automatically selected

Save and Load Models

# Save trained model
model.save('my_ensemble.pt')

# Load later
model = DEEM()
model.load('my_ensemble.pt')
predictions = model.predict(new_data)

Soft Labels (Probability Distributions)

DEEM can also work with soft predictions (probability distributions):

# soft_predictions: (n_samples, n_classes, n_classifiers)
model = DEEM(n_classes=3)
model.fit(soft_predictions)

Requirements

  • Python >= 3.8
  • PyTorch >= 1.9
  • NumPy >= 1.19
  • SciPy >= 1.7

Optional:

  • scikit-learn >= 0.24 (for automatic hyperparameter selection)

Citation

If you use DEEM in your research, please cite:

@software{deem2026,
  title={DEEM: Deep Ensemble Energy Models for Classifier Aggregation},
  author={[Your Name]},
  year={2026},
  url={https://github.com/Rem4rkable/rbm_python}
}

License

This project is licensed under the MIT License - see the LICENSE file for details.

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Acknowledgments

  • Based on research in energy-based models and crowd learning
  • Built with PyTorch and inspired by scikit-learn's API design

Links


Made with ❤️ for the machine learning community

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deem-0.1.1.tar.gz (50.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

deem-0.1.1-py3-none-any.whl (59.6 kB view details)

Uploaded Python 3

File details

Details for the file deem-0.1.1.tar.gz.

File metadata

  • Download URL: deem-0.1.1.tar.gz
  • Upload date:
  • Size: 50.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for deem-0.1.1.tar.gz
Algorithm Hash digest
SHA256 20e9ef5a661334daa0105ef3289574f4b374325413dc7a0b32be523f571dadcd
MD5 774a232187e9d1ca604babb7e53499d7
BLAKE2b-256 d59f416755e7ef321c97e775c7800b6be65ffc7db0efe082957ae025bb5d9349

See more details on using hashes here.

File details

Details for the file deem-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: deem-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 59.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for deem-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 9c0084eaa377f16d91b19c3a4162825f17283dc699fbc8f00ab0b39efa18953d
MD5 e01f153f6012c6faeff8b9c6bfd54acf
BLAKE2b-256 014369c813b0a417d3d721d0f8bac29e9209f1657951fb685e27da73bd685fba

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page