Skip to main content

A dynamic CNN ensembling framework for PyTorch

Project description

PyEnsembleCNN: A dynamic CNN ensembling framework for PyTorch

A PyTorch Implementation of Self-Learning CNN Ensembles with Integrated Visualization

Key Features

Flexible ensemble architecture supporting any CNN backbone (ResNet, DenseNet, VGG, etc.)

  • Novel trainable weighted averaging system that automatically learns optimal ensemble proportions
  • Built-in Class Activation Mapping (CAM) visualization support for model interpretability

Two distinct ensemble approaches:

  • AverageEnsemble: Features learned weight distribution across models
  • StackEnsemble: Concatenates features for enhanced representation power

Technical Innovation

The framework introduces a unique approach to ensemble weighting by making the model weights themselves trainable parameters. Instead of using static weights based on individual model performance, this implementation allows the ensemble to dynamically learn the optimal contribution of each model during training.

Technical Details

Usage

# Example usage with common CNN architectures
extractors = [resnet50(weights='IMAGENET1K_V1'),
              densenet121(weights='IMAGENET1K_V1'),
              vgg16(weights='IMAGENET1K_V1')]  # Pre-trained models

# Replace the classification heads of each extractor with pooling or interpolation (downscaling vs upscaling)
for i in range(len(extractors)):
  replace_classifier(extractors[i], 2048, pooling='avg') # Supports avg and max pooling

# Simple MLP for classification
classifier = nn.Sequential(
    nn.Linear(2048, 512),
    nn.ReLU(),
    nn.Linear(512, num_classes)
)

# Initialize ensemble with automatic weight learning and train
model = AverageEnsemble(extractors, classifier, CAM=True)
best_epoch, best_model = train(model, train_loader, val_loader, optimizer, criterion, epochs, device='cuda', verbose=True)

Advanced Features

  • Automatic Architecture Validation: Built-in dimension checking ensures compatibility between extractors and classifiers
  • GPU Support: Seamless device transition with comprehensive to(device) implementation
  • Flexible Feature Handling: Supports both weighted averaging and feature stacking approaches
  • Integrated Visualization: Native support for Class Activation Mapping in AverageEnsemble
  • Memory Efficient: Automatic freezing of extractor weights to optimize memory usage

Implementation Highlights

The core innovation lies in the AverageEnsemble class, which implements a weighted average of the ensembles through a learnable parameter:

self.proportions = nn.Parameter(torch.randn(len(extractors)))
proportions = self.proportion_softmax(self.proportions)

Applications

Originally developed for medical imaging applications, but designed to be domain-agnostic and applicable to any computer vision task requiring ensemble methods, including:

  • Medical image analysis
  • Object detection
  • Image classification
  • Visual reasoning tasks

Future Development

  • Benchmark comparisons against traditional ensemble methods
  • Integration of additional visualization techniques
  • Support for non-CNN architectures
  • Performance optimization for large-scale deployments

Installation and Dependencies

pip install PyEnsembleCNN

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pyensemblecnn-1.0.1.tar.gz (4.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

PyEnsembleCNN-1.0.1-py3-none-any.whl (4.4 kB view details)

Uploaded Python 3

File details

Details for the file pyensemblecnn-1.0.1.tar.gz.

File metadata

  • Download URL: pyensemblecnn-1.0.1.tar.gz
  • Upload date:
  • Size: 4.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.8.10

File hashes

Hashes for pyensemblecnn-1.0.1.tar.gz
Algorithm Hash digest
SHA256 20ca8f955d01ab839a23ba3233318a6973a294e1fdcd4c184fa4f90bb323de12
MD5 eea465f12dc4a93da99d59c53e279acf
BLAKE2b-256 354e02398e5b03476c65cf6502d84be2e9ad5fb4b19069971decfc1a755f306d

See more details on using hashes here.

File details

Details for the file PyEnsembleCNN-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: PyEnsembleCNN-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 4.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.8.10

File hashes

Hashes for PyEnsembleCNN-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b1cfafa43f7286abed3e749a633155161c48c7bd20aada18135fde005ef5152e
MD5 d5e6afdda0a86376b0ca81e6d1fb1e45
BLAKE2b-256 4a2d987fef51b9732d5c9934fc2e6edac89fb11d9817638bbb19e94f6f0aef95

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page