Engression Modelling
Project description
Engression
Engression is a nonlinear regression methodology proposed in the paper "Engression: Extrapolation for Nonlinear Regression?" by Xinwei Shen and Nicolai Meinshausen. This directory contains the Python implementation of engression.
Consider targets $Y\in\mathbb{R}^k$ and predictors $X\in\mathbb{R}^d$; both variables can be univariate or multivariate. Engression can be used to
- estimate the conditional mean $\mathbb{E}[Y|X=x]$ (as in least-squares regression),
- estimate the conditional quantiles of $Y$ given $X=x$ (as in quantile regression), and
- sample from the fitted conditional distribution of $Y$ given $X=x$ (as a generative model).
The results in the paper show the advantages of engression over existing regression approaches in terms of extrapolation.
Installation
The latest release of the Python package can be installed through pip:
pip install engression
The development version can be installed from github:
pip install -e "git+https://github.com/xwshen51/engression#egg=engression&subdirectory=engression-python"
Usage Example
Python
Below is one simple demonstration. See this tutorial for more details on simulated data and this tutorial for a real data example. We demonstrate in another tutorial how to fit a bagged engression model, which also helps with hyperparameter tuning.
from engression import engression
from engression.data.simulator import preanm_simulator
## Simulate data
x, y = preanm_simulator("square", n=10000, x_lower=0, x_upper=2, noise_std=1, train=True, device=device)
x_eval, y_eval_med, y_eval_mean = preanm_simulator("square", n=1000, x_lower=0, x_upper=4, noise_std=1, train=False, device=device)
## Fit an engression model
engressor = engression(x, y, lr=0.01, num_epoches=500, batch_size=1000, device="cuda")
## Summarize model information
engressor.summary()
## Evaluation
print("L2 loss:", engressor.eval_loss(x_eval, y_eval_mean, loss_type="l2"))
print("correlation between predicted and true means:", engressor.eval_loss(x_eval, y_eval_mean, loss_type="cor"))
## Predictions
y_pred_mean = engressor.predict(x_eval, target="mean") ## for the conditional mean
y_pred_med = engressor.predict(x_eval, target="median") ## for the conditional median
y_pred_quant = engressor.predict(x_eval, target=[0.025, 0.5, 0.975]) ## for the conditional 2.5% and 97.5% quantiles
Contact information
If you meet any problems with the code, please submit an issue or contact Xinwei Shen.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file engression-0.1.1.tar.gz
.
File metadata
- Download URL: engression-0.1.1.tar.gz
- Upload date:
- Size: 14.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 26eba919336ac4efcac5ddae1d18904ba99a9e53e6e1aaf4620d94875b6af873 |
|
MD5 | 6d00dceaf10080501c9775053a947b6f |
|
BLAKE2b-256 | 2c554e1d100949ab2f0523eea8369d660fd3e6b1153744635a21bc8997cf2679 |
File details
Details for the file engression-0.1.1-py3-none-any.whl
.
File metadata
- Download URL: engression-0.1.1-py3-none-any.whl
- Upload date:
- Size: 17.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e9243d17cd5296cb443d434bd42c875707f8f27c3f14d3a107819413606ed2df |
|
MD5 | 897ebe08cfc1b12a27d0fe487bcabc5b |
|
BLAKE2b-256 | 55ebdd0d5f82156df2f6781b2d22f10c2b9282283b60eea174eb023d014c41f2 |