Region Mutual Information loss in PyTorch
Project description
PyTorch implementation of the Region Mutual Information Loss for Semantic Segmentation.
Example usage
With logits:
import torch
from rmi import RMILoss
loss = RMILoss(with_logits=True)
batch_size, classes, height, width = 5, 4, 64, 64
pred = torch.rand(batch_size, classes, height, width, requires_grad=True)
target = torch.empty(batch_size, classes, height, width).random_(2)
output = loss(pred, target)
output.backward()
With probabilities:
import torch
from torch import nn
from rmi import RMILoss
m = nn.Sigmoid()
loss = RMILoss(with_logits=False)
batch_size, classes, height, width = 5, 4, 64, 64
pred = torch.randn(batch_size, classes, height, width, requires_grad=True)
target = torch.empty(batch_size, classes, height, width).random_(2)
output = loss(m(pred), target)
output.backward()
Graphs
Plot of the value of the loss between the prediction and target without the BCE component. Target is a random binary 256x256 matrix. For Random the prediction is a 256x256 matrix of probabilities initialized uniformly at random. For All zero the prediction is a 256x256 matrix with all zeros. For 1- target the prediction is the inverse of the target. The prediction is interpolated with the target by: input_i = (1 - α) * input + α * target.
Difference between this implementation and the implementation in the official git repository, with EPSILON = 0.0005 and pool='max'.
Execution time on tensors with batch size of 8 and with 21 classes.
Size |
This |
Official |
---|---|---|
8x21x32x32 |
6.5722ms |
6.3261ms |
8x21x64x64 |
11.8159ms |
12.6169ms |
8x21x128x128 |
39.9946ms |
40.3798ms |
8x21x256x256 |
160.0352ms |
160.9543ms |
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file rmi-pytorch-0.1.1.tar.gz
.
File metadata
- Download URL: rmi-pytorch-0.1.1.tar.gz
- Upload date:
- Size: 4.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.0 setuptools/51.1.2 requests-toolbelt/0.9.1 tqdm/4.41.1 CPython/3.7.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 553b69fbd08f89e0628c940159c39526dff2ea6750fbff38a83e70ce32067717 |
|
MD5 | c0282077198a26e84f01578530b5de31 |
|
BLAKE2b-256 | a40e0c66ab6f0197b99f5117728207d8e642ab8b89430c2e316366e0a35b673a |
File details
Details for the file rmi_pytorch-0.1.1-py3-none-any.whl
.
File metadata
- Download URL: rmi_pytorch-0.1.1-py3-none-any.whl
- Upload date:
- Size: 5.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.0 setuptools/51.1.2 requests-toolbelt/0.9.1 tqdm/4.41.1 CPython/3.7.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | efe0e41ccdda7828745ebc548e135709db151425a5c07c578f1e5a3a77cf24b1 |
|
MD5 | c74b23cf3756e7c39419047645a13b8f |
|
BLAKE2b-256 | 4a3c65abcde92e9cc2cea3f39c6b488ccf63ffdd9fd7701a6cd8d6a37fd0d127 |