Skip to main content

Engression Modelling

Project description

Engression

Engression is a nonlinear regression methodology proposed in the paper "Engression: Extrapolation for Nonlinear Regression?" by Xinwei Shen and Nicolai Meinshausen. This directory contains the Python implementation of engression.

Consider targets $Y\in\mathbb{R}^k$ and predictors $X\in\mathbb{R}^d$; both variables can be univariate or multivariate. Engression can be used to

  • estimate the conditional mean $\mathbb{E}[Y|X=x]$ (as in least-squares regression),
  • estimate the conditional quantiles of $Y$ given $X=x$ (as in quantile regression), and
  • sample from the fitted conditional distribution of $Y$ given $X=x$ (as a generative model).

The results in the paper show the advantages of engression over existing regression approaches in terms of extrapolation.

Installation

The latest release of the Python package can be installed through pip:

pip install engression

The development version can be installed from github:

pip install -e "git+https://github.com/xwshen51/engression#egg=engression&subdirectory=engression-python" 

Usage Example

Python

Below is one simple demonstration. See this tutorial for more details on simulated data and this tutorial for a real data example. We demonstrate in another tutorial how to fit a bagged engression model, which also helps with hyperparameter tuning.

from engression import engression
from engression.data.simulator import preanm_simulator

## Simulate data
x, y = preanm_simulator("square", n=10000, x_lower=0, x_upper=2, noise_std=1, train=True, device=device)
x_eval, y_eval_med, y_eval_mean = preanm_simulator("square", n=1000, x_lower=0, x_upper=4, noise_std=1, train=False, device=device)

## Fit an engression model
engressor = engression(x, y, lr=0.01, num_epoches=500, batch_size=1000, device="cuda")
## Summarize model information
engressor.summary()

## Evaluation
print("L2 loss:", engressor.eval_loss(x_eval, y_eval_mean, loss_type="l2"))
print("correlation between predicted and true means:", engressor.eval_loss(x_eval, y_eval_mean, loss_type="cor"))

## Predictions
y_pred_mean = engressor.predict(x_eval, target="mean") ## for the conditional mean
y_pred_med = engressor.predict(x_eval, target="median") ## for the conditional median
y_pred_quant = engressor.predict(x_eval, target=[0.025, 0.5, 0.975]) ## for the conditional 2.5% and 97.5% quantiles

Contact information

If you meet any problems with the code, please submit an issue or contact Xinwei Shen.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

engression-0.1.9.tar.gz (16.2 kB view details)

Uploaded Source

Built Distribution

engression-0.1.9-py3-none-any.whl (19.6 kB view details)

Uploaded Python 3

File details

Details for the file engression-0.1.9.tar.gz.

File metadata

  • Download URL: engression-0.1.9.tar.gz
  • Upload date:
  • Size: 16.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.2

File hashes

Hashes for engression-0.1.9.tar.gz
Algorithm Hash digest
SHA256 d4d07a4f6712655a4fb8b74010f9b59324bba5b7901bd6196a69dc81547930e9
MD5 3d750096c9989e4c4dd520988114b1bf
BLAKE2b-256 3448de16777d0618c76537945122d077239999f65e066ea7faa8866e4ea34117

See more details on using hashes here.

File details

Details for the file engression-0.1.9-py3-none-any.whl.

File metadata

  • Download URL: engression-0.1.9-py3-none-any.whl
  • Upload date:
  • Size: 19.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.2

File hashes

Hashes for engression-0.1.9-py3-none-any.whl
Algorithm Hash digest
SHA256 a6a7ae4da3a4baff20a8d77bac4bcf9c465857b7672163e1d77f747e6ea21d5c
MD5 d4f8923d9a388848e47468bb73725e00
BLAKE2b-256 35416c1ce56e37b5665a3b7639c11056b29a62232bb2c1b22b1d6d56bd9ae4d3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page