Skip to main content

End-to-end interpretable binary-classification pipeline

Project description

crystal-ml

An end-to-end interpretable binary‐classification pipeline.
crystal-ml provides configurable data ingestion, model training (SVM, Balanced Random Forest, XGBoost, AutoGluon),
SVM‐based downsampling algorithm, supervised discretization (FCCA), and optimal decision‐tree induction (GOSDT).


🚀 Features

  • Data ingestion from CSV/XLSX, with train/test split or pre-split datasets
  • Balanced Random Forest, SVM, XGBoost, and AutoGluon model training with hyperparameter search
  • SVM‐based undersampling: identify “free” support vectors for downsampling training set (+ validation)
  • FCCA discretization
  • GOSDT (Generalized and Scalable Optimal Sparse Decision Trees) for interpretable optimal decision tree
  • Fully YAML‐driven configuration:

🔗 Official Documentations


🛠️ Prerequisites

  • Python 3.73.12 (recommended 3.10)
  • git, pip, and optionally conda
  • An active Gurobi Licence is needed to run the code (specifically, to execute the FCCA discretization procedure)

📦 Installation

From PyPI

# (Optional) Create & activate a fresh conda env with Python 3.10
conda create -n crystal_ml python=3.10 -y
conda activate crystal_ml

# Install
pip install crystal_ml_pipeline

From source

git clone https://github.com/yourusername/crystal-ml.git
cd crystal-ml
pip install .

🎯 Quickstart

1. Create a script, e.g. run.py:

from crystal_ml.pipeline import run_pipeline

if __name__ == "__main__":
    run_pipeline("config.yaml")

2. Prepare config.yaml and place your train/test files alongside.

All pipeline options live in a single config.yaml at your project root. Copy the template in the repo (text) and tweak sections as needed (see section "Configuration of Pipeline’s Parameters" for more details).

3. Execute:

python run.py (alternatively, use your favourite IDE to run the project)

4. Inspect the logs/ folder for:

  • Excel reports (*_Performance.xlsx, *_Results.xlsx)
  • Pickled objects (.pkl)
  • PNG charts (*.png)
  • Optimal tree tree diagrams

Configuration of Pipeline’s Parameters

All pipeline parameters are configured through a single YAML file named config.yaml, organized into sections corresponding to the pipeline sections. Here, we will not detail every individual parameter, as many of them—particularly those related to base models and external algorithms—are already thoroughly described in their official documentations:

Below is a concise overview of the main configuration options, following the structure of the YAML file:

Starting Dataset (Data_Ingestion)

  • enable: enables or disables this phase. Must be enabled if pre-processed data (already discretized for GOSDT) is not provided.
  • input data paths: file paths to either the complete dataset or pre-split training and testing datasets.
  • target_column: name of the binary target variable to predict (e.g., y720).
  • train/test split params: parameters used for splitting the dataset into training and testing subsets (see the official scikit-learn docs for details).

Base Models

This section includes the four base models (BRF, XGBoost, SVM, AutoGluon), each configurable through:

  • enabled: enables or disables the execution of the specific model.
  • output_dir: directory where the model’s performance metrics and results are saved.
  • search params: parameters used in hyperparameter optimization via cross-validation (BRF, XGB, SVM), or more generally for selecting the optimal model configuration (see the official docs).

SVM-based Undersampling Algorithm

This section contains the parameters to configure the SVM-based downsampling procedure, aimed at reducing the size of the training dataset:

  • SVM_Downsampling

    • enabled: enables or disables the downsampling algorithm.
    • output_dir: directory for results, including the undersampled dataset (saved with pickle).
    • CV search params: parameters for SVM hyperparameter search (see official scikit-learn docs).
    • n_free_models: number of SVM models used to select support vectors (lower values yield smaller datasets).
    • save_output / load_saved_output: whether to save/load undersampled datasets (using pickle), preventing repeated downsampling runs.
    • percentage_performance_drop_threshold: threshold percentage drop in model performance that triggers a user warning.
    • percentage_performance_drop_metric: metric chosen by the user (Accuracy, Recall, Precision, f1, or f2) to evaluate performance drop—using BRF as reference, comparing training metrics before vs. after downsampling.
  • Undersampling Performance Assessment (BRF, XGB, AutoGluon)
    Parameters analogous to those in the Base Models section are employed to estimate the effectiveness of the undersampled dataset by retraining the base models (excluding SVM) and assessing their performance.

Data Discretization

This section handles the discretization of continuous features, required for GOSDT:

  • BRF_FCCA
    Parameters (same structure as the BRF in Base Models) used for configuring both the Balanced Random Forest model employed by FCCA to identify discretization thresholds, and the BRF models trained during each FCCA iteration to evaluate predictive performance on datasets discretized according to each parameter combination. BRF results are saved into subfolders named by their parameter settings.
  • FCCA
    • enabled: enables or disables the discretization step.
    • output_dir: directory where FCCA generates its results—one subfolder (named by parameter combo) per tested configuration, containing the discretized datasets.
    • Additional FCCA-specific parameters (e.g., lambda0_values, p0_values, tao_q_values), detailed in the official FCCA documentation and paper.

This stage also produces two visual plots to help users select the optimal trade-off between data compression and information loss:

  • Compression Rate vs. Inconsistency Rate across all parameter combinations
  • Balanced RF performance on each discretized dataset

Interpretable Models

This final stage generates interpretable optimal decision trees using GOSDT on the FCCA-discretized data:

  • enabled: enables or disables this step.
  • input_dir: path to the directory containing the FCCA output files (x_train_discr.xlsx, y_train_discr.xlsx, x_test_discr.xlsx, y_test_discr.xlsx).
  • output_dir: directory where GOSDT saves model performance metrics and the optimal tree plot.
  • Additional GOSDT-specific parameters are described in the official GOSDT documentation.

📄 License

crystal_ml_pipeline is released under the MIT License. See LICENSE for details.

Built with ❤️ by Raffaele Mariosa. PyPI: https://pypi.org/project/crystal-ml-pipeline/

For bug reports or feature suggestions, feel free to drop me a line at raffaele.mariosa@uniroma1.it.


Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

crystal_ml_pipeline-0.3.tar.gz (24.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

crystal_ml_pipeline-0.3-py3-none-any.whl (22.2 kB view details)

Uploaded Python 3

File details

Details for the file crystal_ml_pipeline-0.3.tar.gz.

File metadata

  • Download URL: crystal_ml_pipeline-0.3.tar.gz
  • Upload date:
  • Size: 24.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.16

File hashes

Hashes for crystal_ml_pipeline-0.3.tar.gz
Algorithm Hash digest
SHA256 057917ec3198fa7d9b06592c6018c6a02fa9d250712a81bec63dcb5a3a3895c0
MD5 5253bb62bcf34233d73db371a38e89b5
BLAKE2b-256 3c6e2554b7cf656383e52ce7141f3098cfe777b18cff80276575a4d23510c1e1

See more details on using hashes here.

File details

Details for the file crystal_ml_pipeline-0.3-py3-none-any.whl.

File metadata

File hashes

Hashes for crystal_ml_pipeline-0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 6b3db6df152b9c6b2cbfe25a44fa69b212799ec922c00780e3d6f5940207e6d3
MD5 2f378eb6034100886b158c8a789d5cb6
BLAKE2b-256 4c295a248ee6794341102b6a7aecbc6a94976a77b53aa0097c0579f9d12e2e31

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page