Skip to main content

A lightweight PyTorch toolkit for building and training image classification models

Project description

Image Classification Tools

A lightweight PyTorch toolkit for building and training image classification models.

Overview

This package provides utilities for common image classification tasks:

  • Data loading: Flexible data loaders for torchvision datasets and custom image folders
  • Model training: Training loops with progress tracking and validation
  • Evaluation: Accuracy metrics, confusion matrices, and performance analysis
  • Visualization: Learning curves, probability distributions, and evaluation plots
  • Hyperparameter optimization: Optuna integration for automated model tuning

Installation

pip install image-classification-tools

Quick start

Basic usage

import torch
from pathlib import Path
from torchvision import datasets, transforms
from image_classification_tools.pytorch.data import (
    load_datasets, prepare_splits, create_dataloaders
)
from image_classification_tools.pytorch.training import train_model
from image_classification_tools.pytorch.evaluation import evaluate_model

# Define transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load datasets
train_dataset, test_dataset = load_datasets(
    data_source=datasets.MNIST,
    train_transform=transform,
    eval_transform=transform,
    download=True,
    root=Path('./data/mnist')
)

# Prepare splits
train_dataset, val_dataset, test_dataset = prepare_splits(
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    train_val_split=0.8
)

# Create dataloaders
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader, val_loader, test_loader = create_dataloaders(
    train_dataset, val_dataset, test_dataset,
    batch_size=64,
    preload_to_memory=True,
    device=device
)

# Define model, criterion, optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = torch.nn.Sequential(
    torch.nn.Flatten(),
    torch.nn.Linear(784, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 10)
).to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Train with lazy loading (moves batches to device during training)
history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    lazy_loading=True,  # Set False if data already on device
    epochs=10
)

# Evaluate
accuracy, predictions, labels = evaluate_model(model, test_loader)
print(f'Test accuracy: {accuracy:.2f}%')

Hyperparameter optimization

from image_classification_tools.pytorch.hyperparameter_optimization import create_objective
import optuna

# Define search space
search_space = {
    'batch_size': [32, 64, 128],
    'n_conv_blocks': (1, 3),
    'initial_filters': [16, 32, 64],
    'n_fc_layers': (1, 3),
    'conv_dropout_rate': (0.1, 0.5),
    'fc_dropout_rate': (0.3, 0.7),
    'learning_rate': (1e-4, 1e-2, 'log'),
    'optimizer': ['Adam', 'SGD'],
    'weight_decay': (1e-6, 1e-3, 'log')
}

# Create objective function
objective = create_objective(
    data_dir='./data',
    train_transform=transform,
    eval_transform=transform,
    n_epochs=20,
    device=device,
    num_classes=10,
    in_channels=1,
    search_space=search_space
)

# Run optimization
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)

Requirements

  • Python ≥ 3.10
  • PyTorch ≥ 2.0.0
  • torchvision ≥ 0.15.0
  • numpy
  • matplotlib
  • optuna (optional, for hyperparameter optimization)

Documentation

Full documentation is available at: https://gperdrizet.github.io/CIFAR10/

Demo project

See a complete example of using this package for CIFAR-10 classification: https://github.com/gperdrizet/CIFAR10

License

GPLv3

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

image_classification_tools-0.5.9.tar.gz (17.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

image_classification_tools-0.5.9-py3-none-any.whl (21.7 kB view details)

Uploaded Python 3

File details

Details for the file image_classification_tools-0.5.9.tar.gz.

File metadata

File hashes

Hashes for image_classification_tools-0.5.9.tar.gz
Algorithm Hash digest
SHA256 77d0e9a1a4a7435e83c6756946db0946aa4a46ac810c770d3fcee18ea225ba02
MD5 9abce5b307ac7caee6cbc4842d958948
BLAKE2b-256 d815ee21abd5d575f6f655c5de6f250ea6216a3c18b8c471dc0b9361d30cf0a3

See more details on using hashes here.

Provenance

The following attestation bundles were made for image_classification_tools-0.5.9.tar.gz:

Publisher: publish-to-pypi.yml on gperdrizet/CIFAR10

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file image_classification_tools-0.5.9-py3-none-any.whl.

File metadata

File hashes

Hashes for image_classification_tools-0.5.9-py3-none-any.whl
Algorithm Hash digest
SHA256 a3826fa8f0d858c27115d96d3856261c5c81d05c82bb924fd7b0d6b4d6801117
MD5 3d915da13a061caa639191fe5b5ed4f3
BLAKE2b-256 c49bca60c188ae11802ec7907aab1eb4c6038670d705ca614b192fd8e69e67e7

See more details on using hashes here.

Provenance

The following attestation bundles were made for image_classification_tools-0.5.9-py3-none-any.whl:

Publisher: publish-to-pypi.yml on gperdrizet/CIFAR10

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page