Skip to main content

Automatic-Class-Balanced MSE Loss for PyTorch (ACB-MSE) to combat class imbalanced datasets and stabilise fluctuating loss gradients.

Project description

ACB-MSE

Automatic-Class-Balanced MSE Loss function for PyTorch (ACB-MSE) to combat class imbalanced datasets.

Language License

Table of Contents

Introduction

This repository contains the PyTorch implementation of the ACB-MSE loss function, which stands for Automatic Class Balanced Mean Squared Error, originally developed for the DEEPCLEAN3D Denoiser to combat class imbalance and stabilise loss gradient fluctuation due to dramatically varying class frequencies.

Installation

Available on PyPi

pip install acb_mse

Requirements

  • Python 3.x
  • PyTorch (tested with version 2.0.1)

Usage

Class Parameters

  • zero_weighting (float, optional): Weighting coefficient for MSE loss of zero pixels. Default is 1.
  • nonzero_weighting (float, optional): Weighting coefficient for MSE loss of non-zero pixels. Default is 1.

Inputs

  • Input (torch.Tensor): $( * )$, where $( * )$ means any number of dimensions.
  • Target (torch.Tensor): $( * )$, same shape as the input.

Returns

  • Output (float): Calculated loss value.
Example Code
import torch
from acb_mse import ACBLoss

# Select weighting for each class if not wanting to use the defualt 1:1 weighting
zero_weighting = 1.0
nonzero_weighting = 1.2

# Create an instance of the ACBMSE loss function with specified weighting coefficients
loss_function = ACBLoss(zero_weighting, nonzero_weighting)

# Dummy target image and reconstructed image tensors (assuming B=10, C=3, H=256, W=256)
target_image = torch.rand(10, 3, 256, 256)
reconstructed_image = torch.rand(10, 3, 256, 256)

# Calculate the ACBMSE loss
loss = loss_function(reconstructed_image, target_image)
print("ACB-MSE Loss:", loss)

Methodology and Equations

  1. Two masks are created from the target (label) image:
  • zero_mask: A boolean mask where elements are True for zero-valued pixels in the target image.
  • nonzero_mask: A boolean mask where elements are True for non-zero-valued pixels in the target image.
  1. The pixel values from both the target image and the reconstructed image corresponding to zero and non-zero masks are extracted.
  2. The mean squared error loss as calculated between the target and the input for each mask.
  3. The two loss values are then multiplied by the corresponding weighting coefficients (zero_weighting and nonzero_weighting) to allow user to adjust the balance from default 1:1.
  4. The weighted balanced MSE loss is returned as the final value.

The function relies on the knowledge of the indices for all hits and non-hits in the true label image, which are then compared to the values in the corresponding index's in the recovered image. Therefore, ACB-MSE is unsuitable for unsupervised learning tasks. The ACB-MSE loss function is given by:

$$ \text{Loss} = A(\frac{1}{N _ h}\sum _ {i = 1} ^ {N _ h}(y _ i - \hat{y} _ i) ^ 2) + B(\frac{1}{N _ n}\sum _ {i = 1} ^ {N _ n}(y _ i - \hat{y} _ i) ^ 2) $$

where $y_i$ is the true value of the $i$-th pixel in the class, $\hat{y}_i$ is the predicted value of the $i$-th pixel in the class, and $n$ is the total number of pixels in the class (in our case labeled as $N_h$ and $N_n$ corresponding to 'hits' and 'no hits' classes, but can be extended to n classes). This approach to the loss function calculation takes the mean square of each class separately, when summing the separate classes errors back together they are automatically scaled by the inverse of the class frequency, normalising the class balance to 1:1. The additional coefficients $A$ and $B$ allow the user to manually adjust the balance to fine tune the balance.

Benefits

The ACB-MSE loss function was designed for data taken from particle detectors which often have a majority of 'pixels' which are unlit and a very sparse pattern of lit pixels. In this scenario the ACB-MSE loss provides two main benefits, addressing the class imbalance beteen lit and unlit pixels whilst also stabilising the loss gradient during training. Additonal parameters, 'A' and 'B', are provided to allow the user to set a custom balance between classes.

Variable Class Size - Training Stability

Fluctuations in the number of hit pixels across images during training can disrupt loss stability. ACB-MSE remedies this by dynamically adjusting loss function weights to reflect class frequencies in the target.

Alternative Text

The above plot demonstrates how each of the loss functions (ACB-MSE, MSE and MAE) behave based on the number of hits in the true signal. Two dummy images were created, the first image contains a simulated signal and the recovered image is created with 50% of that signal correctly identified, simulating a 50% signal recovery by the network. To generate the plot the first image was filled in two pixel increments with the second image following at a constant 50% recovery, and at each iteration the loss is calculated for the pair of images. We can see how the MSE and MAE functions loss varies as the size of the signal is increased with the recovery percentage fixed at 50%, whereas the ACB-MSE loss stays constant regardless of the frequency of the signal class.

Class Imbalance - Local Minima

Class imbalance is an issue that can arise where the interesting features are contained in the minority class. In the case of the DEEPCLEAN3D data, the input images contained 11,264 total pixels with only around 200 of them being hits. For the network, guessing that all the pixels are non-hits (zero valued) yields a very respectable reconstruction loss and is a simple transfer function for the network to learn, this local minima proved hard for the network to escape from. Class balancing based on class frequency is a simple solution to this problem that shifts the loss landscape, making it less favorable for the network to guess all pixels as non-hits. This enabled the DEEPCLEAN3D network to escape the local minima and begin to learn a usefull transfer function for the input fetures.

License

This project is licensed under the MIT License - see the LICENSE.md file for details.

Contributions

Contributions to this codebase are welcome! If you encounter any issues or have suggestions for improvements please open an issue or a pull request on the GitHub repository.

Contact

For any inquiries, feel free to reach out to me at adillwmaa@gmail.com.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

acb_mse-1.0.6.tar.gz (76.2 kB view details)

Uploaded Source

Built Distribution

acb_mse-1.0.6-py3-none-any.whl (5.4 kB view details)

Uploaded Python 3

File details

Details for the file acb_mse-1.0.6.tar.gz.

File metadata

  • Download URL: acb_mse-1.0.6.tar.gz
  • Upload date:
  • Size: 76.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for acb_mse-1.0.6.tar.gz
Algorithm Hash digest
SHA256 e89a89fea4de14344745b3dd1ce538d9dc7be5970379e0b3a183e4a900140411
MD5 253f5a73cd5577c6d91ef3946ab7c717
BLAKE2b-256 8e8edfbfe724fcaaeff0e5034236b351eb01123e8c720ebcf8aceb7f8d71edbd

See more details on using hashes here.

File details

Details for the file acb_mse-1.0.6-py3-none-any.whl.

File metadata

  • Download URL: acb_mse-1.0.6-py3-none-any.whl
  • Upload date:
  • Size: 5.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for acb_mse-1.0.6-py3-none-any.whl
Algorithm Hash digest
SHA256 71b99f3edb126cf725d183b00a284dbbee547620bc8b2fcecc0d3880565f2cc1
MD5 e6751948cd64211d418f45da19c52e9d
BLAKE2b-256 1a7f1600df2ce55bc0092582efa2f3be3b129b97cfd576cf2006dce582801892

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page