Deep Ensemble Energy Models - RBM-based ensemble aggregation for crowd learning and classifier combination
Project description
DEEM - Deep Ensemble Energy Models
DEEM is a Python library for training Restricted Boltzmann Machines (RBMs) on ensemble predictions from multiple classifiers. It provides a scikit-learn compatible API for unsupervised ensemble aggregation, crowd learning, and model combination.
Features
- 🚀 Simple 3-line API - Fit and predict in just a few lines of code
- 🔬 Unsupervised Learning - No labels required for training (though they can be used for evaluation)
- 🧮 Energy-Based Models - Uses RBMs to learn the joint distribution of classifier predictions
- 🎯 Hungarian Alignment - Automatic label permutation handling via Hungarian algorithm
- ⚡ GPU Acceleration - Full PyTorch backend with CUDA support
- 🔧 Scikit-learn Compatible - Standard
.fit(),.predict(),.score()interface - 📊 Automatic Hyperparameters - Optional meta-learning for hyperparameter selection
Installation
pip install deem
From Source
git clone https://github.com/Rem4rkable/rbm_python.git
cd rbm_python
pip install -e .
Quick Start
import numpy as np
from deem import DEEM
# Ensemble predictions from 15 classifiers on 100 samples with 3 classes
predictions = np.random.randint(0, 3, (100, 15))
# Train and predict in 3 lines!
model = DEEM()
model.fit(predictions)
consensus = model.predict(predictions)
With Evaluation
# If you have true labels, evaluate with automatic label alignment
model = DEEM(n_classes=3, epochs=50)
model.fit(train_predictions)
# Automatically handles label permutation problem
accuracy = model.score(test_predictions, test_labels)
print(f"Consensus accuracy: {accuracy:.2%}")
Custom Configuration
model = DEEM(
n_classes=5,
hidden_dim=2, # Number of hidden units
learning_rate=0.01,
epochs=100,
batch_size=64,
cd_k=10, # Contrastive divergence steps
deterministic=True, # Use probabilities (more stable)
device='cuda' # Use GPU
)
model.fit(predictions)
Use Cases
1. Crowd Learning / Multi-Annotator Aggregation
Aggregate noisy labels from multiple human annotators:
# annotator_labels: (n_samples, n_annotators) with values 0 to k-1
model = DEEM(n_classes=k)
model.fit(annotator_labels)
consensus_labels = model.predict(annotator_labels)
2. Ensemble Model Aggregation
Combine predictions from multiple trained classifiers:
# Get predictions from multiple models
predictions = np.column_stack([
model1.predict(X),
model2.predict(X),
model3.predict(X),
# ... more models
])
# Learn optimal aggregation
ensemble = DEEM()
ensemble.fit(predictions)
final_predictions = ensemble.predict(predictions)
3. Missing Predictions
DEEM automatically handles cases where some classifiers don't provide predictions (use -1 for missing):
predictions = np.array([
[0, 1, -1, 2, 1], # Classifier 3 missing
[1, 1, 1, -1, 1], # Classifier 4 missing
# ...
])
model = DEEM(n_classes=3)
model.fit(predictions) # Missing values handled automatically
How It Works
DEEM uses Restricted Boltzmann Machines (RBMs) - energy-based models that learn the joint probability distribution over classifier predictions and hidden representations. The key insight is that multiple weak classifiers contain complementary information that can be combined through unsupervised learning.
Key Components
- Energy Function: Models compatibility between visible (predictions) and hidden (consensus) states
- Contrastive Divergence: Trains the RBM using DLP/GWG sampling
- Hungarian Algorithm: Solves the label permutation problem during evaluation
Architecture
Classifier Predictions → RBM → Hidden Representation → Consensus Label
(visible layer) (hidden layer)
API Reference
DEEM
Main class for ensemble aggregation.
Parameters:
n_classes(int, optional): Number of classes. Auto-detected if not specified.hidden_dim(int, default=1): Number of hidden units.cd_k(int, default=10): Contrastive divergence steps.deterministic(bool, default=True): Use probabilities instead of sampling.learning_rate(float, default=0.001): Learning rate for SGD.momentum(float, default=0.9): SGD momentum.epochs(int, default=100): Training epochs.batch_size(int, default=128): Batch size.device(str, default='auto'): Device ('cpu', 'cuda', or 'auto').random_state(int, optional): Random seed.
Methods:
fit(predictions, labels=None, **kwargs): Train the modelpredict(predictions, return_probs=False): Get consensus predictionspredict_with_hungarian(predictions, true_labels): Predict with label alignmentscore(predictions, true_labels): Compute accuracy with Hungarian alignmentsave(path): Save model to diskload(path): Load model from diskget_params(): Get parameters (sklearn compatibility)set_params(**params): Set parameters (sklearn compatibility)
Advanced Features
Automatic Hyperparameter Selection
model = DEEM(
auto_hyperparameters=True,
model_dir='saved_hyp_models_v1' # Path to trained predictor
)
model.fit(predictions) # Hyperparameters automatically selected
Save and Load Models
# Save trained model
model.save('my_ensemble.pt')
# Load later
model = DEEM()
model.load('my_ensemble.pt')
predictions = model.predict(new_data)
Soft Labels (Probability Distributions)
DEEM can also work with soft predictions (probability distributions):
# soft_predictions: (n_samples, n_classes, n_classifiers)
model = DEEM(n_classes=3)
model.fit(soft_predictions)
Requirements
- Python >= 3.8
- PyTorch >= 1.9
- NumPy >= 1.19
- SciPy >= 1.7
Optional:
- scikit-learn >= 0.24 (for automatic hyperparameter selection)
Citation
If you use DEEM in your research, please cite:
@software{deem2026,
title={DEEM: Deep Ensemble Energy Models for Classifier Aggregation},
author={[Your Name]},
year={2026},
url={https://github.com/Rem4rkable/rbm_python}
}
License
This project is licensed under the MIT License - see the LICENSE file for details.
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
Acknowledgments
- Based on research in energy-based models and crowd learning
- Built with PyTorch and inspired by scikit-learn's API design
Links
- GitHub: https://github.com/Rem4rkable/rbm_python
- Documentation: [Coming soon]
- Issues: https://github.com/Rem4rkable/rbm_python/issues
Made with ❤️ for the machine learning community
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file deem-0.1.1.tar.gz.
File metadata
- Download URL: deem-0.1.1.tar.gz
- Upload date:
- Size: 50.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.25
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
20e9ef5a661334daa0105ef3289574f4b374325413dc7a0b32be523f571dadcd
|
|
| MD5 |
774a232187e9d1ca604babb7e53499d7
|
|
| BLAKE2b-256 |
d59f416755e7ef321c97e775c7800b6be65ffc7db0efe082957ae025bb5d9349
|
File details
Details for the file deem-0.1.1-py3-none-any.whl.
File metadata
- Download URL: deem-0.1.1-py3-none-any.whl
- Upload date:
- Size: 59.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.25
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9c0084eaa377f16d91b19c3a4162825f17283dc699fbc8f00ab0b39efa18953d
|
|
| MD5 |
e01f153f6012c6faeff8b9c6bfd54acf
|
|
| BLAKE2b-256 |
014369c813b0a417d3d721d0f8bac29e9209f1657951fb685e27da73bd685fba
|