Causal Inference Matching Package.
Project description
Introduction
PyMALTS is a learning-to-matching interpretable causal inference method. PyMALTS implements MALTS algorithm proposed by Harsh Parikh, Cynthia Rudin and Alexander Volfovsky in their 2019 paper titled "MALTS: Matching After Learning to Stretch"
Dependencies
PyMALTS is a Python3 library and it requires Numpy, Pandas, Scipy, Scikit-Learn, Matplotlib and Seaborn.
import pymalts
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
np.random.seed(0)
sns.set()
Reading Data
df = pd.read_csv('example/example_data.csv',index_col=0)
print(df.shape)
df.head()
(2500, 20)
| X1 | X2 | X3 | X4 | X5 | X6 | X7 | X8 | X9 | X10 | X11 | X12 | X13 | X14 | X15 | X16 | X17 | X18 | outcome | treated | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1355 | 1.881335 | 1.684164 | 0.532332 | 2.002254 | 1.435032 | 1.450196 | 1.974763 | 1.321659 | 0.709443 | -1.141244 | 0.883130 | 0.956721 | 2.498229 | 2.251677 | 0.375271 | -0.545129 | 3.334220 | 0.081259 | -15.679894 | 0 |
| 1320 | 0.666476 | 1.263065 | 0.657558 | 0.498780 | 1.096135 | 1.002569 | 0.881916 | 0.740392 | 2.780857 | -0.765889 | 1.230980 | -1.214324 | -0.040029 | 1.554477 | 4.235513 | 3.596213 | 0.959022 | 0.513409 | -7.068587 | 0 |
| 1233 | -0.193200 | 0.961823 | 1.652723 | 1.117316 | 0.590318 | 0.566765 | 0.775715 | 0.938379 | -2.055124 | 1.942873 | -0.606074 | 3.329552 | -1.822938 | 3.240945 | 2.106121 | 0.857190 | 0.577264 | -2.370578 | -5.133200 | 0 |
| 706 | 1.378660 | 1.794625 | 0.701158 | 1.815518 | 1.129920 | 1.188477 | 0.845063 | 1.217270 | 5.847379 | 0.566517 | -0.045607 | 0.736230 | 0.941677 | 0.835420 | -0.560388 | 0.427255 | 2.239003 | -0.632832 | 39.684984 | 1 |
| 438 | 0.434297 | 0.296656 | 0.545785 | 0.110366 | 0.151758 | -0.257326 | 0.601965 | 0.499884 | -0.973684 | -0.552586 | -0.778477 | 0.936956 | 0.831105 | 2.060040 | 3.153799 | 0.027665 | 0.376857 | -1.221457 | -2.954324 | 0 |
Using MALTS
Distance Metric Learning
Setting up the model for learning the distance metric.
- Variable name for the outcome variable: 'outcome'.
- Variable name for the treatment variable: 'treated'
- Data is assigned to python variable df
m = pymalts.malts_mf( outcome='outcome', treatment='treated', data=df) # running MALTS with default setting
Matched Groups
Matched Group matrix (MG_matrix) is NxN matrix with each row corresponding to each query unit and each column corresponds to matched units. Cell (i,j) in the matrix corresponds to the weight of unit j in the matched group of unit i. The weight corresponds to the numbers of times a unit is included in a matched group across M-folds.
m.MG_matrix
| 1355 | 1320 | 1233 | 706 | 438 | 184 | 1108 | 1612 | 816 | 131 | ... | 1181 | 1698 | 916 | 59 | 2267 | 1520 | 1408 | 909 | 603 | 2285 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1355 | 4.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 3.0 |
| 1320 | 0.0 | 4.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 4.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 1233 | 0.0 | 0.0 | 4.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 706 | 0.0 | 0.0 | 0.0 | 4.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 |
| 438 | 0.0 | 0.0 | 0.0 | 0.0 | 4.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 1520 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 4.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 1408 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 4.0 | 0.0 | 0.0 | 0.0 |
| 909 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 4.0 | 0.0 | 0.0 |
| 603 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 4.0 | 0.0 |
| 2285 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 4.0 |
2500 rows × 2500 columns
Visualizing matched group matrix as heatmap
fig = plt.figure(figsize=(10,10))
sns.heatmap(m.MG_matrix)
Accessing the matched group for an example unit with index equal to "1" and visualizing the weights as bar-chart
MG1 = m.MG_matrix.loc[1] #matched group for unit "1"
MG1[MG1>1].sort_values(ascending=False).plot(kind='bar',figsize=(20,5)) #Visualizing all the units matched to unit 1 more than once
ATE and CATE Estimates
m.CATE_df #each row is a cate estimate for a corresponding unit
| avg.CATE | std.CATE | outcome | treated | |
|---|---|---|---|---|
| 0 | 47.232061 | 21.808950 | -15.313091 | 0.0 |
| 1 | 40.600643 | 21.958906 | -16.963202 | 0.0 |
| 2 | 40.877320 | 22.204570 | 9.527929 | 1.0 |
| 3 | 37.768578 | 19.740320 | -3.940218 | 0.0 |
| 4 | 39.920257 | 21.744433 | -8.011915 | 0.0 |
| ... | ... | ... | ... | ... |
| 2495 | 49.227788 | 21.581176 | -14.529871 | 0.0 |
| 2496 | 42.352355 | 21.385861 | 19.570055 | 1.0 |
| 2497 | 43.737763 | 19.859275 | -16.342666 | 0.0 |
| 2498 | 41.189297 | 20.346711 | -9.165242 | 0.0 |
| 2499 | 45.427037 | 23.762884 | -17.604829 | 0.0 |
2500 rows × 4 columns
Estimate Average Treatment Effect (ATE)
ATE = m.CATE_df['avg.CATE'].mean()
ATE
42.29673993471417
Visualizing ATE and probability density function of CATE (using KDE plot)
fig = plt.figure(figsize=(10,5))
sns.kdeplot(m.CATE_df['avg.CATE'],shade=True)
plt.axvline(ATE,c='black')
plt.text(ATE-4,0.04,'$\hat{ATE}$',rotation=90)
Text(38.29673993471417, 0.04, '$\\hat{ATE}$')
MALTS Arguments
- outcome: Name of the outcome variable column in the data
- treatment: Name of the treatment variable column in the data
- data: Data in the pandas Dataframe format
- discrete: List of column names which are discrete (dummified); Default=[]
- C: Regularization constant; Default=1
- k_tr: Size of matched group in training step; Default=15
- k_est: Size of matched group in estimation step; Default=50
- estimator: CATE estimator inside a matched group; Default='linear'; Options: 'linear','mean' or 'RF'
- smooth_cate: Boolean to smoothen CATE estimates by fitting a regression model; Default=True
- reweight: Reweight treated and control groups as per their respective sample sizes in training step; Default=False,
- n_splits: Number of splits of the data for n_split-fold procedure; Default=5
- n_repeats: Number of repeats of the whole procedure; Default=1
- output_format: Output format of CATE dataframe; Default='brief'; Options: 'brief' or 'full'
Visualization
Looking Inside a Matched-Group
Plotting the X1 and X2 marginal of matched-group of unit "0"
MG0 = m.MG_matrix.loc[0] #fetching the matched group
matched_units_idx = MG0[MG0!=0].index #getting the indices of the matched units
matched_units = df.loc[matched_units_idx] #fetching the data of matched units
sns.lmplot(x='X1', y='X2', hue='treated', data=matched_units,palette="Set1") #plotting the MG on (X1,X2)
plt.scatter(x=[df.loc[0,'X1']],y=[df.loc[0,'X2']],c='black',s=100) #plotting the unit-0 on (X1,X2)
plt.title('Matched Group for Unit-0') #setting title of the plot
Text(0.5, 1, 'Matched Group for Unit-0')
Plotting CATE versus covariate
Plotting CATE v.s. X1
data_w_cate = df.join(m.CATE_df, rsuffix='_').drop(columns=['outcome_','treated_']) #joining cate dataframe with data
sns.regplot( x='X1', y='avg.CATE', data=data_w_cate, scatter_kws={'alpha':0.5,'s':2}, line_kws={'color':'black'}, order=2 ) #fitting a degree 2 polynomial X1 on CATE
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pymalts2-0.0.2.tar.gz.
File metadata
- Download URL: pymalts2-0.0.2.tar.gz
- Upload date:
- Size: 14.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/47.1.1 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
672304615260a4ffc86abbb3bf5774c5cede44ffed5800dc95741bbb825fd163
|
|
| MD5 |
672db0d371e2191e31c92b3fe0364172
|
|
| BLAKE2b-256 |
aa05900b49914ef0c444dc4fd3e2b0a0e81a12973d44704a8a5230c8cf30f854
|
File details
Details for the file pymalts2-0.0.2-py3-none-any.whl.
File metadata
- Download URL: pymalts2-0.0.2-py3-none-any.whl
- Upload date:
- Size: 10.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/47.1.1 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9614adad89d0f0eafcfddcfbe9efe4d1080bf4c31c5fe42c39a69236fa6c3ae7
|
|
| MD5 |
fc1ddec93d698cb4de1fa0dc3daae0fa
|
|
| BLAKE2b-256 |
1d894f581d9d250701fc72ceed387664c2782b6d45e52d9dac3168a352dccb44
|