Skip to main content

No project description provided

Project description

skcausal

skcausal: A Machine Learning Library for Causal Inference

Overview

skcausal is a Python library designed to provide machine learning tools for causal inference. It allows users to estimate causal effects using techniques such as propensity score weighting, generalized propensity scores, and optimal hyperparameter tuning. Built on top of polars, optuna, and pytorch-lightning, skcausal offers scalable and flexible implementations of state-of-the-art causal response estimation methods.

Features

  • Causal Estimation Models: Implements various approaches to causal effect estimation, including direct response modeling, propensity weighting, and GPS-based methods.
  • Hyperparameter Optimization: Integrates optuna for tuning causal models efficiently.
  • Propensity Score Weighting: Supports several weighting techniques, including synthetic classifier-based estimators and neural network-based density ratio estimation.
  • Flexible Treatment Modeling: Supports both binary and continuous treatment variables, as well as multi-dimensional treatment estimation.
  • Seamless Integration with Machine Learning Pipelines: Provides compatibility with sklearn-like API and supports modern ML techniques for regression and classification.

Installation

To install skcausal, use:

pip install skcausal

Modules

1. Causal Estimators (skcausal.causal_estimators)

This module provides the core classes for estimating the average dose-response function (ADRF) and individual treatment effects (ITE).

Key Classes:

  • BaseCausalResponseEstimator - Abstract base class for all causal estimators.
  • GPS - Implements the Generalized Propensity Score method.
  • PropensityWeightingDiscrete - Uses Propensity Score Weighting for discrete treatments.
  • PropensityWeightingContinuous - Extends Propensity Score Weighting for continuous treatments.
  • BinaryDoublyRobust - Combines outcome regression and propensity weighting for doubly robust estimation.

2. Hyperparameter Tuning (skcausal.tuning)

Provides tools for hyperparameter tuning using optuna.

Key Classes:

  • OptunaCausalResponseEstimator - Wraps a causal estimator with an Optuna-based tuning procedure.

3. Weight Estimators (skcausal.weight_estimators)

Contains methods for estimating balancing weights used in causal inference.

Key Classes:

  • BaseBalancingWeightRegressor - Base class for balancing weight estimation.
  • BinaryClassifierWeightRegressor - Learns propensity scores via binary classification.
  • DiscriminativeWeightRegressor - Creates a synthetic classification problem to estimate inverse probability of treatment weights (IPTW).
  • TreatmentDensityRatioRegressor - Uses a deep learning model to estimate density ratio weights.
  • InterpolateNeuralWeightRegressor - A neural network-based method for weight estimation with linear interpolation.

4. Polars Utility Functions (skcausal.polars)

Helper functions for data preprocessing, including:

  • convert_categorical_to_dummies() - Converts categorical features to dummy variables.
  • to_dummies() - One-hot encodes categorical features.
  • assert_schema_equal() - Ensures consistency between data schemas.

Example Usage

Causal Inference with Generalized Propensity Score (GPS)

import polars as pl
import numpy as np
from skcausal.causal_estimators import GPS
from skcausal.weight_estimators import BinaryClassifierWeightRegressor
from sklearn.ensemble import RandomForestClassifier

# Generate synthetic data
n_samples = 1000
X = np.random.rand(n_samples, 5)
t = np.random.choice([0, 1], size=n_samples)
y = 2 * t + np.random.randn(n_samples)

X_df = pl.DataFrame(X, schema=[f"x{i}" for i in range(X.shape[1])])
t_df = pl.DataFrame({"treatment": t})
y_df = pl.DataFrame({"outcome": y})

# Define weight estimator
treatment_regressor = BinaryClassifierWeightRegressor(RandomForestClassifier())

# Define GPS estimator
gps_estimator = GPS(treatment_regressor, outcome_regressor=RandomForestClassifier())
gps_estimator.fit(X_df, y_df, t_df)

# Predict treatment effect
ate = gps_estimator.predict_average_treatment_effect(X_df, t_df)
print("Estimated ATE:", ate)

Hyperparameter Tuning with Optuna

from skcausal.tuning import OptunaCausalResponseEstimator
from causal_experiment.evaluation.metrics.dose_response import EMSE
from causal_experiment.datasets.synthetic_wang import SyntheticBidimensionalDataset
from skcausal.causal_estimators.direct_dynamicnet import DirectDynamicNet
from optuna.distributions import IntUniformDistribution, LogUniformDistribution

# Define model and dataset
model = DirectDynamicNet(n_epochs=10)
metric = EMSE()
dataset = SyntheticBidimensionalDataset().prepare(n=1000)
param_grid = {
    "learning_rate": LogUniformDistribution(1e-5, 1e-1),
    "batch_size": IntUniformDistribution(32, 512),
}

# Run hyperparameter tuning
optuna_estimator = OptunaCausalResponseEstimator(model, metric, param_grid, dataset, n_evals=10)
optuna_estimator.tune()
print("Best Parameters:", optuna_estimator.best_params_)

Contributing

We welcome contributions to skcausal! If you want to contribute, please follow these steps:

  1. Fork the repository.
  2. Create a new branch for your feature or bugfix.
  3. Write tests for your code.
  4. Submit a pull request.

License

skcausal is licensed under the MIT License. See the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

skcausal-0.0.1.tar.gz (34.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

skcausal-0.0.1-py3-none-any.whl (48.1 kB view details)

Uploaded Python 3

File details

Details for the file skcausal-0.0.1.tar.gz.

File metadata

  • Download URL: skcausal-0.0.1.tar.gz
  • Upload date:
  • Size: 34.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.5 CPython/3.11.11 Darwin/24.6.0

File hashes

Hashes for skcausal-0.0.1.tar.gz
Algorithm Hash digest
SHA256 a73d2691173d63b6a5be403bcba6b8642d48b3ad7a007d7282663c0ce26927b2
MD5 6dfbc865e343fd843e76c3f21db11ff7
BLAKE2b-256 666fd5d9ff033a2797a048a538c03abe18b107faa218192d217d07dc3988ae29

See more details on using hashes here.

File details

Details for the file skcausal-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: skcausal-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 48.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.5 CPython/3.11.11 Darwin/24.6.0

File hashes

Hashes for skcausal-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 cd35ca5a9c21121a344560738979bc514f189bee4f6bad4a766f754fb7daf63e
MD5 fe0e588edc053404b7c776f96d4e6539
BLAKE2b-256 62534480c6f800e685818959c1eec128d859a4a09f689629c2179405440b08ab

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page