Skip to main content

A dynamic CNN ensembling framework for PyTorch

Project description

PyEnsembleCNN: A dynamic CNN ensembling framework for PyTorch

A PyTorch Implementation of Self-Learning CNN Ensembles with Integrated Visualization

Key Features

Flexible ensemble architecture supporting any CNN backbone (ResNet, DenseNet, VGG, etc.)

  • Novel trainable weighted averaging system that automatically learns optimal ensemble proportions
  • Built-in Class Activation Mapping (CAM) visualization support for model interpretability

Two distinct ensemble approaches:

  • AverageEnsemble: Features learned weight distribution across models
  • StackEnsemble: Concatenates features for enhanced representation power

Technical Innovation

The framework introduces a unique approach to ensemble weighting by making the model weights themselves trainable parameters. Instead of using static weights based on individual model performance, this implementation allows the ensemble to dynamically learn the optimal contribution of each model during training.

Technical Details

Usage

# Example usage with common CNN architectures
extractors = [resnet50(weights='IMAGENET1K_V1'),
              densenet121(weights='IMAGENET1K_V1'),
              vgg16(weights='IMAGENET1K_V1')]  # Pre-trained models

# Replace the classification heads of each extractor with pooling or interpolation (downscaling vs upscaling)
for i in range(len(extractors)):
  replace_classifier(extractors[i], 2048, pooling='avg') # Supports avg and max pooling

# Simple MLP for classification
classifier = nn.Sequential(
    nn.Linear(2048, 512),
    nn.ReLU(),
    nn.Linear(512, num_classes)
)

# Initialize ensemble with automatic weight learning and train
model = AverageEnsemble(extractors, classifier, CAM=True)
best_epoch, best_model = train(model, train_loader, val_loader, optimizer, criterion, epochs, device='cuda', verbose=True)

Advanced Features

  • Automatic Architecture Validation: Built-in dimension checking ensures compatibility between extractors and classifiers
  • GPU Support: Seamless device transition with comprehensive to(device) implementation
  • Flexible Feature Handling: Supports both weighted averaging and feature stacking approaches
  • Integrated Visualization: Native support for Class Activation Mapping in AverageEnsemble
  • Memory Efficient: Automatic freezing of extractor weights to optimize memory usage

Implementation Highlights

The core innovation lies in the AverageEnsemble class, which implements a weighted average of the ensembles through a learnable parameter:

self.proportions = nn.Parameter(torch.randn(len(extractors)))
proportions = self.proportion_softmax(self.proportions)

Applications

Originally developed for medical imaging applications, but designed to be domain-agnostic and applicable to any computer vision task requiring ensemble methods, including:

  • Medical image analysis
  • Object detection
  • Image classification
  • Visual reasoning tasks

Future Development

  • Benchmark comparisons against traditional ensemble methods
  • Integration of additional visualization techniques
  • Support for non-CNN architectures
  • Performance optimization for large-scale deployments

Installation and Dependencies

pip install torch
pip install pytorch-grad-cam
pip install ??? (PyPI name not finalized)

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pyensemblecnn-1.0.0.tar.gz (4.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

PyEnsembleCNN-1.0.0-py3-none-any.whl (4.5 kB view details)

Uploaded Python 3

File details

Details for the file pyensemblecnn-1.0.0.tar.gz.

File metadata

  • Download URL: pyensemblecnn-1.0.0.tar.gz
  • Upload date:
  • Size: 4.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.8.10

File hashes

Hashes for pyensemblecnn-1.0.0.tar.gz
Algorithm Hash digest
SHA256 13802ab6a29360be5b8380ac289b249ea4fa5ffffaccdefd3c1c607b571ca650
MD5 9ef5019db09442f286a5aee0c0eab99e
BLAKE2b-256 fa60393591f09246a4d1afee53dde61aa1f168d0068782e3ed1e8c39c9325a66

See more details on using hashes here.

File details

Details for the file PyEnsembleCNN-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: PyEnsembleCNN-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 4.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.8.10

File hashes

Hashes for PyEnsembleCNN-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e924e70952892fcf7f85926edb6dc249c95a26d90a91e41ec95768710075f5db
MD5 6f09f2467b3b166774527a6fae494a27
BLAKE2b-256 4149786b0dfa4c56f79ac44923f11e90c007d2b441589dcd673b6a89d85ab98e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page