Skip to main content

StratifiedGroupKFoldRequiresGroups

Project description

StratifiedGroupKFoldRequiresGroups

CI

A small wrapper around scikit-learn's StratifiedGroupKFold that makes the groups argument mandatory when calling split().

What It Is

StratifiedGroupKFoldRequiresGroups is a subclass of sklearn.model_selection.StratifiedGroupKFold. It keeps the underlying cross-validation behavior from scikit-learn, but adds a guardrail: callers must provide a non-None groups argument to split().

This is useful when grouped splitting is part of the correctness of a model evaluation workflow. If a pipeline, estimator, or helper forgets to pass groups, the split should fail immediately instead of silently behaving like a regular stratified split without group isolation.

Installation

pip install StratifiedGroupKFoldRequiresGroups

The package requires Python 3.8 or newer and depends on scikit-learn.

Usage

import numpy as np
from StratifiedGroupKFoldRequiresGroups import StratifiedGroupKFoldRequiresGroups

X = np.random.randn(9, 5)
y = np.array(["class1", "class2", "class3"] * 3)
groups = np.array([
    "group1",
    "group2",
    "group3",
    "group4",
    "group5",
    "group6",
    "group7",
    "group7",
    "group7",
])

cv = StratifiedGroupKFoldRequiresGroups(n_splits=3)

for train_index, test_index in cv.split(X, y, groups):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]

The constructor is inherited from scikit-learn's StratifiedGroupKFold, so use the same options such as n_splits, shuffle, and random_state.

Important Behavior

  • cv.split(X, y, groups) delegates to sklearn.model_selection.StratifiedGroupKFold.split().
  • cv.split(X, y) raises a TypeError because groups is a required positional argument.
  • cv.split(X, y, groups=None) raises a ValueError.
  • The wrapper does not change scikit-learn's splitting algorithm; it only enforces that group labels are supplied.

Development

pip install -r requirements_dev.txt
pip install -e .
pytest

The test suite checks that the class is a StratifiedGroupKFold subclass, that it produces the expected number of splits, and that missing or None groups fail as intended.

Changelog

0.0.1

  • First release on PyPI.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stratifiedgroupkfoldrequiresgroups-0.0.2.tar.gz (5.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

stratifiedgroupkfoldrequiresgroups-0.0.2-py2.py3-none-any.whl (4.4 kB view details)

Uploaded Python 2Python 3

File details

Details for the file stratifiedgroupkfoldrequiresgroups-0.0.2.tar.gz.

File metadata

File hashes

Hashes for stratifiedgroupkfoldrequiresgroups-0.0.2.tar.gz
Algorithm Hash digest
SHA256 f6e213d5ca37649fa000aca36136d7f602cf9b26533667914129e3a392dde83c
MD5 854038ea2a12c249b7f66d0c28b1533c
BLAKE2b-256 8d4cfa6fea8c664e0df26ebcb3f5a4ac54edc7b03cd9d4ffe1c70b32427f4662

See more details on using hashes here.

File details

Details for the file stratifiedgroupkfoldrequiresgroups-0.0.2-py2.py3-none-any.whl.

File metadata

File hashes

Hashes for stratifiedgroupkfoldrequiresgroups-0.0.2-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 bc8af13ab2948600db7860221c1eba6a598119567b5c7f24050f73c0b4d6b3b6
MD5 8a1b314009c0be3ebff912b12111d219
BLAKE2b-256 cdbb15aadd07555b92e583cee43b69bdc1ea03dbe3d6932b901f8ec219a641b3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page