Skip to main content

Automatic CNN feature extraction and ML model comparison

Project description

CNN Feature Extractor

A Python package for automatic CNN feature extraction and ML model comparison. Extract features from images using pre-trained CNN models and evaluate multiple ML classifiers in one go.

Installation

pip install cnn_feature_extractor

Quick Start with CIFAR10

import torch
import torchvision
from torchvision import transforms
from cnn_feature_extractor import CNNFeatureExtractor

# Set image size
image_size = 128

# Define transforms
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load CIFAR10 dataset
train_dataset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=True,
    download=True,
    transform=transform
)

val_dataset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=False,
    download=True,
    transform=transform
)

# Create data loaders
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=4
)

# Initialize and run feature extraction + ML comparison
extractor = CNNFeatureExtractor(save_path='cifar10_results.csv')
results = extractor.fit(
    train_loader, 
    val_loader,

    # Example 1: Using specific models
    cnn_models=['resnet18', 'efficientnet_b0'],    

    # Example 2: Using the tiny package (fastest, good for testing)
    # cnn_models='tiny',  # This will use: mobilenet_v2, mobilenet_v3_small, efficientnet_b0, convnext_tiny, resnet18

    # Example 3: Mixing packages
    # cnn_models='tiny + small',  # This will combine models from both packages
    
    ml_models=['RandomForest', 'LogisticRegression']
)

Using Your Custom Dataset

Required Dataset Structure

dataset/
├── train/
│   ├── class1/
│   │   ├── image1.jpg
│   │   └── image2.jpg
│   └── class2/
│       ├── image3.jpg
│       └── image4.jpg
└── val/
    ├── class1/
    │   └── image5.jpg
    └── class2/
        └── image6.jpg

Custom Dataset Example

from cnn_feature_extractor import CNNFeatureExtractor
from cnn_feature_extractor.utils.dataset import load_custom_dataset

# Set image size and other parameters
image_size = 224  # Standard size for most CNN models
batch_size = 32
num_workers = 4

# Load your custom dataset
train_loader, val_loader, num_classes = load_custom_dataset(
    data_dir='path/to/your/dataset',  # Path to your dataset root directory
    batch_size=batch_size,
    num_workers=num_workers,
    image_size=image_size,
    augment=True  # Enable data augmentation (optional)
)

# Initialize feature extractor
extractor = CNNFeatureExtractor(save_path='results.csv')

# Run feature extraction and ML comparison
results = extractor.fit(
    train_loader, 
    val_loader,

    # Example 1: Using specific models
    cnn_models=['resnet18', 'efficientnet_b0'],    

    # Example 2: Using the tiny package (fastest, good for testing)
    # cnn_models='tiny',  # This will use: mobilenet_v2, mobilenet_v3_small, efficientnet_b0, convnext_tiny, resnet18

    # Example 3: Mixing packages
    # cnn_models='tiny + small',  # This will combine models from both packages
    
    ml_models=['RandomForest', 'LogisticRegression']
)

# Results will be saved to 'results.csv'
print(results)

Available Models

CNN Feature Extractors

Tiny Package (Fast, Lower Accuracy)

  • mobilenet_v2
  • mobilenet_v3_small
  • efficientnet_b0
  • convnext_tiny
  • resnet18

Small Package

  • resnet34
  • densenet121
  • mobilenet_v3_large
  • efficientnet_b1
  • convnext_small

Medium Package

  • resnet50
  • densenet169
  • vgg16
  • efficientnet_b2
  • convnext_base

Large Package

  • resnet101
  • densenet201
  • vgg19
  • efficientnet_b3
  • convnext_large

Biggest Package (Slow, Higher Accuracy)

  • resnet152
  • densenet201
  • efficientnet_b7
  • convnext_large
  • vgg19

ML Classifiers

  • RandomForest
  • SVM (with probability estimation)
  • LogisticRegression
  • GradientBoosting
  • XGBoost
  • LightGBM
  • KNN
  • DecisionTree
  • AdaBoost
  • GaussianNB
  • RidgeClassifier
  • SGDClassifier
  • LinearSVC

Package Usage Tips

  1. Choosing CNN Models:

    • Start with 'tiny' package models for quick experiments
    • Use 'biggest' package models for maximum accuracy
    • Mix models from different packages: cnn_models=['resnet18', 'efficientnet_b7']
  2. Choosing ML Models:

    • Start with fast models like LogisticRegression
    • Use RandomForest or XGBoost for better accuracy
    • Try multiple models: ml_models=['LogisticRegression', 'RandomForest', 'XGBoost']
  3. Data Augmentation:

    • Enable with augment=True in load_custom_dataset
    • Helps prevent overfitting
    • Especially useful for small datasets
  4. GPU Usage:

    • GPU is automatically used if available
    • CNN feature extraction is significantly faster on GPU
    • Some ML models (XGBoost, LightGBM) can also use GPU

Saving Models

The package automatically saves all trained models to a directory for future use.

from cnn_feature_extractor import CNNFeatureExtractor

# Models will be saved to 'my_models' directory
extractor = CNNFeatureExtractor(save_path='results.csv', models_dir='my_models')

# Run feature extraction and ML comparison
results = extractor.fit(
    train_loader, 
    val_loader,
    cnn_models=['resnet18', 'efficientnet_b0'],    
    ml_models=['RandomForest', 'LogisticRegression']
)

# Models are saved in the specified directory:
# - CNN models: my_models/cnn_models/
# - ML models: my_models/ml_models/

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cnn_feature_extractor-0.2.0.tar.gz (11.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cnn_feature_extractor-0.2.0-py3-none-any.whl (12.2 kB view details)

Uploaded Python 3

File details

Details for the file cnn_feature_extractor-0.2.0.tar.gz.

File metadata

  • Download URL: cnn_feature_extractor-0.2.0.tar.gz
  • Upload date:
  • Size: 11.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.0

File hashes

Hashes for cnn_feature_extractor-0.2.0.tar.gz
Algorithm Hash digest
SHA256 2c5f31394936c688504efa5158d389da1d9b01c7f9e4a99fde3aa09136d4fa42
MD5 360264ad11d76daff056cb70aef39f96
BLAKE2b-256 e0fc62d994c3874e3fa302fa01fe4037e6ec61a79f3579021d7ea399a1aa0a6f

See more details on using hashes here.

File details

Details for the file cnn_feature_extractor-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for cnn_feature_extractor-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 461d5dc7e173cbe792f49bb76745741f1bbb20a6e3721fc67b0ae90510cc6651
MD5 ab7ecdc8d7afbb9e04ab13c42b612e62
BLAKE2b-256 01c3d544e16964e68e7eaa1301fe036b6765258a5f18fbaff97e958a96ba38f2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page