StratifiedGroupKFoldRequiresGroups
Project description
StratifiedGroupKFoldRequiresGroups
A small wrapper around scikit-learn's StratifiedGroupKFold that makes the
groups argument mandatory when calling split().
What It Is
StratifiedGroupKFoldRequiresGroups is a subclass of
sklearn.model_selection.StratifiedGroupKFold. It keeps the underlying
cross-validation behavior from scikit-learn, but adds a guardrail: callers must
provide a non-None groups argument to split().
This is useful when grouped splitting is part of the correctness of a model
evaluation workflow. If a pipeline, estimator, or helper forgets to pass
groups, the split should fail immediately instead of silently behaving like a
regular stratified split without group isolation.
Installation
pip install StratifiedGroupKFoldRequiresGroups
The package requires Python 3.8 or newer and depends on scikit-learn.
Usage
import numpy as np
from StratifiedGroupKFoldRequiresGroups import StratifiedGroupKFoldRequiresGroups
X = np.random.randn(9, 5)
y = np.array(["class1", "class2", "class3"] * 3)
groups = np.array([
"group1",
"group2",
"group3",
"group4",
"group5",
"group6",
"group7",
"group7",
"group7",
])
cv = StratifiedGroupKFoldRequiresGroups(n_splits=3)
for train_index, test_index in cv.split(X, y, groups):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
The constructor is inherited from scikit-learn's StratifiedGroupKFold, so use
the same options such as n_splits, shuffle, and random_state.
Important Behavior
cv.split(X, y, groups)delegates tosklearn.model_selection.StratifiedGroupKFold.split().cv.split(X, y)raises aTypeErrorbecausegroupsis a required positional argument.cv.split(X, y, groups=None)raises aValueError.- The wrapper does not change scikit-learn's splitting algorithm; it only enforces that group labels are supplied.
Development
pip install -r requirements_dev.txt
pip install -e .
pytest
The test suite checks that the class is a StratifiedGroupKFold subclass, that
it produces the expected number of splits, and that missing or None groups
fail as intended.
Changelog
0.0.1
- First release on PyPI.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file stratifiedgroupkfoldrequiresgroups-0.0.2.tar.gz.
File metadata
- Download URL: stratifiedgroupkfoldrequiresgroups-0.0.2.tar.gz
- Upload date:
- Size: 5.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.25
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f6e213d5ca37649fa000aca36136d7f602cf9b26533667914129e3a392dde83c
|
|
| MD5 |
854038ea2a12c249b7f66d0c28b1533c
|
|
| BLAKE2b-256 |
8d4cfa6fea8c664e0df26ebcb3f5a4ac54edc7b03cd9d4ffe1c70b32427f4662
|
File details
Details for the file stratifiedgroupkfoldrequiresgroups-0.0.2-py2.py3-none-any.whl.
File metadata
- Download URL: stratifiedgroupkfoldrequiresgroups-0.0.2-py2.py3-none-any.whl
- Upload date:
- Size: 4.4 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.25
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bc8af13ab2948600db7860221c1eba6a598119567b5c7f24050f73c0b4d6b3b6
|
|
| MD5 |
8a1b314009c0be3ebff912b12111d219
|
|
| BLAKE2b-256 |
cdbb15aadd07555b92e583cee43b69bdc1ea03dbe3d6932b901f8ec219a641b3
|