Skip to main content

Engression Modelling

Project description

Engression

Engression is a neural network-based distributional regression method proposed in the paper "Engression: Extrapolation through the Lens of Distributional Regression?" by Xinwei Shen and Nicolai Meinshausen (2023). This repository contains the software implementations of engression in both R and Python.

Consider targets $Y\in\mathbb{R}^k$ and predictors $X\in\mathbb{R}^d$; both variables can be univariate or multivariate, continuous or discrete. Engression can be used to

  • estimate the conditional mean $\mathbb{E}[Y|X=x]$ (as in least-squares regression),
  • estimate the conditional quantiles of $Y$ given $X=x$ (as in quantile regression), and
  • sample from the fitted conditional distribution of $Y$ given $X=x$ (as a generative model).

The results in the paper show the advantages of engression over existing regression approaches in terms of extrapolation.

Installation

The latest release of the Python package can be installed through pip:

pip install engression

The development version can be installed from github:

pip install -e "git+https://github.com/xwshen51/engression#egg=engression&subdirectory=engression-python" 

Usage Example

Python

Below is one simple demonstration. See this tutorial for more details on simulated data and this tutorial for a real data example. We demonstrate in another tutorial how to fit a bagged engression model, which also helps with hyperparameter tuning.

from engression import engression
from engression.data.simulator import preanm_simulator

## Simulate data
x, y = preanm_simulator("square", n=10000, x_lower=0, x_upper=2, noise_std=1, train=True, device=device)
x_eval, y_eval_med, y_eval_mean = preanm_simulator("square", n=1000, x_lower=0, x_upper=4, noise_std=1, train=False, device=device)

## Fit an engression model
engressor = engression(x, y, lr=0.01, num_epochs=500, batch_size=1000, device="cuda")
## Summarize model information
engressor.summary()

## Evaluation
print("L2 loss:", engressor.eval_loss(x_eval, y_eval_mean, loss_type="l2"))
print("correlation between predicted and true means:", engressor.eval_loss(x_eval, y_eval_mean, loss_type="cor"))

## Predictions
y_pred_mean = engressor.predict(x_eval, target="mean") ## for the conditional mean
y_pred_med = engressor.predict(x_eval, target="median") ## for the conditional median
y_pred_quant = engressor.predict(x_eval, target=[0.025, 0.5, 0.975]) ## for the conditional 2.5% and 97.5% quantiles

Contact information

If you meet any problems with the code, please submit an issue or contact Xinwei Shen.

Citation

If you would refer to or extend our work, please cite the following paper:

@article{10.1093/jrsssb/qkae108,
    author = {Shen, Xinwei and Meinshausen, Nicolai},
    title = {Engression: extrapolation through the lens of distributional regression},
    journal = {Journal of the Royal Statistical Society Series B: Statistical Methodology},
    pages = {qkae108},
    year = {2024},
    month = {11},
    issn = {1369-7412},
    doi = {10.1093/jrsssb/qkae108},
    url = {https://doi.org/10.1093/jrsssb/qkae108},
    eprint = {https://academic.oup.com/jrsssb/advance-article-pdf/doi/10.1093/jrsssb/qkae108/60827977/qkae108.pdf},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

engression-0.1.14.tar.gz (16.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

engression-0.1.14-py3-none-any.whl (17.0 kB view details)

Uploaded Python 3

File details

Details for the file engression-0.1.14.tar.gz.

File metadata

  • Download URL: engression-0.1.14.tar.gz
  • Upload date:
  • Size: 16.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.2

File hashes

Hashes for engression-0.1.14.tar.gz
Algorithm Hash digest
SHA256 48fdcb3cb858f17da8e8162e73337a341f0ead5836c74e9547054074b3d47363
MD5 c510121bb41e060af832157cb9589cae
BLAKE2b-256 54508d7774dd402fa31d2b0e6e82584b47644a3b11f3f4734477a690435457cf

See more details on using hashes here.

File details

Details for the file engression-0.1.14-py3-none-any.whl.

File metadata

  • Download URL: engression-0.1.14-py3-none-any.whl
  • Upload date:
  • Size: 17.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.2

File hashes

Hashes for engression-0.1.14-py3-none-any.whl
Algorithm Hash digest
SHA256 58609a0a4b91e1594247eacccb656c5847729fdf600acd34fc7c0bc8e369f15e
MD5 05590306b3ec5f0056264039fe1422bd
BLAKE2b-256 371d3b1098abed5edc74565ce0f21d9357e6506db2d072d4eac8129abf3b367c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page