A dynamic CNN ensembling framework for PyTorch
Project description
PyEnsembleCNN: A dynamic CNN ensembling framework for PyTorch
A PyTorch Implementation of Self-Learning CNN Ensembles with Integrated Visualization
Key Features
Flexible ensemble architecture supporting any CNN backbone (ResNet, DenseNet, VGG, etc.)
- Novel trainable weighted averaging system that automatically learns optimal ensemble proportions
- Built-in Class Activation Mapping (CAM) visualization support for model interpretability
Two distinct ensemble approaches:
- AverageEnsemble: Features learned weight distribution across models
- StackEnsemble: Concatenates features for enhanced representation power
Technical Innovation
The framework introduces a unique approach to ensemble weighting by making the model weights themselves trainable parameters. Instead of using static weights based on individual model performance, this implementation allows the ensemble to dynamically learn the optimal contribution of each model during training.
Technical Details
Usage
# Example usage with common CNN architectures
extractors = [resnet50(weights='IMAGENET1K_V1'),
densenet121(weights='IMAGENET1K_V1'),
vgg16(weights='IMAGENET1K_V1')] # Pre-trained models
# Replace the classification heads of each extractor with pooling or interpolation (downscaling vs upscaling)
for i in range(len(extractors)):
replace_classifier(extractors[i], 2048, pooling='avg') # Supports avg and max pooling
# Simple MLP for classification
classifier = nn.Sequential(
nn.Linear(2048, 512),
nn.ReLU(),
nn.Linear(512, num_classes)
)
# Initialize ensemble with automatic weight learning and train
model = AverageEnsemble(extractors, classifier, CAM=True)
best_epoch, best_model = train(model, train_loader, val_loader, optimizer, criterion, epochs, device='cuda', verbose=True)
Advanced Features
- Automatic Architecture Validation: Built-in dimension checking ensures compatibility between extractors and classifiers
- GPU Support: Seamless device transition with comprehensive to(device) implementation
- Flexible Feature Handling: Supports both weighted averaging and feature stacking approaches
- Integrated Visualization: Native support for Class Activation Mapping in AverageEnsemble
- Memory Efficient: Automatic freezing of extractor weights to optimize memory usage
Implementation Highlights
The core innovation lies in the AverageEnsemble class, which implements a weighted average of the ensembles through a learnable parameter:
self.proportions = nn.Parameter(torch.randn(len(extractors)))
proportions = self.proportion_softmax(self.proportions)
Applications
Originally developed for medical imaging applications, but designed to be domain-agnostic and applicable to any computer vision task requiring ensemble methods, including:
- Medical image analysis
- Object detection
- Image classification
- Visual reasoning tasks
Future Development
- Benchmark comparisons against traditional ensemble methods
- Integration of additional visualization techniques
- Support for non-CNN architectures
- Performance optimization for large-scale deployments
Installation and Dependencies
pip install PyEnsembleCNN
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pyensemblecnn-1.0.1.tar.gz.
File metadata
- Download URL: pyensemblecnn-1.0.1.tar.gz
- Upload date:
- Size: 4.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.8.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
20ca8f955d01ab839a23ba3233318a6973a294e1fdcd4c184fa4f90bb323de12
|
|
| MD5 |
eea465f12dc4a93da99d59c53e279acf
|
|
| BLAKE2b-256 |
354e02398e5b03476c65cf6502d84be2e9ad5fb4b19069971decfc1a755f306d
|
File details
Details for the file PyEnsembleCNN-1.0.1-py3-none-any.whl.
File metadata
- Download URL: PyEnsembleCNN-1.0.1-py3-none-any.whl
- Upload date:
- Size: 4.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.8.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b1cfafa43f7286abed3e749a633155161c48c7bd20aada18135fde005ef5152e
|
|
| MD5 |
d5e6afdda0a86376b0ca81e6d1fb1e45
|
|
| BLAKE2b-256 |
4a2d987fef51b9732d5c9934fc2e6edac89fb11d9817638bbb19e94f6f0aef95
|