Skip to main content

Generalized Riesz Regression (GRR) utilities, including nearest-neighbor matching as LSIF/Riesz regression.

Project description

genriesz — Generalized Riesz Regression (GRR)

A Python library for Generalized Riesz Regression (GRR) under Bregman divergences — a unified way to fit Riesz representers with automatic regressor balancing (ARB) and then report RA / RW / ARW estimates with inference (optionally via cross-fitting).


Contents


Installation

Python >= 3.10.

From PyPI:

pip install genriesz

Optional extras:

# scikit-learn integrations (tree-based feature maps)
pip install "genriesz[sklearn]"

# PyTorch integrations (neural-network feature maps)
pip install "genriesz[torch]"

From a local checkout (editable install):

python -m pip install -U pip
pip install -e .

Core idea

You specify:

  • an estimand / linear functional m(X, γ),
  • a feature map / basis φ(X),
  • a Bregman generator g(X, α) (or one of the built-in generator classes),

and the library will:

  1. build the ARB link function induced by g,
  2. fit a Riesz representer α̂(X) via GRR,
  3. optionally fit an outcome model γ̂(X) (for RA / ARW / TMLE),
  4. return RA / RW / ARW / TMLE point estimates and inference (SE / CI / p-value), optionally with cross-fitting.

Notation in this library: the regressor is X (shape (n, d)) and the outcome is Y (shape (n,)). If you prefer the paper’s notation, you can think of X as the full regressor vector (often X = [D, Z]).


Quickstart: ATE (Average Treatment Effect)

The ATE is available as a convenient wrapper grr_ate.

import numpy as np
from genriesz import (
    grr_ate,
    UKLGenerator,
    PolynomialBasis,
    TreatmentInteractionBasis,
)

# Example layout: X = [D, Z]
#   D: treatment (0/1)
#   Z: covariates
n, d_z = 1000, 5
rng = np.random.default_rng(0)
Z = rng.normal(size=(n, d_z))
D = (rng.normal(size=n) > 0).astype(float)
Y = 2.0 * D + Z[:, 0] + rng.normal(size=n)

X = np.column_stack([D, Z])

# Base basis on Z (or on all of X if you prefer).
psi = PolynomialBasis(degree=2)

# ATE-friendly basis: interact the base basis with treatment.
phi = TreatmentInteractionBasis(base_basis=psi)

# Unnormalized KL generator with a branch function:
#   + branch for treated (D=1), - branch for control (D=0)
gen = UKLGenerator(C=1.0, branch_fn=lambda x: int(x[0] == 1.0)).as_generator()

res = grr_ate(
    X=X,
    Y=Y,
    basis=phi,
    generator=gen,
    cross_fit=True,
    folds=5,
    riesz_penalty="l2",
    riesz_lam=1e-3,
    estimators=("ra", "rw", "arw", "tmle"),
)

print(res.summary_text())

Choosing a Bregman generator (Table 1 from the paper)

The generator g determines the GRR objective, and (through the induced link) the shape of the fitted representer / weights. The paper’s Table 1 summarizes how common choices relate to well-known density-ratio estimation and Riesz representer / balancing-weight methods.

Note on citations: GitHub README rendering does not resolve LaTeX bibliography commands like \citep{...}. The table below uses clickable author–year links. For full bibliography entries (author lists, venues), see CITATIONS.md.

Convention: C ∈ ℝ is a problem-dependent constant; ω ∈ (0, ∞).

Table 1 — Correspondence among Bregman generators, density-ratio estimation, and Riesz representer estimation
Bregman generator $g(\alpha)$ Density-ratio (DR) estimation view Riesz representer (RR) estimation view
$(\alpha - C)^2$ LSIF (Kanamori et al., 2009) / KuLSIF (Kanamori et al., 2012) SQ-Riesz regression (this library); RieszNet / ForestRiesz (Chernozhukov et al., 2022); RieszBoost (Lee & Schuler, 2025); KRRR (Singh, 2021); nearest-neighbor matching (Lin et al., 2023); causal tree / causal forest (Wager & Athey, 2018)
Dual solution (linear link) Kernel mean matching (Gretton et al., 2009) Sieve Riesz representer (Chen & Christensen, 2015); stable balancing weights (Zubizarreta, 2015; Bruns-Smith et al., 2025); approximate residual balancing (Athey et al., 2018); covariate balancing by SVM (Tarr & Imai, 2025)
$(\lvert\alpha\rvert - C)\log(\lvert\alpha\rvert - C) - \lvert\alpha\rvert$ UKL divergence minimization (Nguyen et al., 2010) UKL-Riesz regression (this library); tailored loss minimization ($\alpha=\beta=-1$; Zhao, 2019); calibrated estimation (Tan, 2020)
Dual solution (logistic / log link) KLIEP (Sugiyama et al., 2008) Entropy balancing weights (Hainmueller, 2012)
$(\lvert\alpha\rvert - C)\log(\lvert\alpha\rvert - C) - (\lvert\alpha\rvert + C)\log(\lvert\alpha\rvert + C)$ BKL divergence minimization (Qin, 1998); TRE (Rhodes et al., 2020) BKL-Riesz regression (this library); logistic MLE propensity-score fit (standard approach); tailored loss minimization ($\alpha=\beta=0$; Zhao, 2019)
$\frac{(\lvert\alpha\rvert - C)^{1+\omega} - (\lvert\alpha\rvert - C)}{\omega} - (\lvert\alpha\rvert - C)$, $\omega>0$ Basu's Power (BP) divergence minimization (Sugiyama et al., 2012) BP-Riesz regression (this library)
$C\log(1-\lvert\alpha\rvert) + C\lvert\alpha\rvert\bigl(\log\lvert\alpha\rvert - \log(1-\lvert\alpha\rvert)\bigr)$, $\alpha\in(0,1)$ PU learning / nonnegative PU learning (du Plessis et al., 2015; Kiryo et al., 2017) PU-Riesz regression (this library)
General Bregman divergence minimization Density-ratio matching (Sugiyama et al., 2012); D3RE (Kato & Teshima, 2021) Generalized Riesz regression (this library via custom BregmanGenerator)

Full bibliography: see CITATIONS.md.

Built-in generator classes

For most use-cases you can start from one of the built-ins:

  • SquaredGenerator → squared distance / "SQ-Riesz"
  • UKLGenerator → unnormalized KL divergence / "UKL-Riesz"
  • BKLGenerator → binary KL divergence / "BKL-Riesz"
  • BPGenerator → Basu's power divergence / "BP-Riesz"
  • PUGenerator → bounded-weights generator / "PU-Riesz"
  • BregmanGenerator → bring your own g, optionally with grad and inv_grad

General API: grr_functional

grr_functional is the most general entry point.

You provide:

  • m(x_row, gamma) — the estimand (a linear functional),
  • a basis basis(X) — feature map returning an (n, p) design matrix,
  • a Bregman generator.

m can be either:

  • a built-in LinearFunctional (recommended), or
  • a plain Python callable (wrapped as CallableFunctional).

Example skeleton:

import numpy as np
from genriesz import grr_functional, BregmanGenerator

def m(x, gamma):
    # x is a single row (1D array)
    # gamma is a callable gamma(x_row)
    return gamma(x)

def g(x, alpha):
    # x is a single row; alpha is a scalar
    return 0.5 * alpha**2

def basis(X):
    # X is (n,d) -> (n,p)
    return np.c_[np.ones(len(X)), X]

X = np.random.randn(200, 3)
Y = np.random.randn(200)

generator = BregmanGenerator(g=g)  # grad/inv_grad can be derived numerically if omitted

res = grr_functional(
    X=X,
    Y=Y,
    m=m,
    basis=basis,
    generator=generator,
    estimators=("rw",),
)

print(res.summary_text())

Notes:

  • In the example above, m is a plain callable. Internally, grr_functional wraps it as CallableFunctional.
  • The callable must be linear in the function argument gamma. If you need performance or advanced control, implement a custom subclass of LinearFunctional instead.
  • If you want a custom name in the summary output, wrap explicitly: m = CallableFunctional(m, name="MyEstimand").
  • Bernoulli TMLE (outcome_link="logit") is implemented for the built-in treatment-type functionals (ATE/ATT/DID). If you represent those estimands via a custom callable m, prefer the built-in wrappers (e.g. grr_ate).

Providing $g'$ and $(g')^{-1}$

If you can implement the derivative grad(W_i, alpha) and inverse-derivative inv_grad(W_i, v) analytically, pass them to:

BregmanGenerator(g=..., grad=..., inv_grad=...)

If you omit them, the library falls back to:

  • finite differences for $g'$, and
  • scalar root-finding for $(g')^{-1}$.

Built-in estimands

The following convenience wrappers are included:

  • ATE (average treatment effect): grr_ate / ATEFunctional(...)
  • ATT (average treatment effect on the treated): grr_att / ATTFunctional(...)
  • DID (panel DID as ATT on ΔY): grr_did / DIDFunctional(...)
  • AME (average marginal effect / average derivative): grr_ame / AMEFunctional(...)

For covariate-shift density ratio estimation via generalized Bregman divergences, see fit_density_ratio.


Basis functions

Polynomial basis

from genriesz import PolynomialBasis

psi = PolynomialBasis(degree=3)
Phi = psi(X)  # (n,p)

RKHS-style bases

Approximate an RBF kernel with either random Fourier features or a Nyström basis:

from genriesz import RBFRandomFourierBasis, RBFNystromBasis

rff = RBFRandomFourierBasis(n_features=500, sigma=1.0, standardize=True, random_state=0)
Phi_rff = rff(X)

nys = RBFNystromBasis(n_centers=500, sigma=1.0, standardize=True, random_state=0)
Phi_nys = nys(X)

Nearest-neighbor matching (kNN catchment-area basis)

Nearest-neighbor matching can be expressed using a catchment-area indicator basis

$\phi_j(z) = \mathbf{1}{c_j \in \mathrm{NN}_k(z)}$,

where ${c_j}$ are centers and $\mathrm{NN}_k(z)$ is the set of k nearest centers of $z$.

from genriesz import KNNCatchmentBasis

centers = X[:200]                  # example
queries = X[200:]                  # example

basis = KNNCatchmentBasis(n_neighbors=3).fit(centers)
Phi = basis(queries)               # dense (n_queries, n_centers)

See examples/ate_synthetic_nn_matching.py for an end-to-end matching-style ATE estimate.

Random forest leaves (scikit-learn)

You can use a random forest as a feature map by encoding leaf indices:

from sklearn.ensemble import RandomForestRegressor
from genriesz.sklearn_basis import RandomForestLeafBasis

rf = RandomForestRegressor(n_estimators=200, random_state=0)
leaf_basis = RandomForestLeafBasis(rf)
Phi_rf = leaf_basis(X)

Neural network features (PyTorch)

If you have PyTorch installed, you can use a neural network as a fixed feature map.

See src/genriesz/torch_basis.py for a minimal wrapper.


Jupyter notebook

An end-to-end notebook with runnable examples is provided at:

  • notebooks/ATE_example.ipynb
  • notebooks/ATT_example.ipynb
  • notebooks/AME_example.ipynb
  • notebooks/DID_example.ipynb

References

If you use genriesz in academic work, please cite:

For full bibliography entries (author lists, venues), see CITATIONS.md.


License

GNU General Public License v3.0 (GPL-3.0).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

genriesz-0.2.5.tar.gz (108.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

genriesz-0.2.5-py3-none-any.whl (63.6 kB view details)

Uploaded Python 3

File details

Details for the file genriesz-0.2.5.tar.gz.

File metadata

  • Download URL: genriesz-0.2.5.tar.gz
  • Upload date:
  • Size: 108.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.6

File hashes

Hashes for genriesz-0.2.5.tar.gz
Algorithm Hash digest
SHA256 9d626b79048130cd7b34fab1a13db07540a59e73a1f36c75b04df8012f2262ce
MD5 12dbbc1babe07e2ec74eace47efc3d8a
BLAKE2b-256 8042c6ae4359f369342b73490242276d43e8a4a713e1aad8e3213d25f95a4cbc

See more details on using hashes here.

File details

Details for the file genriesz-0.2.5-py3-none-any.whl.

File metadata

  • Download URL: genriesz-0.2.5-py3-none-any.whl
  • Upload date:
  • Size: 63.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.6

File hashes

Hashes for genriesz-0.2.5-py3-none-any.whl
Algorithm Hash digest
SHA256 a16510957d260f4c793b0f00c90c9915cd39378cc8bff376636d2d792a45f32e
MD5 66e78632f7e16b6b9b8904731654eed8
BLAKE2b-256 0ea1d510f4f79e87e9e266d0d5b828bf0238c77bf99e7bd225819be01abea881

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page