A comprehensive assertion-and-validation toolkit for ML workflows.
Project description
ml-assert
A lightweight, chainable assertion toolkit for validating data and models in ML workflows.
ml-assert is a Python library that provides a fluent, expressive API to act as a guardrail in your automated ML pipelines. It doesn't just calculate metrics; it asserts that your data and models meet specific, mission-critical criteria. If an assertion fails, it fails loudly and immediately, stopping the pipeline to prevent bad models or corrupt data from moving downstream.
This is crucial for building robust, production-ready ML systems where data quality, model performance, and artifact integrity are non-negotiable.
Core Features
- DataFrame Assertions: Validate
pandasDataFrame properties like schema, null values, column uniqueness, value ranges, and set membership. - Statistical Drift Detection: Use low-level statistical tests (Kolmogorov-Smirnov, Chi-Squared, Wasserstein) or a high-level
assert_no_driftfunction to detect changes between datasets. - Model Performance Assertions: Chain assertions for key classification metrics (Accuracy, Precision, Recall, F1, ROC AUC) to ensure your model meets performance targets.
- Extensible Plugin System: Leverage built-in plugins (
file_exists,dvc_check) or create your own to add custom checks. - Declarative CLI: Define your assertion suite in a single
config.yamland run it from the command line, generating JSON and HTML reports.
Installation
pip install ml-assert
How It Works: Assertion vs. Calculation
A typical metrics library might calculate an accuracy of 75% and let the pipeline continue. ml-assert asserts that accuracy must be >= 80%. If it's 75%, it raises an AssertionError, halting execution.
This paradigm shift from passive calculation to active assertion is what makes ml-assert a powerful tool for ML Ops.
Usage Examples
1. DataFrameAssertion DSL
Chain assertions to validate a pandas DataFrame. The chain stops at the first failure.
import pandas as pd
import numpy as np
from ml_assert import Assertion, schema
# DataFrame with a column full of nulls and an out-of-range value
data = {
'user_id': list(range(100, 110)),
'age': [25, 30, 99, 45, 30, 50, 60, 22, 33, 41], # 99 is out of range
'plan_type': ['basic', 'premium', 'basic', 'premium', 'premium', 'basic', 'free', 'free', 'premium', 'basic'],
'empty_col': [np.nan] * 10
}
df = pd.DataFrame(data)
# This check will FAIL because `age` has a value > 70
try:
s = schema()
s.col("user_id").is_unique()
s.col("age").in_range(18, 70)
s.col("plan_type").is_type("object")
Assertion(df).satisfies(s).no_nulls().validate()
except AssertionError as e:
print(f"As expected, validation failed: {e}")
# This check will PASS because we only check specific columns
s2 = schema()
s2.col("user_id").is_unique()
Assertion(df).satisfies(s2).no_nulls(['user_id', 'age', 'plan_type']).validate()
print("Partial validation passed!")
2. High-Level Drift Detection
Detect distributional drift between a reference (training) and current (inference) dataset. assert_no_drift intelligently applies KS tests to numeric columns and Chi-Squared tests to categorical columns.
import pandas as pd
import numpy as np
from ml_assert.stats.drift import assert_no_drift
# Reference dataset
df_ref = pd.DataFrame({
'temperature': np.random.normal(20, 5, 500),
'city': np.random.choice(['NY', 'LA', 'SF'], 500, p=[0.5, 0.3, 0.2])
})
# Current dataset with a deliberate drift
df_cur = pd.DataFrame({
'temperature': np.random.normal(30, 5, 500), # Mean shifted by +10
'city': np.random.choice(['NY', 'LA', 'SF'], 500, p=[0.2, 0.3, 0.5]) # Proportions changed
})
# This will FAIL and identify the drifting column ('temperature').
try:
assert_no_drift(df_ref, df_cur, alpha=0.05)
except AssertionError as e:
print(f"As expected, drift was detected: {e}")
# This will PASS because the data is identical.
assert_no_drift(df_ref, df_ref.copy(), alpha=0.05)
print("No drift detected in identical datasets.")
3. Model Performance Assertions
Ensure your model's predictions meet your minimum quality bar.
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from ml_assert import assert_model
# Generate data and train a simple model
X, y = make_classification(random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
model = LogisticRegression().fit(X_train, y_train)
y_pred = model.predict(X_test)
y_scores = model.predict_proba(X_test)[:, 1]
# Chain assertions for key metrics
# This will PASS if all metrics meet their thresholds.
assert_model(y_test, y_pred, y_scores) \
.accuracy(min_score=0.80) \
.precision(min_score=0.80) \
.recall(min_score=0.80) \
.f1(min_score=0.80) \
.roc_auc(min_score=0.90) \
.validate()
print("All model performance metrics passed!")
4. CLI for Automated Runs
Define a suite of checks in a YAML file and execute it with the ml_assert CLI. This is perfect for CI/CD pipelines.
config.yaml
steps:
- type: drift
train: 'ref.csv'
test: 'cur.csv'
alpha: 0.05
# The CLI run will fail on this step due to drift
- type: model_performance
y_true: 'y_true.csv'
y_pred: 'y_pred.csv'
y_scores: 'y_scores.csv'
assertions:
accuracy: 0.75
roc_auc: 0.80
- type: file_exists
path: 'my_model.pkl'
- type: dvc_check
path: 'model_data.csv'
Run from your terminal:
# poetry run ml_assert run config.yaml
# The command will fail because of the drift, and generate reports.
ml_assert run config.yaml
This command generates two reports:
config.report.json: A machine-readable summary.config.report.html: A human-friendly HTML report.
Cross-Validation Support
A Python library for asserting machine learning model performance using cross-validation.
Features
- Cross-validation support for model evaluation
- Multiple cross-validation strategies:
- K-Fold Cross-Validation
- Stratified K-Fold Cross-Validation
- Leave-One-Out Cross-Validation
- Support for various metrics:
- Accuracy
- Precision
- Recall
- F1 Score
- ROC AUC Score
- Parallel processing for faster cross-validation
- Comprehensive test suite
Installation
pip install -r requirements.txt
Usage
Basic Usage
from sklearn.linear_model import LogisticRegression
from ml_assert.model.cross_validation import assert_cv_accuracy_score
# Create and train your model
model = LogisticRegression()
X, y = your_data # Your features and target variables
# Assert minimum accuracy across cross-validation folds
assert_cv_accuracy_score(model, X, y, min_score=0.85)
Different Cross-Validation Strategies
# K-Fold Cross-Validation (default)
assert_cv_accuracy_score(model, X, y, min_score=0.85, cv_type='kfold', n_splits=5)
# Stratified K-Fold Cross-Validation
assert_cv_accuracy_score(model, X, y, min_score=0.85, cv_type='stratified', n_splits=5)
# Leave-One-Out Cross-Validation
assert_cv_accuracy_score(model, X, y, min_score=0.85, cv_type='loo')
Multiple Metrics
from ml_assert.model.cross_validation import (
assert_cv_accuracy_score,
assert_cv_precision_score,
assert_cv_recall_score,
assert_cv_f1_score,
assert_cv_roc_auc_score,
)
# Assert multiple metrics
assert_cv_accuracy_score(model, X, y, min_score=0.85)
assert_cv_precision_score(model, X, y, min_score=0.80)
assert_cv_recall_score(model, X, y, min_score=0.80)
assert_cv_f1_score(model, X, y, min_score=0.80)
assert_cv_roc_auc_score(model, X, y, min_score=0.80)
Get Cross-Validation Summary
from ml_assert.model.cross_validation import get_cv_summary
# Get summary of all metrics
summary = get_cv_summary(model, X, y)
print(summary)
Running Tests
pytest src/ml_assert/tests/
Contributing
Contributions are welcome! Please see CONTRIBUTING.md for details on how to get started.
License
This project is licensed under the Apache 2.0 License - see the LICENSE file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file ml_assert-1.0.5.tar.gz.
File metadata
- Download URL: ml_assert-1.0.5.tar.gz
- Upload date:
- Size: 30.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9bb6d7f7fffc93f7d46dfb649919e734af7799f12ef7cd72931e0ec65902492c
|
|
| MD5 |
d611b3e3183e181304ecc03077c459a0
|
|
| BLAKE2b-256 |
a1e2601d63807b878945d19af493340220d8a0cd2a44e5b41082c66d3efcda00
|
Provenance
The following attestation bundles were made for ml_assert-1.0.5.tar.gz:
Publisher:
ci.yml on PyBrainn/ml-assert
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
ml_assert-1.0.5.tar.gz -
Subject digest:
9bb6d7f7fffc93f7d46dfb649919e734af7799f12ef7cd72931e0ec65902492c - Sigstore transparency entry: 235387302
- Sigstore integration time:
-
Permalink:
PyBrainn/ml-assert@f16b6c096caeae192c8c8f9c3804cf51c6419fb9 -
Branch / Tag:
refs/tags/v1.0.5 - Owner: https://github.com/PyBrainn
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
ci.yml@f16b6c096caeae192c8c8f9c3804cf51c6419fb9 -
Trigger Event:
push
-
Statement type:
File details
Details for the file ml_assert-1.0.5-py3-none-any.whl.
File metadata
- Download URL: ml_assert-1.0.5-py3-none-any.whl
- Upload date:
- Size: 33.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ffb7ab3d958b6daa743232b291bbeecc059b155332e97a2fec7d131de9a18f18
|
|
| MD5 |
e3fc1486e9e517cef035d93bba47f18c
|
|
| BLAKE2b-256 |
efc49dd30eaa0156638745beeb4491b4416b82909daae3b1984e27d14a001fc1
|
Provenance
The following attestation bundles were made for ml_assert-1.0.5-py3-none-any.whl:
Publisher:
ci.yml on PyBrainn/ml-assert
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
ml_assert-1.0.5-py3-none-any.whl -
Subject digest:
ffb7ab3d958b6daa743232b291bbeecc059b155332e97a2fec7d131de9a18f18 - Sigstore transparency entry: 235387309
- Sigstore integration time:
-
Permalink:
PyBrainn/ml-assert@f16b6c096caeae192c8c8f9c3804cf51c6419fb9 -
Branch / Tag:
refs/tags/v1.0.5 - Owner: https://github.com/PyBrainn
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
ci.yml@f16b6c096caeae192c8c8f9c3804cf51c6419fb9 -
Trigger Event:
push
-
Statement type: