Generalized Riesz Regression under Bregman divergences
Project description
genriesz — Generalized Riesz Regression (GRR)
This repository packages a Python library for Generalized Riesz Regression under Bregman divergences.
The key idea is:
- you specify a linear functional
m(X, γ)(the estimand), - you specify a basis
φ(X), - you specify a Bregman generator
g(X, α),
and the library:
- builds the automatic-covariate-balancing (ACB) link function from
g, - fits a Riesz representer
α̂(X)via GRR, - optionally fits an outcome model
γ̂(X), - returns DM / IPW / AIPW estimates with standard errors, confidence intervals, and p-values (optionally with cross-fitting).
Notation: in this library the regressor is called
Xand the outcome is calledY. If you prefer the paper's notation, you can think ofXas the full regressor vector (oftenX=[D,Z]).
Installation
From PyPI:
pip install genriesz
Optional extras:
# scikit-learn integrations (random forest leaf basis)
pip install "genriesz[sklearn]"
# PyTorch integrations (neural-network feature maps)
pip install "genriesz[torch]"
From a local checkout (editable install):
python -m pip install -U pip
pip install -e .
Quickstart: ATE (Average Treatment Effect)
The ATE can be estimated as a special case of grr_functional.
import numpy as np
from genriesz import (
grr_ate,
UKLGenerator,
PolynomialBasis,
TreatmentInteractionBasis,
)
# Example layout: X = [D, Z]
# D: treatment (0/1)
# Z: covariates
n, d_z = 1000, 5
rng = np.random.default_rng(0)
Z = rng.normal(size=(n, d_z))
D = (rng.normal(size=n) > 0).astype(float)
Y = 2.0 * D + Z[:, 0] + rng.normal(size=n)
X = np.column_stack([D, Z])
# Base basis on Z (or on all of X if you prefer).
psi = PolynomialBasis(degree=2)
# ATE-friendly basis: interact the base basis with treatment.
phi = TreatmentInteractionBasis(base_basis=psi)
# A common generator choice for ATE-style balancing.
# The branch function chooses the sign of alpha depending on the treatment.
# Here: positive for treated (D=1), negative for control (D=0).
gen = UKLGenerator(C=1.0, branch_fn=lambda x: int(x[0] == 1.0)).as_generator()
res = grr_ate(
X=X,
Y=Y,
basis=phi,
generator=gen,
cross_fit=True,
folds=5,
riesz_penalty="l2",
riesz_lam=1e-3,
estimators=("dm", "ipw", "aipw"),
)
print(res.summary_text())
General API: grr_functional
grr_functional is the most general entry point.
You provide:
m(X, gamma)— the estimand,- a
basis(X)— feature map, - a Bregman generator
g(X, alpha)(or a pre-builtgenerator).
Example skeleton:
import numpy as np
from genriesz import grr_functional, BregmanGenerator
def m(x, gamma):
# x is a single row (1D array)
# gamma is a callable gamma(w)
# return a scalar
return gamma(x)
def g(x, alpha):
# x is a single row; alpha is a scalar
# return g(x, alpha)
return 0.5 * alpha**2
def basis(X):
# X is (n,d); return (n,p)
return np.c_[np.ones(len(X)), X]
X = np.random.randn(200, 3)
Y = np.random.randn(200)
generator = BregmanGenerator(g=g) # gradients/inverse-grad can be auto-derived numerically
res = grr_functional(
X=X,
Y=Y,
m=m,
basis=basis,
generator=generator,
estimators=("ipw",),
)
print(res.summary_text())
Providing g' and (g')^{-1}
If you can analytically implement the derivative g_grad(X_i, alpha) and the inverse derivative
g_inv_grad(X_i, v), pass them to BregmanGenerator(g=..., grad=..., inv_grad=...).
If you omit them, the library falls back to:
- finite differences for
g', and - scalar root-finding for
(g')^{-1}.
Basis functions
Polynomial basis
from genriesz import PolynomialBasis
psi = PolynomialBasis(degree=3)
Phi = psi(X) # (n,p)
RKHS-style bases
You can approximate an RBF kernel either with random Fourier features or a Nyström basis.
from genriesz import RBFRandomFourierBasis, RBFNystromBasis
rff = RBFRandomFourierBasis(n_features=500, sigma=1.0, standardize=True, random_state=0)
Phi_rff = rff(X)
nys = RBFNystromBasis(n_centers=500, sigma=1.0, standardize=True, random_state=0)
Phi_nys = nys(X)
Nearest-neighbor matching (kNN catchment-area basis)
Nearest-neighbor matching can be expressed using a catchment-area indicator basis
(\phi_j(z) = \mathbf{1}{c_j \in \mathrm{NN}_k(z)}),
where ({c_j}) are a set of centers and (\mathrm{NN}_k(z)) is the set of k nearest centers of (z).
This library provides KNNCatchmentBasis:
from genriesz import KNNCatchmentBasis
basis = KNNCatchmentBasis(n_neighbors=3).fit(centers)
Phi = basis(queries) # dense (n_queries, n_centers)
See examples/ate_synthetic_nn_matching.py for an end-to-end matching-style ATE estimate.
Random forest leaves (scikit-learn)
If you have scikit-learn installed, you can use a random forest as a feature map by encoding
leaf indices.
from sklearn.ensemble import RandomForestRegressor
from genriesz.sklearn_basis import RandomForestLeafBasis
rf = RandomForestRegressor(n_estimators=200, random_state=0)
leaf_basis = RandomForestLeafBasis(rf)
Phi_rf = leaf_basis(X)
Neural network features (PyTorch)
If you have PyTorch installed, you can use a neural network as a fixed feature map.
See src/genriesz/torch_basis.py for a minimal wrapper.
Included estimands
- ATE:
grr_ate, orm=ATEFunctional(...) - AME (average marginal effect / average derivative):
grr_ame, orm=AverageDerivativeFunctional(...) - Average policy effect:
grr_policy_effect, orm=PolicyEffectFunctional(...)
Jupyter notebook
An end-to-end notebook with runnable examples is provided at:
notebooks/GRR_end_to_end_examples.ipynb
References
If you use genriesz in academic work, please cite:
- Masahiro Kato. Riesz Representer Fitting under Bregman Divergence: A Unified Framework for Debiased Machine Learning. https://arxiv.org/abs/2601.07752
- Note: https://arxiv.org/abs/2601.07752 consolidates earlier related drafts: https://arxiv.org/abs/2509.22122, https://arxiv.org/abs/2510.26783, and https://arxiv.org/abs/2510.23534.
- Masahiro Kato. Direct Bias-Correction Term Estimation for Propensity Scores and Average Treatment Effect Estimation. https://arxiv.org/abs/2509.22122
- Masahiro Kato. Nearest Neighbor Matching as Least Squares Density Ratio Estimation and Riesz Regression. https://arxiv.org/abs/2510.24433
License
GNU General Public License v3.0 (GPL-3.0).
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file genriesz-0.1.10-py3-none-any.whl.
File metadata
- Download URL: genriesz-0.1.10-py3-none-any.whl
- Upload date:
- Size: 58.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1a4dbb355e881a51500881b309a5d8b7d9c5a6320ff707b3cb9f351a2fb3ded7
|
|
| MD5 |
d7e155e6af5176e72a5f89a2cea1989b
|
|
| BLAKE2b-256 |
34ea2e518bb5278ac0384e155242069383182eecabc7a818aed8268555a9a7c3
|