Signed Distance Function based loss functions for deep learning semantic segmentation to miss fewer instances.
Project description
SDF Loss
Signed Distance Function (SDF) based loss functions for deep learning semantic segmentation.
Overview
This library provides PyTorch loss functions that use Signed Distance Functions to weight pixels based on their distance from object boundaries. This approach puts heavier penalties on false positives and false negatives that are farther from the correct boundary, leading to more accurate segmentation results.
Installation
Install directly from PyPI:
pip install sdf-loss
Or using uv:
uv add sdf-loss
For development:
git clone https://github.com/Halyjo/sdf_loss.git
cd sdf_loss
uv sync
Quick Start
import torch
from sdf_loss import DiSCoLoss
# Initialize the loss function
criterion = DiSCoLoss()
# Your model predictions (logits) and ground truth
pred_logits = model(images) # Shape: (B, 1, H, W)
target = ground_truth # Shape: (B, 1, H, W), binary mask
# Compute loss
loss = criterion(pred_logits, target)
loss.backward()
Available Loss Functions
DiSCoLoss (Recommended)
Distance-scaled combination loss that combines BCE, Dice, and their SDF-weighted variants.
from sdf_loss import DiSCoLoss
# Default: only SDF-weighted losses
criterion = DiSCoLoss(
normalize=True, # Normalize SDF to [-1, 1]
baseloss_weight=0, # Weight for BCE + Dice
sdfweighted_weight=1, # Weight for SDF-weighted losses
clip_negatives=False # Clip negative distances
)
loss = criterion(pred_logits, target)
SDFWeightedBCELoss
Binary cross-entropy loss weighted by SDF differences.
from sdf_loss import SDFWeightedBCELoss
criterion = SDFWeightedBCELoss(
reduction="mean", # 'mean', 'sum', or 'none'
normalize=True,
clip_negatives=False
)
loss = criterion(pred_logits, target)
SDFWeightedDiceLoss
Dice loss weighted by SDF differences.
from sdf_loss import SDFWeightedDiceLoss
criterion = SDFWeightedDiceLoss(
from_logits=True,
normalize=True,
clip_negatives=False
)
loss = criterion(pred_logits, target)
DiceLoss
Standard Dice loss with optional custom weighting function.
from sdf_loss import DiceLoss
criterion = DiceLoss(
from_logits=True,
smooth=0.0
)
loss = criterion(pred_logits, target)
Drop-in Replacement for BCE Loss
You can easily replace torch.nn.BCEWithLogitsLoss with DiSCoLoss:
# Before
criterion = torch.nn.BCEWithLogitsLoss()
# After - simple drop-in replacement
criterion = DiSCoLoss()
# Usage remains the same
loss = criterion(pred_logits, target)
Key Features
- Boundary-aware: Focuses on pixels near object boundaries
- Distance-weighted: Penalizes errors proportional to distance from correct boundary
- PyTorch native: Fully compatible with PyTorch training loops
- GPU compatible: Works with CUDA tensors
- Differentiable: Full gradient flow for backpropagation
- Flexible: Multiple loss functions and customizable parameters
Parameters
Common Parameters
normalize(bool): Normalize SDF values to [-1, 1] range. Default:Trueclip_negatives(bool): Clip negative distance values to 0. Default:Falsefrom_logits(bool): Whether input is logits (before sigmoid). Default:True
DiSCoLoss Specific
baseloss_weight(float): Weight for standard BCE + Dice losses. Default:0sdfweighted_weight(float): Weight for SDF-weighted losses. Default:1
Examples
Basic Training Loop
import torch
from sdf_loss import DiSCoLoss
model = YourSegmentationModel()
criterion = DiSCoLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(num_epochs):
for images, masks in dataloader:
optimizer.zero_grad()
pred_logits = model(images)
loss = criterion(pred_logits, masks)
loss.backward()
optimizer.step()
Combining Base and SDF-weighted Losses
# Use both base losses and SDF-weighted losses
criterion = DiSCoLoss(
baseloss_weight=0.3, # 30% base losses
sdfweighted_weight=0.7 # 70% SDF-weighted losses
)
Custom Reduction
from sdf_loss import SDFWeightedBCELoss
# Get per-pixel losses for custom weighting
criterion = SDFWeightedBCELoss(reduction="none")
loss = criterion(pred_logits, target) # Shape: (B, 1, H, W)
# Apply custom weighting
custom_weights = compute_your_weights(target)
weighted_loss = (loss * custom_weights).mean()
Requirements
- Python >= 3.10
- PyTorch >= 1.10.0
- NumPy >= 1.20.0
- SciPy >= 1.7.0
- scikit-image >= 0.19.0
Testing
Run the test suite:
pytest tests/ -v
Citation
If you use this library in your research, please cite:
@article{your_paper,
title={Your Paper Title},
author={Your Name},
journal={Your Journal},
year={2025}
}
License
This project is licensed under the MIT License - see the LICENSE file for details.
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
Acknowledgments
The DiceLoss implementation is inspired by segmentation-models-pytorch (SMP).
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file sdf_loss-0.1.0.tar.gz.
File metadata
- Download URL: sdf_loss-0.1.0.tar.gz
- Upload date:
- Size: 74.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4c4ea5d6fc0a5fcb00cf5f7d54973d811e34b2443e5ee91dbc297c4f59475415
|
|
| MD5 |
8cd77693e60939e37fb8e56ded1c1b63
|
|
| BLAKE2b-256 |
1a782dcd764902f7e82a92559dec07fdf76b440fedcf1bc15854640b0958571c
|
File details
Details for the file sdf_loss-0.1.0-py3-none-any.whl.
File metadata
- Download URL: sdf_loss-0.1.0-py3-none-any.whl
- Upload date:
- Size: 8.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b08ff4a08d332e3845318b5034e9e18bf8430c4f2d0dc3c8565ea1b8253d3c08
|
|
| MD5 |
ccb01c185b444aa94d6b1aec181ed0a5
|
|
| BLAKE2b-256 |
1bfcd7cfd840b73a9774aef52dcb9760d32d3f2f022150e86afea63f86c2b2c0
|