Skip to main content

Unified CPU/GPU statistical categorical encoding: leakage-safe target encoding generalized to arbitrary statistics, with one sklearn-compatible API.

Project description

catstat

CI PyPI Python License: MIT

Unified CPU/GPU statistical categorical encoding: leakage-safe target encoding generalized to arbitrary statistics, behind one scikit-learn-compatible API.

Runs on CPU (pandas/numpy) and on GPU (RAPIDS cuDF/CuPy), parity-validated (CPU/GPU allclose). Device-resident pipelines are where the GPU pays off: pass a cuDF DataFrame and the whole encode — factorize, cross-fitting, gather, output — stays on the device (Colab T4: ~2.6× at 100k rows to ~6–13× at 1M–10M rows vs CPU, returning cuDF). For pandas-origin data the host↔device copies eat the win (~1.1× at best), so backend="auto" still resolves to CPU; cuDF input routes to the GPU automatically. See docs/roadmap.md and docs/known_issues.md (KI-020).

import cudf
enc = TargetEncoder(cols=["cat"], stats=["mean", "var"])   # backend="auto"
out = enc.fit_transform(cudf.from_pandas(X), y)            # stays on device; returns cuDF

Install

pip install catstat

Optional extras: catstat[gpu] (RAPIDS cuDF/CuPy, CUDA 12), catstat[polars] (output="polars"), catstat[docs] (API-reference build), catstat[dev] (tests + lint + build).

Quickstart

from catstat import TargetEncoder, CountEncoder, FrequencyEncoder

enc = TargetEncoder(cols="auto", stats=["mean"], smooth="auto", cv=5, random_state=42)
X_train_enc = enc.fit_transform(X_train, y_train)   # out-of-fold (leakage-safe)
X_test_enc  = enc.transform(X_test)                 # full-data encodings for new data

Why catstat

sklearn's TargetEncoder is CPU and mean-only; cuML is GPU-only (RAPIDS-locked, few stats); category_encoders has no internal cross-fitting (leakage risk). catstat is the union: one API, CPU today and GPU when it pays off, generalized statistics, always leakage-safe.

What it encodes

Three encoders over a shared core: TargetEncoder (supervised, cross-fitted) and the unsupervised CountEncoder / FrequencyEncoder. TargetEncoder(stats=[...]) selects the statistics to emit:

stats= entry smoothing target GPU column infix
"mean" m-estimate (fixed) / empirical-Bayes (smooth="auto") regression / binary / multiclass te_mean
"count" unsupervised count
"frequency" unsupervised freq
"var", "std" — (global fallback) regression te_var, te_std
"median", "min", "max" — (global fallback) regression te_median / te_min / te_max
"skew", "kurt" — (global fallback) regression te_skew, te_kurt
"woe" inherits the mean's smoothing (logit-derived) binary woe
("name", callable) — custom (quantiles, IQR, …) — (global fallback) regression CPU only name

Smoothing honesty: only mean/probability statistics are smoothed. Count/frequency get none; order/shape statistics never blend — below min_samples_category (or where undefined) they fall back to the global statistic. (stats=["quantile"] raises with a hint to pass a custom callable such as ("q90", lambda v: np.quantile(v, 0.9)).)

Other knobs: scheme ∈ {kfold, loo, ordered} (cross-fitting for the mean; loo/ordered are mean-only), multi_feature_mode ∈ {independent, combination} (joint group-by), handle_unknown / handle_missing ∈ {value, return_nan, error}, backend ∈ {auto, cpu, gpu}, and output ∈ {auto, numpy, pandas, polars}.

Leakage-safe by design

  • fit_transform(X, y) is out-of-fold: each fold is encoded from its complement, then the encoder refits on the full data for later transform of new rows. fit(X, y).transform(X) on the training set is the leaky path and is documented as such.
  • smooth="auto" variance is computed per fold; folds flow only through random_state (catstat owns fold assignment, so CPU and GPU produce the same encodings — asserted allclose).
  • Deterministic given random_state.

scikit-learn compatibility

BaseEstimator / TransformerMixin; works in Pipeline and ColumnTransformer, supports set_output(transform="pandas"|"polars") and get_feature_names_out. The supported subset of sklearn.utils.estimator_checks.check_estimator is documented and tested (see docs/known_issues.md, KI-012).

API reference

Rendered API docs: https://matapanino.github.io/catstat/ (built with pdoc; see scripts/build_docs.sh).

Develop

pip install -e ".[dev]"
bash scripts/check.sh        # ruff + pytest + examples (the green gate)
PYTHONPATH=src python3 -m pytest tests/ -q
PYTHONPATH=src python3 -m benchmarks.run_benchmarks --size small --backend cpu --reps 5 \
    --out benchmarks/results/run.json

See CLAUDE.md for the development rules and docs/ for the design.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

catstat-0.5.0.tar.gz (83.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

catstat-0.5.0-py3-none-any.whl (59.0 kB view details)

Uploaded Python 3

File details

Details for the file catstat-0.5.0.tar.gz.

File metadata

  • Download URL: catstat-0.5.0.tar.gz
  • Upload date:
  • Size: 83.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for catstat-0.5.0.tar.gz
Algorithm Hash digest
SHA256 079efd3d3827ce8f13d826c601a985efed5a804598e50fc89963bb27372c814d
MD5 ce545fc760033fd4c7067cc721f8723c
BLAKE2b-256 e4f2ecab07317674b04853b2fa8ee5f811d7e32526d33ff91ec86f486071e36c

See more details on using hashes here.

Provenance

The following attestation bundles were made for catstat-0.5.0.tar.gz:

Publisher: release.yml on Matapanino/catstat

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file catstat-0.5.0-py3-none-any.whl.

File metadata

  • Download URL: catstat-0.5.0-py3-none-any.whl
  • Upload date:
  • Size: 59.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for catstat-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0e4c485dce68c0df72398584137de6d6e7d09cdb72163cd4c23b6b42dc6cb198
MD5 db872c24fbb36271c32d63098a7fb266
BLAKE2b-256 933f7c12dd13c6600fcd187e4274ad6cf61795e4c3d410fd6a4e1ba926ac6229

See more details on using hashes here.

Provenance

The following attestation bundles were made for catstat-0.5.0-py3-none-any.whl:

Publisher: release.yml on Matapanino/catstat

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page