Generalized Riesz Regression (GRR) utilities, including nearest-neighbor matching as LSIF/Riesz regression.
Project description
genriesz — Generalized Riesz Regression (GRR)
A Python library for Generalized Riesz Regression (GRR) under Bregman divergences — a unified way to fit Riesz representers with automatic regressor balancing (ARB) and then report RA / RW / ARW estimates with inference (optionally via cross-fitting).
- Docs: https://genriesz.readthedocs.io/en/latest/
- Paper: A Unified Framework for Debiased Machine Learning: Riesz Representer Fitting under Bregman Divergence (arXiv:2601.07752)
Contents
- Installation
- Core idea
- Quickstart: ATE (Average Treatment Effect)
- Choosing a Bregman generator (Table 1 from the paper)
- General API:
grr_functional - Built-in estimands
- Basis functions
- Jupyter notebook
- References
- License
Installation
Python >= 3.10.
From PyPI:
pip install genriesz
Optional extras:
# scikit-learn integrations (tree-based feature maps)
pip install "genriesz[sklearn]"
# PyTorch integrations (neural-network feature maps)
pip install "genriesz[torch]"
From a local checkout (editable install):
python -m pip install -U pip
pip install -e .
Core idea
You specify:
- an estimand / linear functional
m(X, γ), - a feature map / basis
φ(X), - a Bregman generator
g(X, α)(or one of the built-in generator classes),
and the library will:
- build the ARB link function induced by
g, - fit a Riesz representer
α̂(X)via GRR, - optionally fit an outcome model
γ̂(X)(for RA / ARW / TMLE), - return RA / RW / ARW / TMLE point estimates and inference (SE / CI / p-value), optionally with cross-fitting.
Notation in this library: the regressor is
X(shape(n, d)) and the outcome isY(shape(n,)). If you prefer the paper’s notation, you can think ofXas the full regressor vector (oftenX = [D, Z]).
Quickstart: ATE (Average Treatment Effect)
The ATE is available as a convenient wrapper grr_ate.
import numpy as np
from genriesz import (
grr_ate,
UKLGenerator,
PolynomialBasis,
TreatmentInteractionBasis,
)
# Example layout: X = [D, Z]
# D: treatment (0/1)
# Z: covariates
n, d_z = 1000, 5
rng = np.random.default_rng(0)
Z = rng.normal(size=(n, d_z))
D = (rng.normal(size=n) > 0).astype(float)
Y = 2.0 * D + Z[:, 0] + rng.normal(size=n)
X = np.column_stack([D, Z])
# Base basis on Z (or on all of X if you prefer).
psi = PolynomialBasis(degree=2)
# ATE-friendly basis: interact the base basis with treatment.
phi = TreatmentInteractionBasis(base_basis=psi)
# Unnormalized KL generator with a branch function:
# + branch for treated (D=1), - branch for control (D=0)
gen = UKLGenerator(C=1.0, branch_fn=lambda x: int(x[0] == 1.0)).as_generator()
res = grr_ate(
X=X,
Y=Y,
basis=phi,
generator=gen,
cross_fit=True,
folds=5,
riesz_penalty="l2",
riesz_lam=1e-3,
estimators=("ra", "rw", "arw", "tmle"),
)
print(res.summary_text())
Choosing a Bregman generator (Table 1 from the paper)
The generator g determines the GRR objective, and (through the induced link) the shape of the fitted representer / weights.
The paper’s Table 1 summarizes how common choices relate to well-known density-ratio estimation and Riesz representer / balancing-weight methods.
Note on citations: GitHub README rendering does not resolve LaTeX bibliography commands like
\citep{...}. The table below uses clickable author–year links. For full bibliography entries (author lists, venues), see CITATIONS.md.
Convention:
C ∈ ℝis a problem-dependent constant;ω ∈ (0, ∞).
Table 1 — Correspondence among Bregman generators, density-ratio estimation, and Riesz representer estimation
| Bregman generator $g(\alpha)$ | Density-ratio (DR) estimation view | Riesz representer (RR) estimation view |
|---|---|---|
| $(\alpha - C)^2$ | LSIF (Kanamori et al., 2009) / KuLSIF (Kanamori et al., 2012) | SQ-Riesz regression (this library); RieszNet / ForestRiesz (Chernozhukov et al., 2022); RieszBoost (Lee & Schuler, 2025); KRRR (Singh, 2021); nearest-neighbor matching (Lin et al., 2023); causal tree / causal forest (Wager & Athey, 2018) |
| Dual solution (linear link) | Kernel mean matching (Gretton et al., 2009) | Sieve Riesz representer (Chen & Christensen, 2015); stable balancing weights (Zubizarreta, 2015; Bruns-Smith et al., 2025); approximate residual balancing (Athey et al., 2018); covariate balancing by SVM (Tarr & Imai, 2025) |
| $(\lvert\alpha\rvert - C)\log(\lvert\alpha\rvert - C) - \lvert\alpha\rvert$ | UKL divergence minimization (Nguyen et al., 2010) | UKL-Riesz regression (this library); tailored loss minimization ($\alpha=\beta=-1$; Zhao, 2019); calibrated estimation (Tan, 2020) |
| Dual solution (logistic / log link) | KLIEP (Sugiyama et al., 2008) | Entropy balancing weights (Hainmueller, 2012) |
| $(\lvert\alpha\rvert - C)\log(\lvert\alpha\rvert - C) - (\lvert\alpha\rvert + C)\log(\lvert\alpha\rvert + C)$ | BKL divergence minimization (Qin, 1998); TRE (Rhodes et al., 2020) | BKL-Riesz regression (this library); logistic MLE propensity-score fit (standard approach); tailored loss minimization ($\alpha=\beta=0$; Zhao, 2019) |
| $\frac{(\lvert\alpha\rvert - C)^{1+\omega} - (\lvert\alpha\rvert - C)}{\omega} - (\lvert\alpha\rvert - C)$, $\omega>0$ | Basu's Power (BP) divergence minimization (Sugiyama et al., 2012) | BP-Riesz regression (this library) |
| $C\log(1-\lvert\alpha\rvert) + C\lvert\alpha\rvert\bigl(\log\lvert\alpha\rvert - \log(1-\lvert\alpha\rvert)\bigr)$, $\alpha\in(0,1)$ | PU learning / nonnegative PU learning (du Plessis et al., 2015; Kiryo et al., 2017) | PU-Riesz regression (this library) |
| General Bregman divergence minimization | Density-ratio matching (Sugiyama et al., 2012); D3RE (Kato & Teshima, 2021) | Generalized Riesz regression (this library via custom BregmanGenerator) |
Full bibliography: see CITATIONS.md.
Built-in generator classes
For most use-cases you can start from one of the built-ins:
SquaredGenerator→ squared distance / "SQ-Riesz"UKLGenerator→ unnormalized KL divergence / "UKL-Riesz"BKLGenerator→ binary KL divergence / "BKL-Riesz"BPGenerator→ Basu's power divergence / "BP-Riesz"PUGenerator→ bounded-weights generator / "PU-Riesz"BregmanGenerator→ bring your owng, optionally withgradandinv_grad
General API: grr_functional
grr_functional is the most general entry point.
You provide:
m(x_row, gamma)— the estimand (a linear functional),- a basis
basis(X)— feature map returning an(n, p)design matrix, - a Bregman
generator.
m can be either:
- a built-in
LinearFunctional(recommended), or - a plain Python callable (wrapped as
CallableFunctional).
Example skeleton:
import numpy as np
from genriesz import grr_functional, BregmanGenerator
def m(x, gamma):
# x is a single row (1D array)
# gamma is a callable gamma(x_row)
return gamma(x)
def g(x, alpha):
# x is a single row; alpha is a scalar
return 0.5 * alpha**2
def basis(X):
# X is (n,d) -> (n,p)
return np.c_[np.ones(len(X)), X]
X = np.random.randn(200, 3)
Y = np.random.randn(200)
generator = BregmanGenerator(g=g) # grad/inv_grad can be derived numerically if omitted
res = grr_functional(
X=X,
Y=Y,
m=m,
basis=basis,
generator=generator,
estimators=("rw",),
)
print(res.summary_text())
Notes:
- In the example above,
mis a plain callable. Internally,grr_functionalwraps it asCallableFunctional. - The callable must be linear in the function argument
gamma. If you need performance or advanced control, implement a custom subclass ofLinearFunctionalinstead. - If you want a custom name in the summary output, wrap explicitly:
m = CallableFunctional(m, name="MyEstimand"). - Bernoulli TMLE (
outcome_link="logit") is implemented for the built-in treatment-type functionals (ATE/ATT/DID). If you represent those estimands via a custom callablem, prefer the built-in wrappers (e.g.grr_ate).
Providing $g'$ and $(g')^{-1}$
If you can implement the derivative grad(W_i, alpha) and inverse-derivative
inv_grad(W_i, v) analytically, pass them to:
BregmanGenerator(g=..., grad=..., inv_grad=...)
If you omit them, the library falls back to:
- finite differences for $g'$, and
- scalar root-finding for $(g')^{-1}$.
Built-in estimands
The following convenience wrappers are included:
- ATE (average treatment effect):
grr_ate/ATEFunctional(...) - ATT (average treatment effect on the treated):
grr_att/ATTFunctional(...) - DID (panel DID as ATT on ΔY):
grr_did/DIDFunctional(...) - AME (average marginal effect / average derivative):
grr_ame/AMEFunctional(...)
For covariate-shift density ratio estimation via generalized Bregman divergences, see fit_density_ratio.
Basis functions
Polynomial basis
from genriesz import PolynomialBasis
psi = PolynomialBasis(degree=3)
Phi = psi(X) # (n,p)
RKHS-style bases
Approximate an RBF kernel with either random Fourier features or a Nyström basis:
from genriesz import RBFRandomFourierBasis, RBFNystromBasis
rff = RBFRandomFourierBasis(n_features=500, sigma=1.0, standardize=True, random_state=0)
Phi_rff = rff(X)
nys = RBFNystromBasis(n_centers=500, sigma=1.0, standardize=True, random_state=0)
Phi_nys = nys(X)
Nearest-neighbor matching (kNN catchment-area basis)
Nearest-neighbor matching can be expressed using a catchment-area indicator basis
$\phi_j(z) = \mathbf{1}{c_j \in \mathrm{NN}_k(z)}$,
where ${c_j}$ are centers and $\mathrm{NN}_k(z)$ is the set of k nearest centers of $z$.
from genriesz import KNNCatchmentBasis
centers = X[:200] # example
queries = X[200:] # example
basis = KNNCatchmentBasis(n_neighbors=3).fit(centers)
Phi = basis(queries) # dense (n_queries, n_centers)
See examples/ate_synthetic_nn_matching.py for an end-to-end matching-style ATE estimate.
Random forest leaves (scikit-learn)
You can use a random forest as a feature map by encoding leaf indices:
from sklearn.ensemble import RandomForestRegressor
from genriesz.sklearn_basis import RandomForestLeafBasis
rf = RandomForestRegressor(n_estimators=200, random_state=0)
leaf_basis = RandomForestLeafBasis(rf)
Phi_rf = leaf_basis(X)
Neural network features (PyTorch)
If you have PyTorch installed, you can use a neural network as a fixed feature map.
See src/genriesz/torch_basis.py for a minimal wrapper.
Jupyter notebook
An end-to-end notebook with runnable examples is provided at:
notebooks/ATE_example.ipynbnotebooks/ATT_example.ipynbnotebooks/AME_example.ipynbnotebooks/DID_example.ipynb
References
If you use genriesz in academic work, please cite:
- A Unified Framework for Debiased Machine Learning: Riesz Representer Fitting under Bregman Divergence (arXiv:2601.07752)
-
Consolidates earlier related drafts: arXiv:2509.22122, arXiv:2510.26783, arXiv:2510.23534.
-
Bibtex-entry:
@misc{Kato2026unifiedframework, title={A Unified Framework for Debiased Machine Learning: Riesz Representer Fitting under Bregman Divergence}, author={Masahiro Kato}, year={2026}, note={{a}rXiv: 2601.07752}, }
-
- Direct Bias-Correction Term Estimation for Propensity Scores and Average Treatment Effect Estimation (arXiv:2509.22122)
- Nearest Neighbor Matching as Least Squares Density Ratio Estimation and Riesz Regression (arXiv:2510.24433)
For full bibliography entries (author lists, venues), see CITATIONS.md.
License
GNU General Public License v3.0 (GPL-3.0).
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file genriesz-0.2.5.tar.gz.
File metadata
- Download URL: genriesz-0.2.5.tar.gz
- Upload date:
- Size: 108.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9d626b79048130cd7b34fab1a13db07540a59e73a1f36c75b04df8012f2262ce
|
|
| MD5 |
12dbbc1babe07e2ec74eace47efc3d8a
|
|
| BLAKE2b-256 |
8042c6ae4359f369342b73490242276d43e8a4a713e1aad8e3213d25f95a4cbc
|
File details
Details for the file genriesz-0.2.5-py3-none-any.whl.
File metadata
- Download URL: genriesz-0.2.5-py3-none-any.whl
- Upload date:
- Size: 63.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a16510957d260f4c793b0f00c90c9915cd39378cc8bff376636d2d792a45f32e
|
|
| MD5 |
66e78632f7e16b6b9b8904731654eed8
|
|
| BLAKE2b-256 |
0ea1d510f4f79e87e9e266d0d5b828bf0238c77bf99e7bd225819be01abea881
|