Stochastic gradient descent implementation of permutation weighting
Project description
stochpw - Permutation Weighting for Causal Inference
Permutation weighting learns importance weights for causal inference by training a discriminator to distinguish between observed treatment-covariate pairs and artificially permuted pairs.
Installation
pip install stochpw # Coming soon
For development:
git clone https://github.com/yourusername/stochpw.git
cd stochpw
poetry install
Quick Start
import jax.numpy as jnp
from stochpw import PermutationWeighter
# Your observational data
X = jnp.array(...) # Covariates, shape (n_samples, n_features)
A = jnp.array(...) # Treatments, shape (n_samples, 1)
# Fit permutation weighter (sklearn-style API)
weighter = PermutationWeighter(
num_epochs=100,
batch_size=256,
random_state=42
)
weighter.fit(X, A)
# Predict importance weights
weights = weighter.predict(X, A)
# Use weights for causal inference (in external package)
# ate = weighted_estimator(Y, A, weights)
How It Works
Permutation weighting estimates density ratios by:
-
Training a discriminator to distinguish:
- Permuted pairs: (X, A') with label C=1 (treatments shuffled)
- Observed pairs: (X, A) with label C=0 (original data)
-
Extracting weights from discriminator probabilities:
w(a, x) = η(a, x) / (1 - η(a, x))where η(a, x) = p(C=1 | a, x)
-
Using weights for inverse probability weighting in causal effect estimation
Composable Design
The package exposes low-level components for integration into larger models:
from stochpw import (
create_training_batch,
logistic_loss,
extract_weights,
create_linear_discriminator
)
# Use in your custom architecture (e.g., DragonNet)
batch = create_training_batch(X, A, batch_indices, rng_key)
logits = my_discriminator(params, batch.A, batch.X)
loss = logistic_loss(logits, batch.C)
Features
- JAX-based: Fast, GPU-compatible, auto-differentiable
- Sklearn-style API: Familiar
.fit()and.predict()interface - Composable: All components exposed for integration
- Flexible: Supports binary, continuous, and multi-dimensional treatments
- Diagnostic tools: ESS, SMD, and balance checks included
References
Arbour, D., Dimmery, D., & Sondhi, A. (2021). Permutation Weighting. International Conference on Machine Learning (ICML).
License
MIT License - see LICENSE file for details.
Contributing
See CONTRIBUTING.md for guidelines.
Citation
If you use this package, please cite:
@software{stochpw2024,
title = {stochpw: Permutation Weighting for Causal Inference},
author = {Your Name},
year = {2024},
url = {https://github.com/yourusername/stochpw}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file stochpw-0.2.0.tar.gz.
File metadata
- Download URL: stochpw-0.2.0.tar.gz
- Upload date:
- Size: 16.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.1.3 CPython/3.13.5 Darwin/23.2.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f2ee9b9dc650fa15cdeed843787653bf3a68190ff8676a0fffb67747c2390925
|
|
| MD5 |
29ca7d1f3f6116186b34688bb32ac76f
|
|
| BLAKE2b-256 |
950fd021f2e892f3e073171ea8c43f787ac16aeacef0296835416c7875f9db52
|
File details
Details for the file stochpw-0.2.0-py3-none-any.whl.
File metadata
- Download URL: stochpw-0.2.0-py3-none-any.whl
- Upload date:
- Size: 20.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.1.3 CPython/3.13.5 Darwin/23.2.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bf122d01981d1ea3ffc267632e2985a32508a0ea3c5779f98c314131a669de11
|
|
| MD5 |
9bb077ab46ec094af78ceea6ffbb4546
|
|
| BLAKE2b-256 |
32f6b86cb9b832e4f413622698ddbf7d09781bea5a8a4a5e6245def9aa8fd25c
|