A generic, config-driven PyTorch Lightning pipeline for image classification tasks
Project description
CPH Image Classification
A generic, modular, and reusable PyTorch Lightning pipeline for training image classification models using CNNs. This package is fully config-driven, allowing you to train models on any image classification dataset by simply modifying a YAML configuration file.
Installation
pip install cph-imgclassification
Quick Start
1. Install the Package
pip install cph-imgclassification
2. Organize Your Data
Organize your images in class folders:
YourProjectName/data/images/
├── class1/
│ ├── image1.jpg
│ ├── image2.jpg
│ └── ...
├── class2/
│ ├── image1.jpg
│ ├── image2.jpg
│ └── ...
└── class3/
├── image1.jpg
└── ...
3. Create Configuration File
Create config.yaml:
# Image Classification Model Configuration
seed_everything: true
trainer:
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: "{epoch}-{val_acc:.2f}.best"
monitor: "val_acc"
mode: "max"
save_top_k: 1
- class_path: cph_imgclassification.imageclassification.callbacks.ONNXExportCallback
init_args:
output_dir: "models"
model_name: "my_model"
input_shape: [3, 224, 224]
logger:
class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: "lightning_logs"
name: "MyProjectTraining"
max_epochs: 50
accelerator: auto
devices: auto
precision: 16-mixed
model:
class_path: cph_imgclassification.imageclassification.modelmodule.ModelModuleIMG
init_args:
lr: 0.001
weight_decay: 0.0001
model:
class_path: cph_imgclassification.imageclassification.modelfactory.ImageClassificationModel
init_args:
input_channels: 3
num_classes: 0 # Auto-set from datamodule
architecture: "medium" # "simple", "medium", "deep", or "custom"
input_size: [224, 224]
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 0.001
weight_decay: 0.0001
data:
class_path: cph_imgclassification.imageclassification.datamodule.DataModuleIMG
init_args:
data_dir: "YourProjectName/data/images"
image_size: [224, 224]
batch_size: 32
num_workers: 4
val_split: 0.2
random_seed: 42
augmentation:
enabled: true
rotation: 15
horizontal_flip: true
normalization:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
save_preprocessor: true
preprocessor_path: "models/label_encoder.joblib"
fit:
ckpt_path: null
test:
ckpt_path: best
4. Train Your Model
Automatic fit + test (default behavior):
cph-imgclassification --config config.yaml
Or use standard Lightning CLI subcommands:
# Training only
cph-imgclassification fit --config config.yaml
# Testing only
cph-imgclassification test --config config.yaml
# Validation
cph-imgclassification validate --config config.yaml
# Prediction
cph-imgclassification predict --config config.yaml
Features
- Fully Config-Driven: All settings controlled via YAML files
- Generic & Reusable: Use for any image classification task
- Flexible Data Formats: Supports folder-based and CSV-based datasets
- Auto-Dimension Detection: Automatically detects number of classes
- Configurable CNN Architectures: Choose from preset architectures or define custom layers
- Data Augmentation: Built-in augmentation support
- Production-Ready: Exports models to ONNX format
- PyTorch Lightning: Built on PyTorch Lightning for scalable training
- Comprehensive Metrics: Tracks Accuracy, F1-Score, Precision, and Recall
Usage Examples
Training Only
cph-imgclassification fit --config config.yaml
Testing Only
cph-imgclassification test --config config.yaml
Resume Training
cph-imgclassification fit --config config.yaml --fit.ckpt_path path/to/checkpoint.ckpt
Using CSV Data Format
In your config file, use:
data:
class_path: cph_imgclassification.imageclassification.datamodule.DataModuleIMG
init_args:
csv_path: "data/image_labels.csv"
image_path_col: "image_path"
label_col: "label"
# ... other settings
Configuration Reference
Model Architectures
- simple: 2 conv blocks (32, 64 channels) - Good for small datasets
- medium: 4 conv blocks (64, 128, 256, 512 channels) - Balanced performance
- deep: 6 conv blocks (64, 128, 256, 256, 512, 512 channels) - For large datasets
- custom: Define your own conv layers
Data Augmentation Options
rotation: Random rotation degreeshorizontal_flip: Random horizontal flipvertical_flip: Random vertical flipcolor_jitter: Color jitter intensityrandom_crop: Crop size for random cropping
Output Files
After training, you'll find:
- ONNX Model:
models/your_model_name.onnx - Label Encoder:
models/label_encoder.joblib - Checkpoints:
lightning_logs/.../checkpoints/ - TensorBoard Logs:
lightning_logs/
Viewing Training Progress
tensorboard --logdir lightning_logs
Then open http://localhost:6006 in your browser.
Python API
You can also use the package programmatically:
from cph_imgclassification.imageclassification import (
DataModuleIMG,
ImageClassificationModel,
ModelModuleIMG
)
# Create datamodule
datamodule = DataModuleIMG(
data_dir="data/images",
image_size=(224, 224),
batch_size=32
)
# Create model
model = ImageClassificationModel(
input_channels=3,
num_classes=10,
architecture="medium"
)
# Create Lightning module
lightning_model = ModelModuleIMG(model=model, lr=0.001)
Requirements
- Python >= 3.8
- PyTorch >= 2.0.0
- PyTorch Lightning >= 2.1.0
See requirements.txt for full dependency list.
License
MIT License
Author
chandra
Repository
https://github.com/imchandra11/cph-imgclassification
Support
For issues, questions, or contributions, please visit the GitHub repository.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file cph_imgclassification-0.1.3.tar.gz.
File metadata
- Download URL: cph_imgclassification-0.1.3.tar.gz
- Upload date:
- Size: 27.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e0a6b6a4da8460620537b53f9f5c90e35e5b41abb53c725669fbd9db6aab1376
|
|
| MD5 |
7f20ab397b01fe67fdd45be5131e52c7
|
|
| BLAKE2b-256 |
4c55b8aae430ba7e1110b4321b90130d24371f6fcfc1df26f039e8fd20f7699a
|
File details
Details for the file cph_imgclassification-0.1.3-py3-none-any.whl.
File metadata
- Download URL: cph_imgclassification-0.1.3-py3-none-any.whl
- Upload date:
- Size: 24.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3a8f17aece4269bf7e640bc02cd74bbf91a8c19e03a59c4aede3de2933666666
|
|
| MD5 |
a8edd2fbfadfabcbba6bcc7a256be7ca
|
|
| BLAKE2b-256 |
c498ff20abcd3c8bc9bc5fbb1a3bf1014a50fc69d66e3b9065b8de3c1932ec36
|