Skip to main content

A generic, config-driven regression training package using PyTorch Lightning

Project description

CPH Regression

A generic, config-driven regression training package using PyTorch Lightning. Train regression models on any tabular dataset by simply providing a YAML configuration file.

Features

  • Fully Config-Driven: All settings (features, hyperparameters, paths) controlled via YAML files
  • Generic & Reusable: Use the same codebase for any regression task (gemstone prices, house prices, etc.)
  • Auto-Dimension Detection: Automatically calculates input dimensions from feature lists
  • Production-Ready: Exports models to ONNX format with preprocessors for easy deployment
  • PyTorch Lightning: Built on PyTorch Lightning for scalable, professional ML training

Installation

pip install cph-regression

Quick Start

1. Create a Configuration File

Create a config.yaml file:

# Your Project Configuration
seed_everything: true

trainer:
  callbacks:
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        filename: "{epoch}-{val_loss:.2f}.best"
        monitor: "val_loss"
        mode: "min"
        save_top_k: 1
    - class_path: cph_regression.regression.callbacks.ONNXExportCallback
      init_args:
        output_dir: "models"
        model_name: "my_model"
        input_dim: null  # Auto-detected

  logger:
    class_path: lightning.pytorch.loggers.TensorBoardLogger
    init_args:
      save_dir: "lightning_logs"
      name: "MyProjectTraining"

  max_epochs: 50
  accelerator: auto
  devices: auto
  precision: 16-mixed

model:
  class_path: cph_regression.regression.modelmodule.ModelModuleRGS
  init_args:
    lr: 0.0001
    model:
      class_path: cph_regression.regression.modelfactory.RegressionModel
      init_args:
        input_dim: 0  # Auto-set from datamodule
        hidden_layers: [128, 64, 32]
        dropout_rates: [0.15, 0.1, 0.05]
        activation: "relu"

optimizer: 
  class_path: torch.optim.Adam
  init_args:
    lr: 0.001

data:
  class_path: cph_regression.regression.datamodule.DataModuleRGS
  init_args:
    csv_path: "data/your_data.csv"
    batch_size: 256
    val_split: 0.2
    categorical_cols:
      - column1
      - column2
    numeric_cols:
      - column3
      - column4
    target_col: "target"
    save_preprocessor: true
    preprocessor_path: "models/preprocessor.joblib"

fit:
  ckpt_path: null

test:
  ckpt_path: best

2. Run Training

cph-regression --config config.yaml

This will:

  • Train the model
  • Run validation
  • Export the model to ONNX format
  • Save the preprocessor

3. Alternative Commands

Training only:

cph-regression fit --config config.yaml

Testing only:

cph-regression test --config config.yaml

Configuration Guide

Data Configuration

  • csv_path: Path to your CSV file
  • batch_size: Batch size for training (default: 256)
  • val_split: Validation split ratio (0.0 to 1.0, default: 0.2)
  • categorical_cols: List of categorical feature column names
  • numeric_cols: List of numeric feature column names
  • target_col: Name of the target column to predict
  • preprocessor_path: Where to save/load the preprocessor

Model Configuration

  • hidden_layers: List of hidden layer sizes, e.g., [128, 64, 32]
  • dropout_rates: List of dropout rates matching hidden layers, e.g., [0.15, 0.1, 0.05]
  • activation: Activation function ("relu", "tanh", "gelu", "sigmoid", "leaky_relu", "elu")
  • input_dim: Automatically set from datamodule (set to 0 in config)

Trainer Configuration

  • max_epochs: Number of training epochs
  • precision: Training precision ("16-mixed", "32", "bf16-mixed")
  • accelerator: Hardware accelerator ("auto", "gpu", "cpu")
  • devices: Number of devices ("auto", 1, [0, 1])

Output Files

After training, you'll find:

  1. Models Directory (models/):

    • your_model_name.onnx: ONNX model for inference
    • preprocessor.joblib: Fitted preprocessor for data transformation
  2. Checkpoints (lightning_logs/YourProjectTraining/version_X/checkpoints/):

    • epoch-X-val_loss=Y.best.ckpt: Best model checkpoint
    • epoch-X.last.ckpt: Last epoch checkpoint
  3. Training Logs (lightning_logs/):

    • TensorBoard logs for visualization

Model Inference

After training, use the exported ONNX model and preprocessor:

import joblib
import onnxruntime as ort
import numpy as np
import pandas as pd

# Load preprocessor
preprocessor = joblib.load("models/preprocessor.joblib")

# Load ONNX model
session = ort.InferenceSession("models/your_model_name.onnx")

# Prepare input data
input_data = pd.DataFrame({
    'categorical_col': ['value1'],
    'numeric_col': [123.45],
    # ... other features
})

# Transform data
feature_cols = ['categorical_col', 'numeric_col']  # Your feature columns
transformed = preprocessor.transform(input_data[feature_cols])

# Predict
input_name = session.get_inputs()[0].name
output = session.run(None, {input_name: transformed.astype(np.float32)})
prediction = output[0][0][0]

print(f"Prediction: {prediction}")

Viewing Training Progress

TensorBoard

tensorboard --logdir lightning_logs

Then open http://localhost:6006 in your browser.

Example Projects

Gemstone Price Prediction

See the GemstonePricePrediction directory for a complete example.

Requirements

  • Python >= 3.8
  • PyTorch >= 2.0.0
  • PyTorch Lightning >= 2.1.0

See requirements.txt for the complete list of dependencies.

License

MIT License - see LICENSE file for details.

Repository

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Support

For issues or questions:

  1. Check the configuration file syntax
  2. Verify CSV file format and column names
  3. Check TensorBoard logs for training insights
  4. Open an issue on GitHub

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cph_regression-0.1.1.tar.gz (19.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cph_regression-0.1.1-py3-none-any.whl (19.4 kB view details)

Uploaded Python 3

File details

Details for the file cph_regression-0.1.1.tar.gz.

File metadata

  • Download URL: cph_regression-0.1.1.tar.gz
  • Upload date:
  • Size: 19.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.7

File hashes

Hashes for cph_regression-0.1.1.tar.gz
Algorithm Hash digest
SHA256 d9b442691e6f7fcb6c90a395fca8fddcfe83230ac3a76cf2724cb473f7cd218c
MD5 6e010d0541c7f8ca65dc65a8c455e7f5
BLAKE2b-256 db921bddfae94d102f388137e0447b0831e7439a347df87b87528bd066ae1e06

See more details on using hashes here.

File details

Details for the file cph_regression-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: cph_regression-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 19.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.7

File hashes

Hashes for cph_regression-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 401b6b0f7009e4a25a66c9fb954ca3ff5852ab146d40faf0be38d32a425bfafb
MD5 f358a833e90c4af0474c67737db3642c
BLAKE2b-256 9931abb9b82f777d938edfc2b4dbaaa9ef1ef6876e4645a39b9d127dff082f6a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page