A jax implementation of optimal representative sample weighting
Project description
rswjax
A JAX implementation of optimal representative sample weighting, drawing heavily on the original rsw
implementation of Barratt, Angeris, and Boyd (rsw).
Thanks to rewriting some core operations in JAX, it is significantly faster than rsw
for medium-large datasets, especially those with many columns (ex: 5k+ rows, 20+ columns). For more thoughts on performance, see the section below. In addition,
it adds a number of quality of life improvements, like early stopping if bad optimization produces NaNs, warnings for common issues like not including loss functions for every column, and a
broader test suite.
Installation
$ pip install rswjax
Usage
rswjax
has one main user facing method, rsw
, with signature:
"""Optimal representative sample weighting.
Arguments:
- df: Pandas dataframe
- funs: functions to apply to each row of df. Function warns if len(losses) != ncols.
- losses: list of losses, each one of rswjax.EqualityLoss, rswjax.InequalityLoss, rswjax.LeastSquaresLoss,
or rswjax.KLLoss.
- regularizer: One of rswjax.ZeroRegularizer, rswjax.EntropyRegularizer,
or rswjax.KLRegularizer, rswjax.BooleanRegularizer
- lam (optional): Regularization hyper-parameter (default=1).
- kwargs (optional): additional arguments to be sent to solver. For example: verbose=True,
maxiter=5000, rho=50, eps_rel=1e-5, eps_abs=1e-5.
Returns:
- w: Final sample weights.
- out: Final induced expected values as a list of numpy arrays.
- sol: Dictionary of final ADMM variables. Can be ignored.
"""
Example usage to fit a weights set to simulated data:
import pandas as pd
import numpy as np
import rswjax
np.random.seed(605)
n = 5000
age = np.random.randint(20, 30, size=n) * 1.
sex = np.random.choice([0., 1.], p=[.4, .6], size=n)
height = np.random.normal(5, 1, size=n)
df = pd.DataFrame({
"age": age,
"sex": sex,
"height": height
})
funs = [
lambda x: x.age,
lambda x: x.sex == 0 if not np.isnan(x.sex) else np.nan,
lambda x: x.height
]
losses = [rswjax.EqualityLoss(25), rswjax.EqualityLoss(.5),
rswjax.EqualityLoss(5.3)]
regularizer = rswjax.EntropyRegularizer()
w, out, sol = rswjax.rsw(df, funs, losses, regularizer, .01, eps_abs=1e-8, verbose = True)
For more details on how one might use the package to do survey weighting, check out my recent talk, at NYOSP. The talk uses the original rsw
, but all ideas should transfer over cleanly.
Performance
rswjax
is generally faster than rsw
for medium-large datasets, especially those with many columns. As both packages take neglible amounts of time for data ~3k rows or less,
rswjax
should be superior for many but not all applications. Of course, having to also install JAX will not be worth it in many situations.
Here is a simple scaling test in n (# of rows), with structure similar to the simulated example in /examples
:
rsw
n=1,000 - 109 ms ± 8.38 ms per loop (mean ± std. dev. of 7 runs)
n=10,000 - 4.29 s ± 706 ms per loop (mean ± std. dev. of 7 runs)
n=100,000 - 46.4 s ± 3.75 s per loop (mean ± std. dev. of 7 runs)
n=1,000,000 - 3min 13s ± 6.78 s per loop (mean ± std. dev. of 7 runs)
rswjax
n=1,000 - 140 ms ± 10.2 ms per loop (mean ± std. dev. of 7 runs)
n=10,000 - 1.24 s ± 123 ms per loop (mean ± std. dev. of 7 runs)
n=100,000 - 2.71 s ± 235 ms per loop (mean ± std. dev. of 7 runs)
n=1,000,000 - 24.6 s ± 579 ms per loop (mean ± std. dev. of 7 runs)
For a rough sense of scaling in the number of columns m, consider these results on a simple test with n = 10,000 rows, and m = 20/50/100/200 columns to weight on:
rsw
m=20 - 28 s ± 7.47 s per loop (mean ± std. dev. of 7 runs)
m=50 - 57.9 s ± 9.15 s per loop (mean ± std. dev. of 7 runs)
m=100 - 1min 48s ± 9.62 s per loop (mean ± std. dev. of 7 runs)
m=200 - 2min 22s ± 8.88 s per loop (mean ± std. dev. of 7 runs)
rswjax
m=20 - 5.64 s ± 357 ms per loop (mean ± std. dev. of 7 runs)
m=50- - 7.59 s ± 692 ms per loop (mean ± std. dev. of 7 runs)
m=100 - 32.3 s ± 915 ms per loop (mean ± std. dev. of 7 runs)
m=200 - 1min 14s ± 6.6 s per loop (mean ± std. dev. of 7 runs)
(Note that this test case uses randomly generated targets and data, and is therefore hard to weight in high dimensions. Thus, most well specified real world examples should run significantly faster due to early termination.)
This speed is achieved by doing the core qdldl factorization and solve using the qdldl
package, but using JITed (just-in-time compiled) rewrites of many pieces of the admm solver, losses, and regularizers. There are still some minor opportunities to speed up the package by further refactoring the code to allow greater portions to be JITed, or by optimizing how and when data is converted back and forth between numpy
and jax.numpy
.
Running the examples
There are two examples, one on simulated data and one on the CDC BRFSS dataset. Both are due to the original package authors.
Simulated
To run the simulated example, after installing rswjax
, navigate to the examples
folder and run:
$ python simulated.py
CDC BRFSS
To run the CDC BRFSS example, first download the data:
$ cd examples/data
$ wget https://www.cdc.gov/brfss/annual_data/2018/files/LLCP2018XPT.zip
$ unzip LLCP2018XPT.zip
In the examples folder, to run all the examples in the paper, execute the following command:
$ python brfss.py
Contributing
Interested in contributing? Check out the contributing guidelines. Please note that this project is released with a Code of Conduct. By contributing to this project, you agree to abide by its terms.
Some possible ideas for contributing would be:
- making a jittable version of the steps to update f in the solver. This would require a decent amount of refactoring (so as to pass data instead of objects into the logic to update f, a requirement for jit).
- Finding opportunities to convert data back and forth less from
numpy
tojax.numpy
and vice versa. - Adding additional losses, regularizers, or examples.
License
rswjax
was created by Andrew Timm. It is licensed under the terms of the Apache License 2.0 license.
See the NOTICE file for attributions due to the original rsw
authors, whose code and paper are the primary origin of most logic in my package.
Credits
rswjax
was created with cookiecutter
and the py-pkgs-cookiecutter
template.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file rswjax-1.1.0.tar.gz
.
File metadata
- Download URL: rswjax-1.1.0.tar.gz
- Upload date:
- Size: 11.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.7.0 CPython/3.11.3 Linux/5.15.133.1-microsoft-standard-WSL2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 97062240f281264ca7d68f0d1ccdee66aec89dfa09c3e3a50fda29a88a5c7d49 |
|
MD5 | 8479ed9175d6b5e387d0341dcde30d40 |
|
BLAKE2b-256 | 0a5df3983f029303cac270acee668cd3349a71aa1f017ff3951b24dd3f814706 |
File details
Details for the file rswjax-1.1.0-py3-none-any.whl
.
File metadata
- Download URL: rswjax-1.1.0-py3-none-any.whl
- Upload date:
- Size: 11.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.7.0 CPython/3.11.3 Linux/5.15.133.1-microsoft-standard-WSL2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 444a1342777e39af7abcf7b3da3b04a08b78ab6c5cb590c7c7d295beabce955b |
|
MD5 | 02ace1753dcb654e282c86866f764ac3 |
|
BLAKE2b-256 | f8f2adc0627298306d284255508986207909db9beb4f2bb0057cfcf10e710fae |