Skip to main content

Self-adjusting Dice Loss implementation

Project description

Self-adjusting Dice Loss

This is an unofficial PyTorch implementation of the Dice Loss for Data-imbalanced NLP Tasks paper.

Usage

Installation

pip install sadice

Text classification example

import torch
from sadice import SelfAdjDiceLoss

criterion = SelfAdjDiceLoss()
# (batch_size, num_classes)
logits = torch.rand(128, 10, requires_grad=True)
targets = torch.randint(0, 10, size=(128, ))

loss = criterion(logits, targets)
loss.backward()

NER example

import torch
from sadice import SelfAdjDiceLoss

criterion = SelfAdjDiceLoss(reduction="none")
# (batch_size, num_tokens, num_classes)
logits = torch.rand(128, 40, 10, requires_grad=True)
targets = torch.randint(0, 10, size=(128, 40))

loss = criterion(logits.view(-1, 10), targets.view(-1))
loss = loss.reshape(-1, 40).mean(-1).mean()
loss.backward()

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sadice-0.1.3.tar.gz (6.4 kB view hashes)

Uploaded Source

Built Distribution

sadice-0.1.3-py3-none-any.whl (6.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page