Skip to main content

The engression loss (energy score) proposed by Shen et al. for distributional regression with a few convenient wrappers, in Pytorch.

Reason this release was yanked:

Learning

Project description

Engression - Pytorch

The engression loss (energy score) proposed by Shen et al. for distributional regression with a few convenient wrappers, in Pytorch.

The paper's original code by Xinwei Shen is available here.

Install

pip install engression-pytorch

Usage

import torch
from engression_pytorch import EnergyScoreLoss, gConcat

x = torch.randn(batch_size, input_dim)
y = torch.randn(batch_size, out_dim)
preds = torch.randn(batch_size, m, out_dim)

g = gConcat(
    model = model,
    noise_dim = 100,
    noise_type = 'normal',
    noise_scale = 1.0,
    m_train = 2, 
    m_predict = 512,
)

g.train() # change m to m_train
preds = g(x) # (batch_size, m_train, output_dim)

# loss = energy_score(y, preds, beta = 1.0, p = 2)
loss = EnergyScoreLoss(beta = 1.0, p = 2)(y, preds)
loss.backward()

g.eval() # changes m to m_predict
sample = g(x) # (batch_size, m_eval, output_dim)

Todos

  • better name for gConcat
  • test with multidim Xs?

Citations

@misc{shen2024engressionextrapolationlensdistributional,
      title={Engression: Extrapolation through the Lens of Distributional Regression}, 
      author={Xinwei Shen and Nicolai Meinshausen},
      year={2024},
      eprint={2307.00835},
      archivePrefix={arXiv},
      primaryClass={stat.ME},
      url={https://arxiv.org/abs/2307.00835}, 
}
@article{KNEIB202399,
title = {Rage Against the Mean – A Review of Distributional Regression Approaches},
journal = {Econometrics and Statistics},
volume = {26},
pages = {99-123},
year = {2023},
issn = {2452-3062},
doi = {https://doi.org/10.1016/j.ecosta.2021.07.006},
url = {https://www.sciencedirect.com/science/article/pii/S2452306221000824},
author = {Thomas Kneib and Alexander Silbersdorff and Benjamin Säfken},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

engression_pytorch-0.1.0.tar.gz (2.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

engression_pytorch-0.1.0-py3-none-any.whl (3.0 kB view details)

Uploaded Python 3

File details

Details for the file engression_pytorch-0.1.0.tar.gz.

File metadata

  • Download URL: engression_pytorch-0.1.0.tar.gz
  • Upload date:
  • Size: 2.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.6.11

File hashes

Hashes for engression_pytorch-0.1.0.tar.gz
Algorithm Hash digest
SHA256 9da6dce6f9426eb2dc581e29c6a58e6927c396a9d1b145cb5e11e231b6f9f15e
MD5 7d7495639a261120005cae9b2059f337
BLAKE2b-256 0d33bc2fb52d81d235201da4e722880d431d18c75e3483c4a8b33e8f93e6cc3f

See more details on using hashes here.

File details

Details for the file engression_pytorch-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for engression_pytorch-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ac462b4f6fe8a4b2c8bae79aacea12806c03e008acf78089effde6cfa37d5203
MD5 d84f607e2f22c512c3dc812602cb702b
BLAKE2b-256 a60f4ffecb9dfc550a69f4078fdbc55b83f3998068bc69a53d5e686e8f9c07dd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page