Skip to main content

Unified CPU/GPU statistical categorical encoding: leakage-safe target encoding generalized to arbitrary statistics, with one sklearn-compatible API.

Project description

catstat

CI PyPI Python License: MIT

Unified CPU/GPU statistical categorical encoding: leakage-safe target encoding generalized to arbitrary statistics, behind one scikit-learn-compatible API.

Runs on CPU (pandas/numpy) and on GPU (RAPIDS cuDF/CuPy), parity-validated (CPU/GPU allclose). Device-resident pipelines are where the GPU pays off: pass a cuDF DataFrame and the whole encode — factorize, cross-fitting, gather, output — stays on the device (Colab T4: ~2.6× at 100k rows to ~6–13× at 1M–10M rows vs CPU, returning cuDF). For pandas-origin data the host↔device copies eat the win (~1.1× at best), so backend="auto" still resolves to CPU; cuDF input routes to the GPU automatically. See docs/roadmap.md and docs/known_issues.md (KI-020).

import cudf
enc = TargetEncoder(cols=["cat"], stats=["mean", "var"])   # backend="auto"
out = enc.fit_transform(cudf.from_pandas(X), y)            # stays on device; returns cuDF

Install

pip install catstat

Optional extras: catstat[gpu] (RAPIDS cuDF/CuPy, CUDA 12), catstat[polars] (output="polars"), catstat[docs] (API-reference build), catstat[dev] (tests + lint + build).

Quickstart

from catstat import TargetEncoder, CountEncoder, FrequencyEncoder

enc = TargetEncoder(cols="auto", stats=["mean"], smooth="auto", cv=5, random_state=42)
X_train_enc = enc.fit_transform(X_train, y_train)   # out-of-fold (leakage-safe)
X_test_enc  = enc.transform(X_test)                 # full-data encodings for new data

Why catstat

sklearn's TargetEncoder is CPU and mean-only; cuML is GPU-only (RAPIDS-locked, few stats); category_encoders has no internal cross-fitting (leakage risk). catstat is the union: one API, CPU today and GPU when it pays off, generalized statistics, always leakage-safe.

What it encodes

Three encoders over a shared core: TargetEncoder (supervised, cross-fitted) and the unsupervised CountEncoder / FrequencyEncoder. TargetEncoder(stats=[...]) selects the statistics to emit:

stats= entry smoothing target GPU column infix
"mean" m-estimate (fixed) / empirical-Bayes ("auto", = sklearn's formula) / "sigmoid" (category_encoders') regression / binary / multiclass te_mean
"count" unsupervised count
"frequency" unsupervised freq
"var", "std" — (global fallback) regression te_var, te_std
"median", "min", "max" — (global fallback) regression te_median / te_min / te_max
"skew", "kurt" — (global fallback) regression te_skew, te_kurt
"woe" inherits the mean's smoothing (logit-derived) binary woe
("name", callable) — custom (quantiles, IQR, …) — (global fallback) regression CPU only name

Smoothing honesty: only mean/probability statistics are smoothed — smooth accepts a fixed m, "auto" (empirical-Bayes, verified exactly sklearn's TargetEncoder(smooth="auto") formula), or "sigmoid" / ("sigmoid", k, f) (category_encoders' blend). Counts get none; frequencies optionally take Laplace add-α (laplace_alpha on CountEncoder/ FrequencyEncoder: unseen → α/(n+αK) instead of 0). Order/shape statistics never blend — below min_samples_category (or where undefined) they fall back to the global statistic. (stats=["quantile"] raises with a hint to pass a custom callable such as ("q90", lambda v: np.quantile(v, 0.9)).)

Other knobs: scheme ∈ {kfold, loo, ordered} (cross-fitting for the mean; loo/ordered are mean-only), multi_feature_mode ∈ {independent, combination} + interactions=[[...]] (joint group-bys), max_classes (cap the multiclass one-vs-rest expansion to the most frequent classes), handle_unknown / handle_missing ∈ {value, return_nan, error}, backend ∈ {auto, cpu, gpu}, and output ∈ {auto, numpy, pandas, polars, cudf}.

Leakage-safe by design

  • fit_transform(X, y) is out-of-fold: each fold is encoded from its complement, then the encoder refits on the full data for later transform of new rows. fit(X, y).transform(X) on the training set is the leaky path and is documented as such.
  • smooth="auto" variance is computed per fold; folds flow only through random_state (catstat owns fold assignment, so CPU and GPU produce the same encodings — asserted allclose).
  • Deterministic given random_state.

scikit-learn compatibility

BaseEstimator / TransformerMixin; works in Pipeline and ColumnTransformer, supports set_output(transform="pandas"|"polars") and get_feature_names_out. The supported subset of sklearn.utils.estimator_checks.check_estimator is documented and tested (see docs/known_issues.md, KI-012).

API reference

Rendered API docs: https://matapanino.github.io/catstat/ (built with pdoc; see scripts/build_docs.sh).

Develop

pip install -e ".[dev]"
bash scripts/check.sh        # ruff + pytest + examples (the green gate)
PYTHONPATH=src python3 -m pytest tests/ -q
PYTHONPATH=src python3 -m benchmarks.run_benchmarks --size small --backend cpu --reps 5 \
    --out benchmarks/results/run.json

See CLAUDE.md for the development rules and docs/ for the design.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

catstat-0.5.2.tar.gz (91.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

catstat-0.5.2-py3-none-any.whl (62.6 kB view details)

Uploaded Python 3

File details

Details for the file catstat-0.5.2.tar.gz.

File metadata

  • Download URL: catstat-0.5.2.tar.gz
  • Upload date:
  • Size: 91.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for catstat-0.5.2.tar.gz
Algorithm Hash digest
SHA256 bfefeba99cfd9f5799cfcd0270e5ba725ce03e64b2667f89c0b487340f783bc4
MD5 316b731e2843236e0add13d080544cfc
BLAKE2b-256 28dcfbcb05ca38eaf6da5f167c232e16009d0bedb86a2f04cec9a08975c0cfa1

See more details on using hashes here.

Provenance

The following attestation bundles were made for catstat-0.5.2.tar.gz:

Publisher: release.yml on Matapanino/catstat

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file catstat-0.5.2-py3-none-any.whl.

File metadata

  • Download URL: catstat-0.5.2-py3-none-any.whl
  • Upload date:
  • Size: 62.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for catstat-0.5.2-py3-none-any.whl
Algorithm Hash digest
SHA256 e89fdc38a7bb2778fa75e828e2c6d79bb06401da6689e2ce50035f68e2c47e9b
MD5 40b0738c6695f7469955e83e44432521
BLAKE2b-256 6ea893f7bfba200f1591eec90c717351116984ae081e50e20935e1322c252990

See more details on using hashes here.

Provenance

The following attestation bundles were made for catstat-0.5.2-py3-none-any.whl:

Publisher: release.yml on Matapanino/catstat

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page