The contrastive token loss for reducing generative repetition of augoregressive neural language models.
Project description
Contrastive Token loss function for PyTorch
This repo is the clean (PyTorch) implementation of the contrastive token loss proposed in our paper: A Simple Contrastive Learning Objective for Alleviating Neural Text Degeneration. For reproducing our results, please check this repo.
Install
pip install ct-loss
Usage
You can use our CT objective when pretraining or finetuning your augoregressive language models.
With CT, the resulting language models will have significantly less repetitive generations, even with deterministic decoding such as greedy and beam search.
It only takes several lines of code to use CT loss, around where you calculate PyTorch's CrossEntropyLoss.
Here is an example:
import torch
# Suppose we already have the model output logits and labels (sequences of token indices).
# For example when the batch size is 10, sequence length is 50 and vocabulary size is 1000:
logits = torch.rand(10, 50, 1000)
labels = torch.randint(0, 999, (10, 50))
# This is how you normally use cross-entropy for a language model:
from torch.nn import CrossEntropyLoss
ce_criterion = CrossEntropyLoss()
ce_loss = ce_criterion(logits.view(-1, 1000), labels.view(-1))
# This is how you can use our contrastive token loss:
from ct.ct_loss import ContrastiveTokenLoss
ct_criterion = ContrastiveTokenLoss(pad_id=999) # we need pad tokens for masking out tokens in a sequence that should not be used as negative tokens
ct_loss = ct_criterion(logits, labels)
# In our paper [1], we use CE and CT together
loss = ce_loss + ct_loss
print(ce_loss, ct_loss)
>>> tensor(6.9536) tensor(1.5848)
Cite our paper
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file ct_loss-0.0.3.tar.gz.
File metadata
- Download URL: ct_loss-0.0.3.tar.gz
- Upload date:
- Size: 4.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.9.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c3cb2a98c37203c1c076b33312ea6fb5947300536bb11d4c8739143189c3be24
|
|
| MD5 |
334559e36e733d21e42fc1c1bda00d04
|
|
| BLAKE2b-256 |
2ff3475d3a13a34338cfc5f7d10402128d1c6d8e703826a82441ecbc604f81af
|
File details
Details for the file ct_loss-0.0.3-py3-none-any.whl.
File metadata
- Download URL: ct_loss-0.0.3-py3-none-any.whl
- Upload date:
- Size: 5.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.9.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
506fd18e3a40398232cd39327fb2377ead9a9681382dcb56baa6585450a0ab05
|
|
| MD5 |
f976b02522bbee6bc19e9e0b4c080cc9
|
|
| BLAKE2b-256 |
2e08535d46671b672415dc04415441f780b6edbcefc274f242d2a63b74d9a77e
|