A fair loss function
Project description
A fair PyTorch loss function
The goal of this loss function is to take fairness into account during the training of a PyTorch model. It works by adding a fairness measure to a regular loss value, following this equation:
Installation
pip install fair-loss
Example
import torch
from fair_loss import FairLoss
model = torch.nn.Sequential(torch.nn.Linear(5, 1), torch.nn.ReLU())
data = torch.randint(0, 5, (100, 5), dtype=torch.float, requires_grad=True)
y_true = torch.randint(0, 5, (100, 1), dtype=torch.float)
y_pred = model(data)
# Let's say the sensitive attribute is in the second dimension
dim = 1
criterion = FairLoss(torch.nn.MSELoss(), data[:, dim].detach().unique(), accuracy)
loss = criterion(data[:, dim], y_pred, y_true)
loss.backward()
Documentation
See the documentation.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
fair_loss-0.5.tar.gz
(29.8 kB
view details)
Built Distribution
fair_loss-0.5-py3-none-any.whl
(16.3 kB
view details)
File details
Details for the file fair_loss-0.5.tar.gz
.
File metadata
- Download URL: fair_loss-0.5.tar.gz
- Upload date:
- Size: 29.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/53.0.0 requests-toolbelt/0.9.1 tqdm/4.56.0 CPython/3.9.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ce14885a0b4ca91d6e9a2bbed38462394761f5c460f4cd2ac4776c1ba25439c4 |
|
MD5 | 9797df98d27e094ce2004579ba6fec85 |
|
BLAKE2b-256 | e8e633fbaaa2bc3f2d0e86cc7c09edc75083ea987ec0cd2cc2797649e86173c4 |
File details
Details for the file fair_loss-0.5-py3-none-any.whl
.
File metadata
- Download URL: fair_loss-0.5-py3-none-any.whl
- Upload date:
- Size: 16.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/53.0.0 requests-toolbelt/0.9.1 tqdm/4.56.0 CPython/3.9.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 66046aa130cd303bf84c8d02e2f072f2141783c5039e8b16cbf7f597f2d3fe2b |
|
MD5 | 1b78ce6336fabcadf73eebb89bb20ab6 |
|
BLAKE2b-256 | ad56e33ac0716abd294f6d586ecd106f026372ce252b709f27621e7c95f8d5af |