TensorFlow implementation of focal loss.
Project description
TensorFlow implementation of focal loss [1]: a loss function generalizing binary and multiclass cross-entropy loss that penalizes hard-to-classify examples.
The focal_loss package provides functions and classes that can be used as off-the-shelf replacements for tf.keras.losses functions and classes, respectively.
# Typical tf.keras API usage import tensorflow as tf from focal_loss import BinaryFocalLoss model = tf.keras.Model(...) model.compile( optimizer=..., loss=BinaryFocalLoss(gamma=2), # Used here like a tf.keras loss metrics=..., ) history = model.fit(...)
The focal_loss package includes the functions
- binary_focal_loss
- sparse_categorical_focal_loss
and wrapper classes
- BinaryFocalLoss (use like tf.keras.losses.BinaryCrossentropy)
- SparseCategoricalFocalLoss (use like tf.keras.losses.SparseCategoricalCrossentropy)
Documentation is available at Read the Docs.
Installation
The focal_loss package can be installed using the pip utility. For the latest version, install directly from the package’s GitHub page:
pip install git+https://github.com/artemmavrin/focal-loss.git
Alternatively, install a recent release from the Python Package Index (PyPI):
pip install focal-loss
Note. To install the project for development (e.g., to make changes to
the source code), clone the project repository from GitHub and run
make dev
:
git clone https://github.com/artemmavrin/focal-loss.git cd focal-loss # Optional but recommended: create and activate a new environment first make dev
This will additionally install the requirements needed to run tests, check code coverage, and produce documentation.
References
[1] | T. Lin, P. Goyal, R. Girshick, K. He and P. Dollár. Focal loss for dense object detection. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2018. (DOI) (arXiv preprint) |
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for focal_loss-0.0.7-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2dad215281050019eb0ffdb4903335de6ef614f5576b08d2961bf1f2493eba42 |
|
MD5 | 0b6aeb213f7a4d1eb30fa4eef8047c86 |
|
BLAKE2-256 | 3aa1e362a8b955a417a6b37ae31088ecc3ad0fc31b50265c0924a969487eacf6 |