Skip to main content

The contrastive token loss for reducing generative repetition of augoregressive neural language models.

Project description

Contrastive Token loss function for PyTorch

This repo is the clean (PyTorch) implementation of the contrastive token loss proposed in our paper: A Simple Contrastive Learning Objective for Alleviating Neural Text Degeneration. For reproducing our results, please check this repo.

Install

pip install ct-loss

Usage

You can use our CT objective when pretraining or finetuning your augoregressive language models. With CT, the resulting language models will have significantly less repetitive generations, even with deterministic decoding such as greedy and beam search. It only takes several lines of code to use CT loss, around where you calculate PyTorch's CrossEntropyLoss. Here is an example:

import torch

# Suppose we already have the model output logits and labels (sequences of token indices).
# For example when the batch size is 10, sequence length is 50 and vocabulary size is 1000:
logits = torch.rand(10, 50, 1000)
labels = torch.randint(0, 999, (10, 50))

# This is how you normally use cross-entropy for a language model:
from torch.nn import CrossEntropyLoss
ce_criterion = CrossEntropyLoss()
ce_loss = ce_criterion(logits.view(-1, 1000), labels.view(-1))

# This is how you can use our contrastive token loss:
from ct.ct_loss import ContrastiveTokenLoss
ct_criterion = ContrastiveTokenLoss(pad_id=999) # we need pad tokens for masking out tokens in a sequence that should not be used as negative tokens
ct_loss = ct_criterion(logits, labels)

# In our paper [1], we use CE and CT together
loss = ce_loss + ct_loss

print(ce_loss, ct_loss)

>>> tensor(6.9536) tensor(1.5848)

Cite our paper

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ct_loss-0.0.3.tar.gz (4.6 kB view hashes)

Uploaded Source

Built Distribution

ct_loss-0.0.3-py3-none-any.whl (5.1 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page