Skip to main content

Stochastic gradient descent implementation of permutation weighting

Project description

stochpw - Permutation Weighting for Causal Inference

PyPI version Python 3.12+ CI codecov Code style: ruff License

Permutation weighting learns importance weights for causal inference by training a discriminator to distinguish between observed treatment-covariate pairs and artificially permuted pairs.

Installation

pip install stochpw

For development:

git clone https://github.com/ddimmery/stochpw.git
cd stochpw
poetry install

Quick Start

import jax.numpy as jnp
from stochpw import PermutationWeighter

# Your observational data
X = jnp.array(...)  # Covariates, shape (n_samples, n_features)
A = jnp.array(...)  # Treatments, shape (n_samples, 1)

# Fit permutation weighter (sklearn-style API)
weighter = PermutationWeighter(
    num_epochs=100,
    batch_size=256,
    random_state=42
)
weighter.fit(X, A)

# Predict importance weights
weights = weighter.predict(X, A)

# Use weights for downstream task
# (tools for causal estimation not provided)
# ate = weighted_estimator(Y, A, weights)

How It Works

Permutation weighting estimates density ratios by:

  1. Training a discriminator to distinguish:

    • Permuted pairs: (X, A') with label C=1 (treatments shuffled)
    • Observed pairs: (X, A) with label C=0 (original data)
  2. Extracting weights from discriminator probabilities:

    w(a, x) = η(a, x) / (1 - η(a, x))
    

    where η(a, x) = p(C=1 | a, x)

  3. Using weights for balancing weights in causal effect estimation

Composable Design

The package exposes low-level components for integration into larger models:

from stochpw import (
    BaseDiscriminator,
    LinearDiscriminator,
    MLPDiscriminator,
    create_training_batch,
    logistic_loss,
    extract_weights,
)

# Use in your custom architecture (e.g., DragonNet)
batch = create_training_batch(X, A, batch_indices, rng_key)
logits = my_discriminator(params, batch.A, batch.X, batch.AX)
loss = logistic_loss(logits, batch.C)

Features

  • JAX-based: Fast, GPU-compatible, auto-differentiable
  • Sklearn-style API: Familiar .fit() and .predict() interface
  • Composable: All components exposed for integration
  • Flexible: Supports binary, continuous, and multi-dimensional treatments
  • Diagnostic tools: ESS, SMD, and balance checks included

References

Arbour, D., Dimmery, D., & Sondhi, A. (2021). Permutation Weighting. In Proceedings of the 38th International Conference on Machine Learning, PMLR 139:331-341.

License

Apache-2.0 License - see LICENSE file for details.

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Citation

If you use this package, please cite the original paper:

@InProceedings{arbour21permutation,
  title = {Permutation Weighting},
  author = {Arbour, David and Dimmery, Drew and Sondhi, Arjun},
  booktitle = {Proceedings of the 38th International Conference on Machine Learning},
  pages = {331--341},
  year = {2021},
  editor = {Meila, Marina and Zhang, Tong},
  volume = {139},
  series = {Proceedings of Machine Learning Research},
  month = {18--24 Jul},
  publisher = {PMLR},
  pdf = {http://proceedings.mlr.press/v139/arbour21a/arbour21a.pdf},
  url = {https://proceedings.mlr.press/v139/arbour21a.html}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stochpw-0.3.0.tar.gz (24.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

stochpw-0.3.0-py3-none-any.whl (31.6 kB view details)

Uploaded Python 3

File details

Details for the file stochpw-0.3.0.tar.gz.

File metadata

  • Download URL: stochpw-0.3.0.tar.gz
  • Upload date:
  • Size: 24.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.13.5 Darwin/23.2.0

File hashes

Hashes for stochpw-0.3.0.tar.gz
Algorithm Hash digest
SHA256 b64216abfe516d2d9ca9f37037895e91290816487f4196252f679ce8fb677649
MD5 52c2a46d00752e7dcdba372016accfc1
BLAKE2b-256 9a6f656ffc6b8bc14507c4441f89f0db9ab8e17330fee01899ed58a3dd3a01f6

See more details on using hashes here.

File details

Details for the file stochpw-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: stochpw-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 31.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.13.5 Darwin/23.2.0

File hashes

Hashes for stochpw-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c42b479fde594e60c2b9f64dc1413e5fa50dff012dbd5d7739cd4d0c191573f0
MD5 559a540ddefee1a62bb5b4e9bdec2c87
BLAKE2b-256 d18ce7727d7ec3604ea8bea9c4da746e0ffcf54c5bab1b492e7982e7562c3d0b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page