Skip to main content

The contrastive token loss for reducing generative repetition of augoregressive neural language models.

Project description

Contrastive Token loss function for PyTorch

This repo is the clean (PyTorch) implementation of the contrastive token loss proposed in our paper: A Simple Contrastive Learning Objective for Alleviating Neural Text Degeneration. For reproducing our results, please check this repo.

Install

pip install ct-loss

Usage

You can use our CT objective when pretraining or finetuning your augoregressive language models. With CT, the resulting language models will have significantly less repetitive generations, even with deterministic decoding such as greedy and beam search. It only takes several lines of code to use CT loss, around where you calculate PyTorch's CrossEntropyLoss. Here is an example:

import torch

# Suppose we already have the model output logits and labels (sequences of token indices).
# For example when the batch size is 10, sequence length is 50 and vocabulary size is 1000:
logits = torch.rand(10, 50, 1000)
labels = torch.randint(0, 999, (10, 50))

# This is how you normally use cross-entropy for a language model:
from torch.nn import CrossEntropyLoss
ce_criterion = CrossEntropyLoss()
ce_loss = ce_criterion(logits.view(-1, 1000), labels.view(-1))

# This is how you can use our contrastive token loss:
from ct.ct_loss import ContrastiveTokenLoss
ct_criterion = ContrastiveTokenLoss(pad_id=999) # we need pad tokens for masking out tokens in a sequence that should not be used as negative tokens
ct_loss = ct_criterion(logits, labels)

# In our paper [1], we use CE and CT together
loss = ce_loss + ct_loss

print(ce_loss, ct_loss)

>>> tensor(6.9536) tensor(1.5848)

Cite our paper

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ct_loss-0.0.3.tar.gz (4.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ct_loss-0.0.3-py3-none-any.whl (5.1 kB view details)

Uploaded Python 3

File details

Details for the file ct_loss-0.0.3.tar.gz.

File metadata

  • Download URL: ct_loss-0.0.3.tar.gz
  • Upload date:
  • Size: 4.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.9.0

File hashes

Hashes for ct_loss-0.0.3.tar.gz
Algorithm Hash digest
SHA256 c3cb2a98c37203c1c076b33312ea6fb5947300536bb11d4c8739143189c3be24
MD5 334559e36e733d21e42fc1c1bda00d04
BLAKE2b-256 2ff3475d3a13a34338cfc5f7d10402128d1c6d8e703826a82441ecbc604f81af

See more details on using hashes here.

File details

Details for the file ct_loss-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: ct_loss-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 5.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.9.0

File hashes

Hashes for ct_loss-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 506fd18e3a40398232cd39327fb2377ead9a9681382dcb56baa6585450a0ab05
MD5 f976b02522bbee6bc19e9e0b4c080cc9
BLAKE2b-256 2e08535d46671b672415dc04415441f780b6edbcefc274f242d2a63b74d9a77e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page