Skip to main content

Uncertainty-weighted multi-task loss for PyTorch (Kendall et al. 2018)

Project description

mtl_uncertainty_loss

Uncertainty-weighted multi-task loss for PyTorch, based on Kendall et al. (2018): "Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics" (arXiv:1705.07115)

Features

  • Arbitrary number of tasks
  • Learnable uncertainty parameters (log σ per task)
  • PyTorch nn.Module interface — log_sigmas are registered parameters and updated during training
  • Easy to monitor uncertainties via get_sigmas()

Installation

pip install mtl_uncertainty_loss

Usage

from mtl_uncertainty_loss import UncertaintyWeightedMultiTaskLoss

loss_fn = UncertaintyWeightedMultiTaskLoss(num_tasks=3)

# Pass already-reduced (scalar) per-task losses
total_loss = loss_fn([loss1, loss2, loss3])
total_loss.backward()

print(loss_fn.get_sigmas())  # tensor of current σ values

Testing

pip install -e ".[dev]"
pytest tests

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mtl_uncertainty_loss-0.1.0.tar.gz (3.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mtl_uncertainty_loss-0.1.0-py3-none-any.whl (3.8 kB view details)

Uploaded Python 3

File details

Details for the file mtl_uncertainty_loss-0.1.0.tar.gz.

File metadata

  • Download URL: mtl_uncertainty_loss-0.1.0.tar.gz
  • Upload date:
  • Size: 3.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for mtl_uncertainty_loss-0.1.0.tar.gz
Algorithm Hash digest
SHA256 d449e4aef6a1539956954babbc53bdab82f75702d9d2a0e11cc5f1f960fa18c2
MD5 b7e0676d4fa50bbedd151385bc4c977c
BLAKE2b-256 000658217bf3b5e47b44121a72322dbc865fe3d4209dab7d76ac49d9b7e1f08a

See more details on using hashes here.

Provenance

The following attestation bundles were made for mtl_uncertainty_loss-0.1.0.tar.gz:

Publisher: publish.yml on elna4os/mtl_uncertainty_loss

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file mtl_uncertainty_loss-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for mtl_uncertainty_loss-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f3f053806d710c58b9688353c03b2047aa87671a7d252d7c52454a03de03f74f
MD5 46471632140ca477fd1232bf85fe344d
BLAKE2b-256 5219db8181d3c2ee77920c3caf541bf35f8487b05ebba485a79090ef8560aef8

See more details on using hashes here.

Provenance

The following attestation bundles were made for mtl_uncertainty_loss-0.1.0-py3-none-any.whl:

Publisher: publish.yml on elna4os/mtl_uncertainty_loss

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page