Region Mutual Information loss in PyTorch
Project description
PyTorch implementation of the Region Mutual Information Loss for Semantic Segmentation.
Example usage
With logits:
import torch
from rmi import RMILoss
loss = RMILoss(with_logits=True)
batch_size, classes, height, width = 5, 4, 64, 64
pred = torch.rand(batch_size, classes, height, width, requires_grad=True)
target = torch.empty(batch_size, classes, height, width).random_(2)
output = loss(pred, target)
output.backward()
With probabilities:
import torch
from torch import nn
from rmi import RMILoss
m = nn.Sigmoid()
loss = RMILoss(with_logits=False)
batch_size, classes, height, width = 5, 4, 64, 64
pred = torch.randn(batch_size, classes, height, width, requires_grad=True)
target = torch.empty(batch_size, classes, height, width).random_(2)
output = loss(m(pred), target)
output.backward()
Graphs
Plot of the value of the loss between the prediction and target without the BCE component. Target is a random binary 256x256 matrix. For Random the prediction is a 256x256 matrix of probabilities initialized uniformly at random. For All zero the prediction is a 256x256 matrix with all zeros. For 1- target the prediction is the inverse of the target. The prediction is interpolated with the target by: input_i = (1 - α) * input + α * target.
Difference between this implementation and the implementation in the official git repository, with EPSILON = 0.0005 and pool='max'.
Execution time on tensors with batch size of 8 and with 21 classes.
Size |
This |
Official |
---|---|---|
8x21x32x32 |
6.5722ms |
6.3261ms |
8x21x64x64 |
11.8159ms |
12.6169ms |
8x21x128x128 |
39.9946ms |
40.3798ms |
8x21x256x256 |
160.0352ms |
160.9543ms |
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for rmi_pytorch-0.1.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | efe0e41ccdda7828745ebc548e135709db151425a5c07c578f1e5a3a77cf24b1 |
|
MD5 | c74b23cf3756e7c39419047645a13b8f |
|
BLAKE2b-256 | 4a3c65abcde92e9cc2cea3f39c6b488ccf63ffdd9fd7701a6cd8d6a37fd0d127 |