Skip to main content

Region Mutual Information loss in PyTorch

Project description

PyTorch implementation of the Region Mutual Information Loss for Semantic Segmentation.

Example usage

With logits:

import torch
from rmi import RMILoss

loss = RMILoss(with_logits=True)

batch_size, classes, height, width = 5, 4, 64, 64
pred = torch.rand(batch_size, classes, height, width, requires_grad=True)
target = torch.empty(batch_size, classes, height, width).random_(2)

output = loss(pred, target)
output.backward()

With probabilities:

import torch
from torch import nn
from rmi import RMILoss

m = nn.Sigmoid()
loss = RMILoss(with_logits=False)

batch_size, classes, height, width = 5, 4, 64, 64
pred = torch.randn(batch_size, classes, height, width, requires_grad=True)
target = torch.empty(batch_size, classes, height, width).random_(2)

output = loss(m(pred), target)
output.backward()

Graphs

Plot of the value of the loss between the prediction and target without the BCE component. Target is a random binary 256x256 matrix. For Random the prediction is a 256x256 matrix of probabilities initialized uniformly at random. For All zero the prediction is a 256x256 matrix with all zeros. For 1- target the prediction is the inverse of the target. The prediction is interpolated with the target by: input_i = (1 - α) * input + α * target.

https://raw.githubusercontent.com/RElbers/region-mutual-information-pytorch/main/imgs/loss.png

Difference between this implementation and the implementation in the official git repository, with EPSILON = 0.0005 and pool='max'.

https://raw.githubusercontent.com/RElbers/region-mutual-information-pytorch/main/imgs/diff.png

Execution time on tensors with batch size of 8 and with 21 classes.

Size

This

Official

8x21x32x32

6.5722ms

6.3261ms

8x21x64x64

11.8159ms

12.6169ms

8x21x128x128

39.9946ms

40.3798ms

8x21x256x256

160.0352ms

160.9543ms

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rmi-pytorch-0.1.1.tar.gz (4.6 kB view details)

Uploaded Source

Built Distribution

rmi_pytorch-0.1.1-py3-none-any.whl (5.0 kB view details)

Uploaded Python 3

File details

Details for the file rmi-pytorch-0.1.1.tar.gz.

File metadata

  • Download URL: rmi-pytorch-0.1.1.tar.gz
  • Upload date:
  • Size: 4.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.0 setuptools/51.1.2 requests-toolbelt/0.9.1 tqdm/4.41.1 CPython/3.7.2

File hashes

Hashes for rmi-pytorch-0.1.1.tar.gz
Algorithm Hash digest
SHA256 553b69fbd08f89e0628c940159c39526dff2ea6750fbff38a83e70ce32067717
MD5 c0282077198a26e84f01578530b5de31
BLAKE2b-256 a40e0c66ab6f0197b99f5117728207d8e642ab8b89430c2e316366e0a35b673a

See more details on using hashes here.

File details

Details for the file rmi_pytorch-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: rmi_pytorch-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 5.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.0 setuptools/51.1.2 requests-toolbelt/0.9.1 tqdm/4.41.1 CPython/3.7.2

File hashes

Hashes for rmi_pytorch-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 efe0e41ccdda7828745ebc548e135709db151425a5c07c578f1e5a3a77cf24b1
MD5 c74b23cf3756e7c39419047645a13b8f
BLAKE2b-256 4a3c65abcde92e9cc2cea3f39c6b488ccf63ffdd9fd7701a6cd8d6a37fd0d127

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page