Skip to main content

Signed Distance Function based loss functions for deep learning semantic segmentation to miss fewer instances.

Project description

SDF Loss

Signed Distance Function (SDF) based loss functions for deep learning semantic segmentation.

Overview

This library provides PyTorch loss functions that use Signed Distance Functions to weight pixels based on their distance from object boundaries. This approach puts heavier penalties on false positives and false negatives that are farther from the correct boundary, leading to more accurate segmentation results.

Installation

Install directly from PyPI:

pip install sdf-loss

Or using uv:

uv add sdf-loss

For development:

git clone https://github.com/Halyjo/sdf_loss.git
cd sdf_loss
uv sync

Quick Start

import torch
from sdf_loss import DiSCoLoss

# Initialize the loss function
criterion = DiSCoLoss()

# Your model predictions (logits) and ground truth
pred_logits = model(images)  # Shape: (B, 1, H, W)
target = ground_truth         # Shape: (B, 1, H, W), binary mask

# Compute loss
loss = criterion(pred_logits, target)
loss.backward()

Available Loss Functions

DiSCoLoss (Recommended)

Distance-scaled combination loss that combines BCE, Dice, and their SDF-weighted variants.

from sdf_loss import DiSCoLoss

# Default: only SDF-weighted losses
criterion = DiSCoLoss(
    normalize=True,           # Normalize SDF to [-1, 1]
    baseloss_weight=0,       # Weight for BCE + Dice
    sdfweighted_weight=1,    # Weight for SDF-weighted losses
    clip_negatives=False     # Clip negative distances
)

loss = criterion(pred_logits, target)

SDFWeightedBCELoss

Binary cross-entropy loss weighted by SDF differences.

from sdf_loss import SDFWeightedBCELoss

criterion = SDFWeightedBCELoss(
    reduction="mean",        # 'mean', 'sum', or 'none'
    normalize=True,
    clip_negatives=False
)

loss = criterion(pred_logits, target)

SDFWeightedDiceLoss

Dice loss weighted by SDF differences.

from sdf_loss import SDFWeightedDiceLoss

criterion = SDFWeightedDiceLoss(
    from_logits=True,
    normalize=True,
    clip_negatives=False
)

loss = criterion(pred_logits, target)

DiceLoss

Standard Dice loss with optional custom weighting function.

from sdf_loss import DiceLoss

criterion = DiceLoss(
    from_logits=True,
    smooth=0.0
)

loss = criterion(pred_logits, target)

Drop-in Replacement for BCE Loss

You can easily replace torch.nn.BCEWithLogitsLoss with DiSCoLoss:

# Before
criterion = torch.nn.BCEWithLogitsLoss()

# After - simple drop-in replacement
criterion = DiSCoLoss()

# Usage remains the same
loss = criterion(pred_logits, target)

Key Features

  • Boundary-aware: Focuses on pixels near object boundaries
  • Distance-weighted: Penalizes errors proportional to distance from correct boundary
  • PyTorch native: Fully compatible with PyTorch training loops
  • GPU compatible: Works with CUDA tensors
  • Differentiable: Full gradient flow for backpropagation
  • Flexible: Multiple loss functions and customizable parameters

Parameters

Common Parameters

  • normalize (bool): Normalize SDF values to [-1, 1] range. Default: True
  • clip_negatives (bool): Clip negative distance values to 0. Default: False
  • from_logits (bool): Whether input is logits (before sigmoid). Default: True

DiSCoLoss Specific

  • baseloss_weight (float): Weight for standard BCE + Dice losses. Default: 0
  • sdfweighted_weight (float): Weight for SDF-weighted losses. Default: 1

Examples

Basic Training Loop

import torch
from sdf_loss import DiSCoLoss

model = YourSegmentationModel()
criterion = DiSCoLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(num_epochs):
    for images, masks in dataloader:
        optimizer.zero_grad()

        pred_logits = model(images)
        loss = criterion(pred_logits, masks)

        loss.backward()
        optimizer.step()

Combining Base and SDF-weighted Losses

# Use both base losses and SDF-weighted losses
criterion = DiSCoLoss(
    baseloss_weight=0.3,      # 30% base losses
    sdfweighted_weight=0.7    # 70% SDF-weighted losses
)

Custom Reduction

from sdf_loss import SDFWeightedBCELoss

# Get per-pixel losses for custom weighting
criterion = SDFWeightedBCELoss(reduction="none")
loss = criterion(pred_logits, target)  # Shape: (B, 1, H, W)

# Apply custom weighting
custom_weights = compute_your_weights(target)
weighted_loss = (loss * custom_weights).mean()

Requirements

  • Python >= 3.10
  • PyTorch >= 1.10.0
  • NumPy >= 1.20.0
  • SciPy >= 1.7.0
  • scikit-image >= 0.19.0

Testing

Run the test suite:

pytest tests/ -v

Citation

If you use this library in your research, please cite:

@article{your_paper,
  title={Your Paper Title},
  author={Your Name},
  journal={Your Journal},
  year={2025}
}

License

This project is licensed under the MIT License - see the LICENSE file for details.

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Acknowledgments

The DiceLoss implementation is inspired by segmentation-models-pytorch (SMP).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sdf_loss-0.1.0.tar.gz (74.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sdf_loss-0.1.0-py3-none-any.whl (8.7 kB view details)

Uploaded Python 3

File details

Details for the file sdf_loss-0.1.0.tar.gz.

File metadata

  • Download URL: sdf_loss-0.1.0.tar.gz
  • Upload date:
  • Size: 74.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.4

File hashes

Hashes for sdf_loss-0.1.0.tar.gz
Algorithm Hash digest
SHA256 4c4ea5d6fc0a5fcb00cf5f7d54973d811e34b2443e5ee91dbc297c4f59475415
MD5 8cd77693e60939e37fb8e56ded1c1b63
BLAKE2b-256 1a782dcd764902f7e82a92559dec07fdf76b440fedcf1bc15854640b0958571c

See more details on using hashes here.

File details

Details for the file sdf_loss-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: sdf_loss-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 8.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.4

File hashes

Hashes for sdf_loss-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b08ff4a08d332e3845318b5034e9e18bf8430c4f2d0dc3c8565ea1b8253d3c08
MD5 ccb01c185b444aa94d6b1aec181ed0a5
BLAKE2b-256 1bfcd7cfd840b73a9774aef52dcb9760d32d3f2f022150e86afea63f86c2b2c0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page