Skip to main content

Stochastic gradient descent implementation of permutation weighting

Project description

stochpw - Permutation Weighting for Causal Inference

Python 3.10+ JAX

Permutation weighting learns importance weights for causal inference by training a discriminator to distinguish between observed treatment-covariate pairs and artificially permuted pairs.

Installation

pip install stochpw  # Coming soon

For development:

git clone https://github.com/yourusername/stochpw.git
cd stochpw
poetry install

Quick Start

import jax.numpy as jnp
from stochpw import PermutationWeighter

# Your observational data
X = jnp.array(...)  # Covariates, shape (n_samples, n_features)
A = jnp.array(...)  # Treatments, shape (n_samples, 1)

# Fit permutation weighter (sklearn-style API)
weighter = PermutationWeighter(
    num_epochs=100,
    batch_size=256,
    random_state=42
)
weighter.fit(X, A)

# Predict importance weights
weights = weighter.predict(X, A)

# Use weights for causal inference (in external package)
# ate = weighted_estimator(Y, A, weights)

How It Works

Permutation weighting estimates density ratios by:

  1. Training a discriminator to distinguish:

    • Permuted pairs: (X, A') with label C=1 (treatments shuffled)
    • Observed pairs: (X, A) with label C=0 (original data)
  2. Extracting weights from discriminator probabilities:

    w(a, x) = η(a, x) / (1 - η(a, x))
    

    where η(a, x) = p(C=1 | a, x)

  3. Using weights for inverse probability weighting in causal effect estimation

Composable Design

The package exposes low-level components for integration into larger models:

from stochpw import (
    create_training_batch,
    logistic_loss,
    extract_weights,
    create_linear_discriminator
)

# Use in your custom architecture (e.g., DragonNet)
batch = create_training_batch(X, A, batch_indices, rng_key)
logits = my_discriminator(params, batch.A, batch.X)
loss = logistic_loss(logits, batch.C)

Features

  • JAX-based: Fast, GPU-compatible, auto-differentiable
  • Sklearn-style API: Familiar .fit() and .predict() interface
  • Composable: All components exposed for integration
  • Flexible: Supports binary, continuous, and multi-dimensional treatments
  • Diagnostic tools: ESS, SMD, and balance checks included

References

Arbour, D., Dimmery, D., & Sondhi, A. (2021). Permutation Weighting. International Conference on Machine Learning (ICML).

License

MIT License - see LICENSE file for details.

Contributing

See CONTRIBUTING.md for guidelines.

Citation

If you use this package, please cite:

@software{stochpw2024,
  title = {stochpw: Permutation Weighting for Causal Inference},
  author = {Your Name},
  year = {2024},
  url = {https://github.com/yourusername/stochpw}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stochpw-0.2.0.tar.gz (16.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

stochpw-0.2.0-py3-none-any.whl (20.4 kB view details)

Uploaded Python 3

File details

Details for the file stochpw-0.2.0.tar.gz.

File metadata

  • Download URL: stochpw-0.2.0.tar.gz
  • Upload date:
  • Size: 16.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.13.5 Darwin/23.2.0

File hashes

Hashes for stochpw-0.2.0.tar.gz
Algorithm Hash digest
SHA256 f2ee9b9dc650fa15cdeed843787653bf3a68190ff8676a0fffb67747c2390925
MD5 29ca7d1f3f6116186b34688bb32ac76f
BLAKE2b-256 950fd021f2e892f3e073171ea8c43f787ac16aeacef0296835416c7875f9db52

See more details on using hashes here.

File details

Details for the file stochpw-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: stochpw-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 20.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.13.5 Darwin/23.2.0

File hashes

Hashes for stochpw-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 bf122d01981d1ea3ffc267632e2985a32508a0ea3c5779f98c314131a669de11
MD5 9bb077ab46ec094af78ceea6ffbb4546
BLAKE2b-256 32f6b86cb9b832e4f413622698ddbf7d09781bea5a8a4a5e6245def9aa8fd25c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page