Skip to main content

Stochastic gradient descent implementation of permutation weighting

Project description

stochpw - Permutation Weighting for Causal Inference

Python 3.10+ JAX

Permutation weighting learns importance weights for causal inference by training a discriminator to distinguish between observed treatment-covariate pairs and artificially permuted pairs.

Installation

pip install stochpw  # Coming soon

For development:

git clone https://github.com/yourusername/stochpw.git
cd stochpw
poetry install

Quick Start

import jax.numpy as jnp
from stochpw import PermutationWeighter

# Your observational data
X = jnp.array(...)  # Covariates, shape (n_samples, n_features)
A = jnp.array(...)  # Treatments, shape (n_samples, 1)

# Fit permutation weighter (sklearn-style API)
weighter = PermutationWeighter(
    num_epochs=100,
    batch_size=256,
    random_state=42
)
weighter.fit(X, A)

# Predict importance weights
weights = weighter.predict(X, A)

# Use weights for causal inference (in external package)
# ate = weighted_estimator(Y, A, weights)

How It Works

Permutation weighting estimates density ratios by:

  1. Training a discriminator to distinguish:

    • Permuted pairs: (X, A') with label C=1 (treatments shuffled)
    • Observed pairs: (X, A) with label C=0 (original data)
  2. Extracting weights from discriminator probabilities:

    w(a, x) = η(a, x) / (1 - η(a, x))
    

    where η(a, x) = p(C=1 | a, x)

  3. Using weights for inverse probability weighting in causal effect estimation

Composable Design

The package exposes low-level components for integration into larger models:

from stochpw import (
    create_training_batch,
    logistic_loss,
    extract_weights,
    create_linear_discriminator
)

# Use in your custom architecture (e.g., DragonNet)
batch = create_training_batch(X, A, batch_indices, rng_key)
logits = my_discriminator(params, batch.A, batch.X)
loss = logistic_loss(logits, batch.C)

Features

  • JAX-based: Fast, GPU-compatible, auto-differentiable
  • Sklearn-style API: Familiar .fit() and .predict() interface
  • Composable: All components exposed for integration
  • Flexible: Supports binary, continuous, and multi-dimensional treatments
  • Diagnostic tools: ESS, SMD, and balance checks included

References

Arbour, D., Dimmery, D., & Sondhi, A. (2021). Permutation Weighting. International Conference on Machine Learning (ICML).

License

MIT License - see LICENSE file for details.

Contributing

See CONTRIBUTING.md for guidelines.

Citation

If you use this package, please cite:

@software{stochpw2024,
  title = {stochpw: Permutation Weighting for Causal Inference},
  author = {Your Name},
  year = {2024},
  url = {https://github.com/yourusername/stochpw}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stochpw-0.1.0.tar.gz (14.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

stochpw-0.1.0-py3-none-any.whl (16.6 kB view details)

Uploaded Python 3

File details

Details for the file stochpw-0.1.0.tar.gz.

File metadata

  • Download URL: stochpw-0.1.0.tar.gz
  • Upload date:
  • Size: 14.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.13.5 Darwin/23.2.0

File hashes

Hashes for stochpw-0.1.0.tar.gz
Algorithm Hash digest
SHA256 0373d910bcedc43288320e5e9abd46405f261bf1d4a7b05669b1546bcf147fac
MD5 e3f5ec190ffdae8d1023d566342d7ce8
BLAKE2b-256 9dee970522f7593cef3e0651ba671e96823a3b7e352bb03980ecff6d1f84f642

See more details on using hashes here.

File details

Details for the file stochpw-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: stochpw-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 16.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.13.5 Darwin/23.2.0

File hashes

Hashes for stochpw-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 bc51a69df644046372bdfdb96b39878cb45ea164977002541f2414f16d28b026
MD5 22e45da51c47377d8e70685e03a6bcdf
BLAKE2b-256 579c820fa3c0f6ed723c3f0896487e769bfa91c4d32d92b73617e10e3ec91971

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page