Skip to main content

Unified CPU/GPU statistical categorical encoding: leakage-safe target encoding generalized to arbitrary statistics, with one sklearn-compatible API.

Project description

catstat

CI PyPI Python License: MIT

Unified CPU/GPU statistical categorical encoding: leakage-safe target encoding generalized to arbitrary statistics, behind one scikit-learn-compatible API.

Runs on CPU (pandas/numpy) and on GPU (RAPIDS cuDF/CuPy), parity-validated (CPU/GPU allclose). Device-resident pipelines are where the GPU pays off: pass a cuDF DataFrame and the whole encode — factorize, cross-fitting, gather, output — stays on the device (Colab T4: ~2.6× at 100k rows to ~6–13× at 1M–10M rows vs CPU, returning cuDF). For pandas-origin data the host↔device copies eat the win (~1.1× at best), so backend="auto" still resolves to CPU; cuDF input routes to the GPU automatically. See docs/roadmap.md and docs/known_issues.md (KI-020).

import cudf
enc = TargetEncoder(cols=["cat"], stats=["mean", "var"])   # backend="auto"
out = enc.fit_transform(cudf.from_pandas(X), y)            # stays on device; returns cuDF

Install

pip install catstat

Optional extras: catstat[gpu] (RAPIDS cuDF/CuPy, CUDA 12), catstat[polars] (output="polars"), catstat[docs] (API-reference build), catstat[dev] (tests + lint + build).

Quickstart

from catstat import TargetEncoder, CountEncoder, FrequencyEncoder

enc = TargetEncoder(cols="auto", stats=["mean"], smooth="auto", cv=5, random_state=42)
X_train_enc = enc.fit_transform(X_train, y_train)   # out-of-fold (leakage-safe)
X_test_enc  = enc.transform(X_test)                 # full-data encodings for new data

Why catstat

sklearn's TargetEncoder is CPU and mean-only; cuML is GPU-only (RAPIDS-locked, few stats); category_encoders has no internal cross-fitting (leakage risk). catstat is the union: one API, CPU today and GPU when it pays off, generalized statistics, always leakage-safe.

What it encodes

Three encoders over a shared core: TargetEncoder (supervised, cross-fitted) and the unsupervised CountEncoder / FrequencyEncoder. TargetEncoder(stats=[...]) selects the statistics to emit:

stats= entry smoothing target GPU column infix
"mean" m-estimate (fixed) / empirical-Bayes (smooth="auto") regression / binary / multiclass te_mean
"count" unsupervised count
"frequency" unsupervised freq
"var", "std" — (global fallback) regression te_var, te_std
"median", "min", "max" — (global fallback) regression te_median / te_min / te_max
"skew", "kurt" — (global fallback) regression te_skew, te_kurt
"woe" inherits the mean's smoothing (logit-derived) binary woe
("name", callable) — custom (quantiles, IQR, …) — (global fallback) regression CPU only name

Smoothing honesty: only mean/probability statistics are smoothed. Count/frequency get none; order/shape statistics never blend — below min_samples_category (or where undefined) they fall back to the global statistic. (stats=["quantile"] raises with a hint to pass a custom callable such as ("q90", lambda v: np.quantile(v, 0.9)).)

Other knobs: scheme ∈ {kfold, loo, ordered} (cross-fitting for the mean; loo/ordered are mean-only), multi_feature_mode ∈ {independent, combination} (joint group-by), handle_unknown / handle_missing ∈ {value, return_nan, error}, backend ∈ {auto, cpu, gpu}, and output ∈ {auto, numpy, pandas, polars}.

Leakage-safe by design

  • fit_transform(X, y) is out-of-fold: each fold is encoded from its complement, then the encoder refits on the full data for later transform of new rows. fit(X, y).transform(X) on the training set is the leaky path and is documented as such.
  • smooth="auto" variance is computed per fold; folds flow only through random_state (catstat owns fold assignment, so CPU and GPU produce the same encodings — asserted allclose).
  • Deterministic given random_state.

scikit-learn compatibility

BaseEstimator / TransformerMixin; works in Pipeline and ColumnTransformer, supports set_output(transform="pandas"|"polars") and get_feature_names_out. The supported subset of sklearn.utils.estimator_checks.check_estimator is documented and tested (see docs/known_issues.md, KI-012).

API reference

Rendered API docs: https://matapanino.github.io/catstat/ (built with pdoc; see scripts/build_docs.sh).

Develop

pip install -e ".[dev]"
bash scripts/check.sh        # ruff + pytest + examples (the green gate)
PYTHONPATH=src python3 -m pytest tests/ -q
PYTHONPATH=src python3 -m benchmarks.run_benchmarks --size small --backend cpu --reps 5 \
    --out benchmarks/results/run.json

See CLAUDE.md for the development rules and docs/ for the design.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

catstat-0.5.1.tar.gz (89.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

catstat-0.5.1-py3-none-any.whl (62.0 kB view details)

Uploaded Python 3

File details

Details for the file catstat-0.5.1.tar.gz.

File metadata

  • Download URL: catstat-0.5.1.tar.gz
  • Upload date:
  • Size: 89.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for catstat-0.5.1.tar.gz
Algorithm Hash digest
SHA256 b34abae2b14a1c35e140441655a2f8f4cc3e37fc9304b4b5bb6ca2c43b655a42
MD5 a45072923924efe8ccab58852f8ccc77
BLAKE2b-256 c9ae6a4e9feae113c515046b108120158f45af736453cb8ca1ede4c6b7a656c2

See more details on using hashes here.

Provenance

The following attestation bundles were made for catstat-0.5.1.tar.gz:

Publisher: release.yml on Matapanino/catstat

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file catstat-0.5.1-py3-none-any.whl.

File metadata

  • Download URL: catstat-0.5.1-py3-none-any.whl
  • Upload date:
  • Size: 62.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for catstat-0.5.1-py3-none-any.whl
Algorithm Hash digest
SHA256 778b4d8cb92a01688f9040add632921a173e19b4835630845fd97f04d1a75e60
MD5 190b8c2d6aa70037bab01b981edeccc3
BLAKE2b-256 560a61042045403f18ee5dad0ed51169b12fc6ee04f1c7f8326c5b12837ef010

See more details on using hashes here.

Provenance

The following attestation bundles were made for catstat-0.5.1-py3-none-any.whl:

Publisher: release.yml on Matapanino/catstat

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page