Probabilistic Parameteric Regression Loss (PROPEL)
Project description
PRObablistic Parametric rEgression Loss (PROPEL)
PRObabilistic Parametric rEgresison Loss (PROPEL) is a loss function that enables probabilisitic regression for a neural network. It achieves this by enabling a neural network to learn parameters of a mixture of Gaussian distribution.
Further details about the loss can be found in the paper: PROPEL: Probabilistic Parametric Regression Loss for Convolutional Neural Networks
This repository provides official pytorch implementation of PROPEL.
Installation Instructions
PROPEL can be installed using the following command
pip install torchpropel
pip install git+https://github.com/masadcv/PROPEL.git
Usage Example
import torch
import numpy as np
from torchpropel import PROPEL
# Our example has a neural network with
# output [num_batch, num_gaussians, num_dims]
num_batch = 4
num_gaussians = 6
num_dims = 3
# setting ground-truth variance sigma_gt=0.2
sigma_gt = 0.2
propel_loss = PROPEL(sigma_gt)
# ground truth targets for loss
y = torch.ones((num_batch, num_dims)) * 0.5
# example prediction - this can also be coming as output of a neural network
feat_g = np.random.randn(num_batch, num_gaussians, 2 * num_dims) * 0.5
feat_g[:, :, num_dims::] = 0.2
feat = torch.tensor(feat_g, dtype=y.dtype)
# compute the loss
L = propel_loss(feat, y)
print(L)
Documentation
Further details of each function implemented for PROPEL can be accessed at the documentation hosted at: https://masadcv.github.io/PROPEL/index.html.
Citing PROPEL
Pre-print of PROPEL can be found at: PROPEL: Probabilistic Parametric Regression Loss for Convolutional Neural Networks
If you use PROPEL in your research, then please cite:
BibTeX:
@inproceedings{asad2020propel,
title={PROPEL: Probabilistic Parametric Regression Loss for Convolutional Neural Networks},
author={Asad, Muhammad and Basaru, Rilwan and Arif, SM and Slabaugh, Greg},
booktitle={25th International Conference on Pattern Recognition},
pages={},
year={2020}}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file torchpropel-0.0.3.tar.gz
.
File metadata
- Download URL: torchpropel-0.0.3.tar.gz
- Upload date:
- Size: 6.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.63.0 importlib-metadata/4.11.3 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 15942aaf94ace32e6955ea0adbb83ac4d6af897a80d258aaa481ee71fa378e58 |
|
MD5 | 84aa93cfc7da37239ebee7127c0d2adb |
|
BLAKE2b-256 | 0e93725fb07173d6e278535384941944e4d073f8ac387783ed29b8586aff308d |