A generic, reusable PyTorch Lightning pipeline for classification tasks
Project description
cph-classification
A generic, reusable PyTorch Lightning pipeline for training classification models on tabular data. This package provides a fully config-driven framework that can be used for any classification task by simply providing a YAML configuration file.
Features
- 🚀 Fully Config-Driven: All settings (features, hyperparameters, paths) controlled via YAML files
- 🔄 Generic & Reusable: Use the same codebase for any classification task (stress levels, sentiment, quality ratings, etc.)
- 🤖 Auto-Dimension Detection: Automatically calculates input dimensions and number of classes from feature lists and target column
- 📊 Categorical Target Support: Automatically handles both integer and categorical string targets (e.g., "good", "better", "best" or "yes", "no")
- 🎯 Production-Ready: Exports models to ONNX format with preprocessors and label encoders for easy deployment
- ⚡ PyTorch Lightning: Built on PyTorch Lightning for scalable, professional ML training
- 📈 Comprehensive Metrics: Tracks Accuracy, F1-Score, Precision, and Recall (macro-averaged)
Installation
Install from PyPI:
pip install cph-classification
Or install from source:
git clone https://github.com/imchandra11/cph-classification.git
cd cph-classification
pip install .
Quick Start
1. Install the Package
pip install cph-classification
2. Prepare Your Data
Create a CSV file with your features and target column. For example, data/myproject.csv:
feature1,feature2,target
value1,123.45,class_a
value2,234.56,class_b
...
3. Create Configuration File
Create a YAML configuration file, e.g., configs/myproject.yaml:
# My Classification Project Configuration
seed_everything: true
trainer:
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: "{epoch}-{val_loss:.2f}.best"
monitor: "val_loss"
mode: "min"
save_top_k: 1
- class_path: cph_classification.classification.callbacks.ONNXExportCallback
init_args:
output_dir: "models"
model_name: "my_model"
input_dim: null # Auto-detected
logger:
class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: "lightning_logs"
name: "MyProjectTraining"
max_epochs: 30
accelerator: auto
devices: auto
precision: 16-mixed
model:
class_path: cph_classification.classification.modelmodule.ModelModuleCLS
init_args:
lr: 0.0001
model:
class_path: cph_classification.classification.modelfactory.ClassificationModel
init_args:
input_dim: 0 # Auto-set from datamodule
num_classes: 0 # Auto-set from datamodule
hidden_layers: [128, 64, 32]
dropout_rates: [0.15, 0.1, 0.05]
activation: "relu"
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 0.001
weight_decay: 0.00001
data:
class_path: cph_classification.classification.datamodule.DataModuleCLS
init_args:
csv_path: "data/myproject.csv"
batch_size: 256
num_workers: 0
val_split: 0.2
random_seed: 42
categorical_cols:
- feature1
numeric_cols:
- feature2
target_col: "target" # Can be integers or categorical strings
save_preprocessor: true
preprocessor_path: "models/preprocessor.joblib"
fit:
ckpt_path: null # Set to checkpoint path for resume training
test:
ckpt_path: best # Use "best" or "last" checkpoint
4. Run Training
Train your model with a single command:
# Train and test (fit+test workflow)
cph-classification --config configs/myproject.yaml
# Or use standard Lightning CLI subcommands
cph-classification fit --config configs/myproject.yaml
cph-classification test --config configs/myproject.yaml
That's it! The model will be trained and saved to the path specified in your config file.
Configuration Guide
Data Configuration
Key Parameters:
csv_path: Path to your CSV filebatch_size: Batch size for training (default: 256)val_split: Validation split ratio (0.0 to 1.0, default: 0.2)categorical_cols: List of categorical feature column namesnumeric_cols: List of numeric feature column namestarget_col: Name of the target column to predict (can be integers or strings)preprocessor_path: Where to save/load the preprocessor
Preprocessing:
- Categorical columns: Automatically one-hot encoded (with
drop='first') - Numeric columns: Automatically standardized using StandardScaler
- Target column:
- If integers: Used as-is (converted to 0-indexed if needed)
- If strings: Automatically encoded to 0-indexed integers using LabelEncoder
Model Configuration
Key Parameters:
hidden_layers: List of hidden layer sizes, e.g.,[128, 64, 32]dropout_rates: List of dropout rates matching hidden layers, e.g.,[0.15, 0.1, 0.05]activation: Activation function ("relu","tanh","gelu","sigmoid","leaky_relu","elu")input_dim: Automatically set from datamodule (set to0in config)num_classes: Automatically set from datamodule (set to0in config)
Output Files
After training, you'll find:
-
Models Directory (
models/):my_model.onnx: ONNX model for inferencepreprocessor.joblib: Fitted preprocessor for data transformationlabel_encoder.joblib: Label encoder (only if target was categorical strings)
-
Checkpoints (
lightning_logs/MyProjectTraining/version_X/checkpoints/):epoch-X-val_loss=Y.best.ckpt: Best model checkpoint (based on validation loss)epoch-X.last.ckpt: Last epoch checkpoint
-
Training Logs (
lightning_logs/):- TensorBoard logs for visualization
Model Inference
After training, use the exported ONNX model for predictions:
import joblib
import onnxruntime as ort
import numpy as np
import pandas as pd
# Load preprocessor
preprocessor = joblib.load("models/preprocessor.joblib")
# Load label encoder (if target was categorical strings)
label_encoder = joblib.load("models/label_encoder.joblib") # Optional
# Load ONNX model
session = ort.InferenceSession("models/my_model.onnx")
# Prepare input data
input_data = pd.DataFrame({
'feature1': ['value1'],
'feature2': [123.45],
})
# Transform data
feature_cols = ['feature1', 'feature2']
transformed = preprocessor.transform(input_data[feature_cols])
# Predict
input_name = session.get_inputs()[0].name
output = session.run(None, {input_name: transformed.astype(np.float32)})
predicted_class_idx = np.argmax(output[0][0])
# Decode back to original label (if label encoder exists)
if label_encoder:
predicted_class = label_encoder.inverse_transform([predicted_class_idx])[0]
print(f"Predicted class: {predicted_class}")
else:
print(f"Predicted class index: {predicted_class_idx}")
Viewing Training Progress
TensorBoard
tensorboard --logdir lightning_logs
Then open http://localhost:6006 in your browser.
Metrics Tracked:
train_loss,val_loss,test_loss: CrossEntropyLosstrain_acc,val_acc,test_acc: Accuracy (macro-averaged)train_f1,val_f1,test_f1: F1-Score (macro-averaged)train_precision,val_precision,test_precision: Precision (macro-averaged)train_recall,val_recall,test_recall: Recall (macro-averaged)
Examples
Example 1: Integer Target Labels
If your target column contains integers (e.g., 1, 2, 3, 4, 5):
data:
init_args:
target_col: "stress_level" # Contains: 1, 2, 3, 4, 5
The pipeline will automatically convert to 0-indexed labels if needed (0, 1, 2, 3, 4).
Example 2: Categorical String Targets
If your target column contains categorical strings (e.g., "low", "medium", "high"):
data:
init_args:
target_col: "quality" # Contains: "low", "medium", "high"
The pipeline will automatically encode to integers (0, 1, 2) and save the label encoder for inference.
Example 3: Multiple Configuration Files
You can use multiple config files for different environments:
# Main config + local overrides
cph-classification --config configs/myproject.yaml --config configs/myproject.local.yaml
The local config will override values from the main config.
Advanced Usage
Resume Training
cph-classification fit \
--config configs/myproject.yaml \
--fit.ckpt_path "lightning_logs/MyProjectTraining/version_0/checkpoints/epoch-10.last.ckpt"
Hyperparameter Tuning
Override hyperparameters via command line or config files:
# myproject.local.yaml
model:
init_args:
lr: 0.0005
data:
init_args:
batch_size: 512
Custom Model Architecture
model:
init_args:
model:
init_args:
hidden_layers: [256, 128, 64, 32] # Deeper network
dropout_rates: [0.2, 0.15, 0.1, 0.05]
activation: "gelu"
Requirements
- Python >= 3.8
- PyTorch >= 2.0.0
- PyTorch Lightning >= 2.1.0
- scikit-learn >= 1.3.0
- Other dependencies are automatically installed with the package
License
MIT License - see LICENSE file for details.
Author
chandra
- Email: chandra385123@gmail.com
- GitHub: @imchandra11
Repository
- GitHub: https://github.com/imchandra11/cph-classification
- PyPI: https://pypi.org/project/cph-classification/
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
Support
For issues or questions:
- Check the configuration file syntax
- Verify CSV file format and column names
- Check target column type (integers or categorical strings)
- Review TensorBoard logs for training insights
- Open an issue on GitHub
Citation
If you use this package in your research, please cite:
@software{cph_classification,
title = {cph-classification: A Generic PyTorch Lightning Pipeline for Classification},
author = {chandra},
year = {2025},
url = {https://github.com/imchandra11/cph-classification}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file cph_classification-0.2.1.tar.gz.
File metadata
- Download URL: cph_classification-0.2.1.tar.gz
- Upload date:
- Size: 23.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b8364a6c1c6678cd1fdbe89bd872116c261e514331f75bda00011eaa459a41d5
|
|
| MD5 |
6a5420c6e0098de102de5bab55c1d08c
|
|
| BLAKE2b-256 |
2ae658c2084a2953e0f116bf864d511e9e47eb8335d099a888a9e77c07527385
|
File details
Details for the file cph_classification-0.2.1-py3-none-any.whl.
File metadata
- Download URL: cph_classification-0.2.1-py3-none-any.whl
- Upload date:
- Size: 23.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
dd4ce2dd72b06353c93800f165c9cfc56e9bf938a74472ef078d1d315e081096
|
|
| MD5 |
fca3c1e51d2b4715711602ca12b69ae0
|
|
| BLAKE2b-256 |
3a6d2c99ec83fc401732ee2a14a439b841ab8f9270d798cabe2b9eb83e37b976
|