Skip to main content

A generic, config-driven PyTorch Lightning pipeline for image classification tasks

Project description

CPH Image Classification

A generic, modular, and reusable PyTorch Lightning pipeline for training image classification models using CNNs. This package is fully config-driven, allowing you to train models on any image classification dataset by simply modifying a YAML configuration file.

Installation

pip install cph-imgclassification

Quick Start

1. Install the Package

pip install cph-imgclassification

2. Organize Your Data

Organize your images in class folders:

YourProjectName/data/images/
├── class1/
│   ├── image1.jpg
│   ├── image2.jpg
│   └── ...
├── class2/
│   ├── image1.jpg
│   ├── image2.jpg
│   └── ...
└── class3/
    ├── image1.jpg
    └── ...

3. Create Configuration File

Create config.yaml:

# Image Classification Model Configuration
seed_everything: true

trainer:
  callbacks:
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        filename: "{epoch}-{val_acc:.2f}.best"
        monitor: "val_acc"
        mode: "max"
        save_top_k: 1
    - class_path: cph_imgclassification.imageclassification.callbacks.ONNXExportCallback
      init_args:
        output_dir: "models"
        model_name: "my_model"
        input_shape: [3, 224, 224]

  logger:
    class_path: lightning.pytorch.loggers.TensorBoardLogger
    init_args:
      save_dir: "lightning_logs"
      name: "MyProjectTraining"

  max_epochs: 50
  accelerator: auto
  devices: auto
  precision: 16-mixed

model:
  class_path: cph_imgclassification.imageclassification.modelmodule.ModelModuleIMG
  init_args:
    lr: 0.001
    weight_decay: 0.0001
    model:
      class_path: cph_imgclassification.imageclassification.modelfactory.ImageClassificationModel
      init_args:
        input_channels: 3
        num_classes: 0  # Auto-set from datamodule
        architecture: "medium"  # "simple", "medium", "deep", or "custom"
        input_size: [224, 224]

optimizer: 
  class_path: torch.optim.Adam
  init_args:
    lr: 0.001
    weight_decay: 0.0001

data:
  class_path: cph_imgclassification.imageclassification.datamodule.DataModuleIMG
  init_args:
    data_dir: "YourProjectName/data/images"
    image_size: [224, 224]
    batch_size: 32
    num_workers: 4
    val_split: 0.2
    random_seed: 42
    augmentation:
      enabled: true
      rotation: 15
      horizontal_flip: true
    normalization:
      mean: [0.485, 0.456, 0.406]
      std: [0.229, 0.224, 0.225]
    save_preprocessor: true
    preprocessor_path: "models/label_encoder.joblib"

fit:
  ckpt_path: null

test:
  ckpt_path: best

4. Train Your Model

cph-imgclassification fit --config config.yaml

Or for fit + test workflow:

cph-imgclassification fit --config config.yaml
cph-imgclassification test --config config.yaml

Features

  • Fully Config-Driven: All settings controlled via YAML files
  • Generic & Reusable: Use for any image classification task
  • Flexible Data Formats: Supports folder-based and CSV-based datasets
  • Auto-Dimension Detection: Automatically detects number of classes
  • Configurable CNN Architectures: Choose from preset architectures or define custom layers
  • Data Augmentation: Built-in augmentation support
  • Production-Ready: Exports models to ONNX format
  • PyTorch Lightning: Built on PyTorch Lightning for scalable training
  • Comprehensive Metrics: Tracks Accuracy, F1-Score, Precision, and Recall

Usage Examples

Training Only

cph-imgclassification fit --config config.yaml

Testing Only

cph-imgclassification test --config config.yaml

Resume Training

cph-imgclassification fit --config config.yaml --fit.ckpt_path path/to/checkpoint.ckpt

Using CSV Data Format

In your config file, use:

data:
  class_path: cph_imgclassification.imageclassification.datamodule.DataModuleIMG
  init_args:
    csv_path: "data/image_labels.csv"
    image_path_col: "image_path"
    label_col: "label"
    # ... other settings

Configuration Reference

Model Architectures

  • simple: 2 conv blocks (32, 64 channels) - Good for small datasets
  • medium: 4 conv blocks (64, 128, 256, 512 channels) - Balanced performance
  • deep: 6 conv blocks (64, 128, 256, 256, 512, 512 channels) - For large datasets
  • custom: Define your own conv layers

Data Augmentation Options

  • rotation: Random rotation degrees
  • horizontal_flip: Random horizontal flip
  • vertical_flip: Random vertical flip
  • color_jitter: Color jitter intensity
  • random_crop: Crop size for random cropping

Output Files

After training, you'll find:

  • ONNX Model: models/your_model_name.onnx
  • Label Encoder: models/label_encoder.joblib
  • Checkpoints: lightning_logs/.../checkpoints/
  • TensorBoard Logs: lightning_logs/

Viewing Training Progress

tensorboard --logdir lightning_logs

Then open http://localhost:6006 in your browser.

Python API

You can also use the package programmatically:

from cph_imgclassification.imageclassification import (
    DataModuleIMG,
    ImageClassificationModel,
    ModelModuleIMG
)

# Create datamodule
datamodule = DataModuleIMG(
    data_dir="data/images",
    image_size=(224, 224),
    batch_size=32
)

# Create model
model = ImageClassificationModel(
    input_channels=3,
    num_classes=10,
    architecture="medium"
)

# Create Lightning module
lightning_model = ModelModuleIMG(model=model, lr=0.001)

Requirements

  • Python >= 3.8
  • PyTorch >= 2.0.0
  • PyTorch Lightning >= 2.1.0

See requirements.txt for full dependency list.

License

MIT License

Author

chandra

Repository

https://github.com/imchandra11/cph-imgclassification

Support

For issues, questions, or contributions, please visit the GitHub repository.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cph_imgclassification-0.1.0.tar.gz (21.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cph_imgclassification-0.1.0-py3-none-any.whl (22.6 kB view details)

Uploaded Python 3

File details

Details for the file cph_imgclassification-0.1.0.tar.gz.

File metadata

  • Download URL: cph_imgclassification-0.1.0.tar.gz
  • Upload date:
  • Size: 21.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.0

File hashes

Hashes for cph_imgclassification-0.1.0.tar.gz
Algorithm Hash digest
SHA256 6aebe1b752684df98bb48732409bc654f57d89c3083ab3dc984d75b3b9bf1df0
MD5 af82306e72c22e6fe76547aba8185dc8
BLAKE2b-256 27126ce8538b53705375183ff2c0e1703e03bede5bb87b95c65dd76b18ffd70a

See more details on using hashes here.

File details

Details for the file cph_imgclassification-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for cph_imgclassification-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 26b4f6476d3b859e4c5ea8dc3ab8cd5f114bac92ad2d001406a9d4ca37016682
MD5 df2f61a8a5b37451808ae4e6d031e1e2
BLAKE2b-256 bbd93f7c2cce6fe4c03a2c7aa56027f43c8c56a61833bbfa3d634e745cd1e87c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page