Stochastic gradient descent implementation of permutation weighting
Project description
stochpw - Permutation Weighting for Causal Inference
Permutation weighting learns importance weights for causal inference by training a discriminator to distinguish between observed treatment-covariate pairs and artificially permuted pairs.
Installation
pip install stochpw
For development:
git clone https://github.com/ddimmery/stochpw.git
cd stochpw
poetry install
Quick Start
import jax.numpy as jnp
from stochpw import PermutationWeighter
# Your observational data
X = jnp.array(...) # Covariates, shape (n_samples, n_features)
A = jnp.array(...) # Treatments, shape (n_samples, 1)
# Fit permutation weighter (sklearn-style API)
weighter = PermutationWeighter(
num_epochs=100,
batch_size=256,
random_state=42
)
weighter.fit(X, A)
# Predict importance weights
weights = weighter.predict(X, A)
# Use weights for downstream task
# (tools for causal estimation not provided)
# ate = weighted_estimator(Y, A, weights)
How It Works
Permutation weighting estimates density ratios by:
-
Training a discriminator to distinguish:
- Permuted pairs: (X, A') with label C=1 (treatments shuffled)
- Observed pairs: (X, A) with label C=0 (original data)
-
Extracting weights from discriminator probabilities:
w(a, x) = η(a, x) / (1 - η(a, x))where η(a, x) = p(C=1 | a, x)
-
Using weights for balancing weights in causal effect estimation
Composable Design
The package exposes low-level components for integration into larger models:
from stochpw import (
BaseDiscriminator,
LinearDiscriminator,
MLPDiscriminator,
create_training_batch,
logistic_loss,
extract_weights,
)
# Use in your custom architecture (e.g., DragonNet)
batch = create_training_batch(X, A, batch_indices, rng_key)
logits = my_discriminator(params, batch.A, batch.X, batch.AX)
loss = logistic_loss(logits, batch.C)
Features
- JAX-based: Fast, GPU-compatible, auto-differentiable
- Sklearn-style API: Familiar
.fit()and.predict()interface - Composable: All components exposed for integration
- Flexible: Supports binary, continuous, and multi-dimensional treatments
- Diagnostic tools: ESS, SMD, and balance checks included
References
Arbour, D., Dimmery, D., & Sondhi, A. (2021). Permutation Weighting. In Proceedings of the 38th International Conference on Machine Learning, PMLR 139:331-341.
License
Apache-2.0 License - see LICENSE file for details.
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
Citation
If you use this package, please cite the original paper:
@InProceedings{arbour21permutation,
title = {Permutation Weighting},
author = {Arbour, David and Dimmery, Drew and Sondhi, Arjun},
booktitle = {Proceedings of the 38th International Conference on Machine Learning},
pages = {331--341},
year = {2021},
editor = {Meila, Marina and Zhang, Tong},
volume = {139},
series = {Proceedings of Machine Learning Research},
month = {18--24 Jul},
publisher = {PMLR},
pdf = {http://proceedings.mlr.press/v139/arbour21a/arbour21a.pdf},
url = {https://proceedings.mlr.press/v139/arbour21a.html}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file stochpw-0.3.0.tar.gz.
File metadata
- Download URL: stochpw-0.3.0.tar.gz
- Upload date:
- Size: 24.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.1.3 CPython/3.13.5 Darwin/23.2.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b64216abfe516d2d9ca9f37037895e91290816487f4196252f679ce8fb677649
|
|
| MD5 |
52c2a46d00752e7dcdba372016accfc1
|
|
| BLAKE2b-256 |
9a6f656ffc6b8bc14507c4441f89f0db9ab8e17330fee01899ed58a3dd3a01f6
|
File details
Details for the file stochpw-0.3.0-py3-none-any.whl.
File metadata
- Download URL: stochpw-0.3.0-py3-none-any.whl
- Upload date:
- Size: 31.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.1.3 CPython/3.13.5 Darwin/23.2.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c42b479fde594e60c2b9f64dc1413e5fa50dff012dbd5d7739cd4d0c191573f0
|
|
| MD5 |
559a540ddefee1a62bb5b4e9bdec2c87
|
|
| BLAKE2b-256 |
d18ce7727d7ec3604ea8bea9c4da746e0ffcf54c5bab1b492e7982e7562c3d0b
|