Skip to main content

The engression loss (energy score) proposed by Shen et al. for distributional regression with a few convenient wrappers, in Pytorch.

Project description

Engression - Pytorch

The engression loss (energy score) proposed by Shen et al. for distributional regression with a few convenient wrappers, in Pytorch.

The paper's original code by Xinwei Shen is available here.

Install

pip install engression-pytorch

Usage

import torch
from engression_pytorch import EnergyScoreLoss, gConcat

batch_size, input_dim, out_dim = 32, 1, 1
noise_dim = 100

x = torch.randn(batch_size, input_dim)
y = torch.randn(batch_size, out_dim)

model = nn.Linear(input_dim + noise_dim, out_dim)

g = gConcat(
    model = model,
    noise_dim = noise_dim,
    noise_type = 'normal',
    noise_scale = 1.0,
    m_train = 2, 
    m_eval = 512,
)

g.train() # change m to m_train
preds = g(x) # (batch_size, m_train, output_dim)

# loss = energy_score(y, preds, beta = 1.0, p = 2)
loss = EnergyScoreLoss(beta = 1.0, p = 2)(y, preds)
loss.backward()

g.eval() # changes m to m_eval
sample = g(x) # (batch_size, m_eval, output_dim)

Citations

@misc{shen2024engressionextrapolationlensdistributional,
      title={Engression: Extrapolation through the Lens of Distributional Regression}, 
      author={Xinwei Shen and Nicolai Meinshausen},
      year={2024},
      eprint={2307.00835},
      archivePrefix={arXiv},
      primaryClass={stat.ME},
      url={https://arxiv.org/abs/2307.00835}, 
}
@article{KNEIB202399,
title = {Rage Against the Mean – A Review of Distributional Regression Approaches},
journal = {Econometrics and Statistics},
volume = {26},
pages = {99-123},
year = {2023},
issn = {2452-3062},
doi = {https://doi.org/10.1016/j.ecosta.2021.07.006},
url = {https://www.sciencedirect.com/science/article/pii/S2452306221000824},
author = {Thomas Kneib and Alexander Silbersdorff and Benjamin Säfken},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

engression_pytorch-0.1.2.tar.gz (82.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

engression_pytorch-0.1.2-py3-none-any.whl (4.9 kB view details)

Uploaded Python 3

File details

Details for the file engression_pytorch-0.1.2.tar.gz.

File metadata

  • Download URL: engression_pytorch-0.1.2.tar.gz
  • Upload date:
  • Size: 82.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.22

File hashes

Hashes for engression_pytorch-0.1.2.tar.gz
Algorithm Hash digest
SHA256 bce4210c52ee699afadf36c9a2029f1770584cfc984c2528bfb4a74cf042e174
MD5 4ad1e77e011d101df8455f206e2ada47
BLAKE2b-256 bf06afad021db2598b6b79ebe566ad00d995bd8084d7344028cf84bf2e82db01

See more details on using hashes here.

File details

Details for the file engression_pytorch-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for engression_pytorch-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 11f47c5dd34ca32a3ec5981a69679c4f61972d7705479668665c7222359c14a1
MD5 22fd71d53fce61a5ddeed2b643b339d3
BLAKE2b-256 928bc1b7f1e45fc0c9fb0c75cef5b08dfcd2c76c3ec223b4e96e997ae653cb6b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page