Skip to main content

A generic, config-driven PyTorch Lightning pipeline for image classification tasks

Project description

CPH Image Classification

A generic, modular, and reusable PyTorch Lightning pipeline for training image classification models using CNNs. This package is fully config-driven, allowing you to train models on any image classification dataset by simply modifying a YAML configuration file.

Installation

pip install cph-imgclassification

Quick Start

1. Install the Package

pip install cph-imgclassification

2. Organize Your Data

Organize your images in class folders:

YourProjectName/data/images/
├── class1/
│   ├── image1.jpg
│   ├── image2.jpg
│   └── ...
├── class2/
│   ├── image1.jpg
│   ├── image2.jpg
│   └── ...
└── class3/
    ├── image1.jpg
    └── ...

3. Create Configuration File

Create config.yaml:

# Image Classification Model Configuration
seed_everything: true

trainer:
  callbacks:
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        filename: "{epoch}-{val_acc:.2f}.best"
        monitor: "val_acc"
        mode: "max"
        save_top_k: 1
    - class_path: cph_imgclassification.imageclassification.callbacks.ONNXExportCallback
      init_args:
        output_dir: "models"
        model_name: "my_model"
        input_shape: [3, 224, 224]

  logger:
    class_path: lightning.pytorch.loggers.TensorBoardLogger
    init_args:
      save_dir: "lightning_logs"
      name: "MyProjectTraining"

  max_epochs: 50
  accelerator: auto
  devices: auto
  precision: 16-mixed

model:
  class_path: cph_imgclassification.imageclassification.modelmodule.ModelModuleIMG
  init_args:
    lr: 0.001
    weight_decay: 0.0001
    model:
      class_path: cph_imgclassification.imageclassification.modelfactory.ImageClassificationModel
      init_args:
        input_channels: 3
        num_classes: 0  # Auto-set from datamodule
        architecture: "medium"  # "simple", "medium", "deep", or "custom"
        input_size: [224, 224]

optimizer: 
  class_path: torch.optim.Adam
  init_args:
    lr: 0.001
    weight_decay: 0.0001

data:
  class_path: cph_imgclassification.imageclassification.datamodule.DataModuleIMG
  init_args:
    data_dir: "YourProjectName/data/images"
    image_size: [224, 224]
    batch_size: 32
    num_workers: 4
    val_split: 0.2
    random_seed: 42
    augmentation:
      enabled: true
      rotation: 15
      horizontal_flip: true
    normalization:
      mean: [0.485, 0.456, 0.406]
      std: [0.229, 0.224, 0.225]
    save_preprocessor: true
    preprocessor_path: "models/label_encoder.joblib"

fit:
  ckpt_path: null

test:
  ckpt_path: best

4. Train Your Model

Automatic fit + test (default behavior):

cph-imgclassification --config config.yaml

Or use standard Lightning CLI subcommands:

# Training only
cph-imgclassification fit --config config.yaml

# Testing only
cph-imgclassification test --config config.yaml

# Validation
cph-imgclassification validate --config config.yaml

# Prediction
cph-imgclassification predict --config config.yaml

Features

  • Fully Config-Driven: All settings controlled via YAML files
  • Generic & Reusable: Use for any image classification task
  • Flexible Data Formats: Supports folder-based and CSV-based datasets
  • Auto-Dimension Detection: Automatically detects number of classes
  • Configurable CNN Architectures: Choose from preset architectures or define custom layers
  • Data Augmentation: Built-in augmentation support
  • Production-Ready: Exports models to ONNX format
  • PyTorch Lightning: Built on PyTorch Lightning for scalable training
  • Comprehensive Metrics: Tracks Accuracy, F1-Score, Precision, and Recall

Usage Examples

Training Only

cph-imgclassification fit --config config.yaml

Testing Only

cph-imgclassification test --config config.yaml

Resume Training

cph-imgclassification fit --config config.yaml --fit.ckpt_path path/to/checkpoint.ckpt

Using CSV Data Format

In your config file, use:

data:
  class_path: cph_imgclassification.imageclassification.datamodule.DataModuleIMG
  init_args:
    csv_path: "data/image_labels.csv"
    image_path_col: "image_path"
    label_col: "label"
    # ... other settings

Configuration Reference

Model Architectures

  • simple: 2 conv blocks (32, 64 channels) - Good for small datasets
  • medium: 4 conv blocks (64, 128, 256, 512 channels) - Balanced performance
  • deep: 6 conv blocks (64, 128, 256, 256, 512, 512 channels) - For large datasets
  • custom: Define your own conv layers

Data Augmentation Options

  • rotation: Random rotation degrees
  • horizontal_flip: Random horizontal flip
  • vertical_flip: Random vertical flip
  • color_jitter: Color jitter intensity
  • random_crop: Crop size for random cropping

Output Files

After training, you'll find:

  • ONNX Model: models/your_model_name.onnx
  • Label Encoder: models/label_encoder.joblib
  • Checkpoints: lightning_logs/.../checkpoints/
  • TensorBoard Logs: lightning_logs/

Viewing Training Progress

tensorboard --logdir lightning_logs

Then open http://localhost:6006 in your browser.

Python API

You can also use the package programmatically:

from cph_imgclassification.imageclassification import (
    DataModuleIMG,
    ImageClassificationModel,
    ModelModuleIMG
)

# Create datamodule
datamodule = DataModuleIMG(
    data_dir="data/images",
    image_size=(224, 224),
    batch_size=32
)

# Create model
model = ImageClassificationModel(
    input_channels=3,
    num_classes=10,
    architecture="medium"
)

# Create Lightning module
lightning_model = ModelModuleIMG(model=model, lr=0.001)

Requirements

  • Python >= 3.8
  • PyTorch >= 2.0.0
  • PyTorch Lightning >= 2.1.0

See requirements.txt for full dependency list.

License

MIT License

Author

chandra

Repository

https://github.com/imchandra11/cph-imgclassification

Support

For issues, questions, or contributions, please visit the GitHub repository.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cph_imgclassification-0.1.3.tar.gz (27.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cph_imgclassification-0.1.3-py3-none-any.whl (24.4 kB view details)

Uploaded Python 3

File details

Details for the file cph_imgclassification-0.1.3.tar.gz.

File metadata

  • Download URL: cph_imgclassification-0.1.3.tar.gz
  • Upload date:
  • Size: 27.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.0

File hashes

Hashes for cph_imgclassification-0.1.3.tar.gz
Algorithm Hash digest
SHA256 e0a6b6a4da8460620537b53f9f5c90e35e5b41abb53c725669fbd9db6aab1376
MD5 7f20ab397b01fe67fdd45be5131e52c7
BLAKE2b-256 4c55b8aae430ba7e1110b4321b90130d24371f6fcfc1df26f039e8fd20f7699a

See more details on using hashes here.

File details

Details for the file cph_imgclassification-0.1.3-py3-none-any.whl.

File metadata

File hashes

Hashes for cph_imgclassification-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 3a8f17aece4269bf7e640bc02cd74bbf91a8c19e03a59c4aede3de2933666666
MD5 a8edd2fbfadfabcbba6bcc7a256be7ca
BLAKE2b-256 c498ff20abcd3c8bc9bc5fbb1a3bf1014a50fc69d66e3b9065b8de3c1932ec36

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page