A lightweight PyTorch toolkit for building and training image classification models
Project description
Image Classification Tools
A lightweight PyTorch toolkit for building and training image classification models.
Overview
This package provides utilities for common image classification tasks:
- Data loading: Flexible data loaders for torchvision datasets and custom image folders
- Model training: Training loops with progress tracking and validation
- Evaluation: Accuracy metrics, confusion matrices, and performance analysis
- Visualization: Learning curves, probability distributions, and evaluation plots
- Hyperparameter optimization: Optuna integration for automated model tuning
Installation
pip install image-classification-tools
Quick start
Basic usage
import torch
from pathlib import Path
from torchvision import datasets, transforms
from image_classification_tools.pytorch.data import (
load_datasets, prepare_splits, create_dataloaders
)
from image_classification_tools.pytorch.training import train_model
from image_classification_tools.pytorch.evaluation import evaluate_model
# Define transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Load datasets
train_dataset, test_dataset = load_datasets(
data_source=datasets.MNIST,
train_transform=transform,
eval_transform=transform,
download=True,
root=Path('./data/mnist')
)
# Prepare splits
train_dataset, val_dataset, test_dataset = prepare_splits(
train_dataset=train_dataset,
test_dataset=test_dataset,
train_val_split=0.8
)
# Create dataloaders
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader, val_loader, test_loader = create_dataloaders(
train_dataset, val_dataset, test_dataset,
batch_size=64,
preload_to_memory=True,
device=device
)
# Define model, criterion, optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.nn.Sequential(
torch.nn.Flatten(),
torch.nn.Linear(784, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 10)
).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Train with lazy loading (moves batches to device during training)
history = train_model(
model=model,
train_loader=train_loader,
val_loader=val_loader,
criterion=criterion,
optimizer=optimizer,
device=device,
lazy_loading=True, # Set False if data already on device
epochs=10
)
# Evaluate
accuracy, predictions, labels = evaluate_model(model, test_loader)
print(f'Test accuracy: {accuracy:.2f}%')
Hyperparameter optimization
from image_classification_tools.pytorch.hyperparameter_optimization import create_objective
import optuna
# Define search space
search_space = {
'batch_size': [32, 64, 128],
'n_conv_blocks': (1, 3),
'initial_filters': [16, 32, 64],
'n_fc_layers': (1, 3),
'conv_dropout_rate': (0.1, 0.5),
'fc_dropout_rate': (0.3, 0.7),
'learning_rate': (1e-4, 1e-2, 'log'),
'optimizer': ['Adam', 'SGD'],
'weight_decay': (1e-6, 1e-3, 'log')
}
# Create objective function
objective = create_objective(
data_dir='./data',
train_transform=transform,
eval_transform=transform,
n_epochs=20,
device=device,
num_classes=10,
in_channels=1,
search_space=search_space
)
# Run optimization
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)
Requirements
- Python ≥ 3.10
- PyTorch ≥ 2.0.0
- torchvision ≥ 0.15.0
- numpy
- matplotlib
- optuna (optional, for hyperparameter optimization)
Documentation
Full documentation is available at: https://gperdrizet.github.io/CIFAR10/
Demo project
See a complete example of using this package for CIFAR-10 classification: https://github.com/gperdrizet/CIFAR10
License
GPLv3
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file image_classification_tools-0.5.6.tar.gz.
File metadata
- Download URL: image_classification_tools-0.5.6.tar.gz
- Upload date:
- Size: 16.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cdb3ad534d9e9e193801b50e74cbaff080e1a8b738ca5421b08e600c4c75e9b8
|
|
| MD5 |
b861f16bf6faf9d7f777cd2118ed537a
|
|
| BLAKE2b-256 |
a14c4ede00549cbbd46bf14f03416e8a5cf6136b7017edd2030b4c497f9f419f
|
Provenance
The following attestation bundles were made for image_classification_tools-0.5.6.tar.gz:
Publisher:
publish-to-pypi.yml on gperdrizet/CIFAR10
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
image_classification_tools-0.5.6.tar.gz -
Subject digest:
cdb3ad534d9e9e193801b50e74cbaff080e1a8b738ca5421b08e600c4c75e9b8 - Sigstore transparency entry: 933741558
- Sigstore integration time:
-
Permalink:
gperdrizet/CIFAR10@ea99f83b0210696763fe38be85a8d3033d8a5093 -
Branch / Tag:
refs/tags/0.5.6 - Owner: https://github.com/gperdrizet
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-to-pypi.yml@ea99f83b0210696763fe38be85a8d3033d8a5093 -
Trigger Event:
release
-
Statement type:
File details
Details for the file image_classification_tools-0.5.6-py3-none-any.whl.
File metadata
- Download URL: image_classification_tools-0.5.6-py3-none-any.whl
- Upload date:
- Size: 20.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
29ee0b37caa8951291687e567cba2810cb3926e1dea86abf7c422a3e8afe6b06
|
|
| MD5 |
bdcc3c767939edaf221857e4af9a0bb0
|
|
| BLAKE2b-256 |
f13533b31963a627bb8ac5766c8cfb9bcc99c43010859df55e03d1abbc0cca04
|
Provenance
The following attestation bundles were made for image_classification_tools-0.5.6-py3-none-any.whl:
Publisher:
publish-to-pypi.yml on gperdrizet/CIFAR10
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
image_classification_tools-0.5.6-py3-none-any.whl -
Subject digest:
29ee0b37caa8951291687e567cba2810cb3926e1dea86abf7c422a3e8afe6b06 - Sigstore transparency entry: 933741590
- Sigstore integration time:
-
Permalink:
gperdrizet/CIFAR10@ea99f83b0210696763fe38be85a8d3033d8a5093 -
Branch / Tag:
refs/tags/0.5.6 - Owner: https://github.com/gperdrizet
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-to-pypi.yml@ea99f83b0210696763fe38be85a8d3033d8a5093 -
Trigger Event:
release
-
Statement type: