Skip to main content

Code for cut cross entropy, a memory efficient implementation of linear-cross-entropy loss.

Project description

Cut Your Losses in Large-Vocabulary Language Models

This software project accompanies the research paper: Cut Your Losses in Large-Vocabulary Language Models, Erik Wijmans, Brody Huval, Alexander Hertzberg, Vladlen Koltun, and Philipp Krähenbühl.

As language models grow ever larger, so do their vocabularies. This has shifted the memory footprint of LLMs during training disproportionately to one single layer: the cross-entropy in the loss computation. Cross-entropy builds up a logit matrix with entries for each pair of input tokens and vocabulary items and, for small models, consumes an order of magnitude more memory than the rest of the LLM combined. We propose Cut Cross-Entropy (CCE), a method that computes the cross-entropy loss without materializing the logits for all tokens into global memory. Rather, CCE only computes the logit for the correct token and evaluates the log-sum-exp over all logits on the fly. We implement a custom kernel that performs the matrix multiplications and the log-sum-exp reduction over the vocabulary in flash memory, making global memory consumption for the cross-entropy computation negligible. This has a dramatic effect. Taking the Gemma 2 (2B) model as an example, CCE reduces the memory footprint of the loss computation from 24 GB to 1 MB, and the total training-time memory consumption of the classifier head from 28 GB to 1 GB. To improve the throughput of CCE, we leverage the inherent sparsity of softmax and propose to skip elements of the gradient computation that have a negligible (i.e., below numerical precision) contribution to the gradient. Experiments demonstrate that the dramatic reduction in memory consumption is accomplished without sacrificing training speed or convergence.

Getting started

Requirements

  1. Python 3.10+
  2. PyTorch 2.4+
  3. Triton 3.0+
  4. Ampere (or newer) GPU

Note: For operating systems that are not supported by Triton (e.g., MacOS), we include a highly optimized version of linear-cross-entropy using torch.compile. This implementation will be set to the default on MacOS.

Basic usage

Installation

pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git"

Usage

from cut_cross_entropy import linear_cross_entropy

embeddings = model.compute_embedding(inputs)
classifier = model.get_classifier_weights()

loss = linear_cross_entropy(embeddings, classifier, labels)

In causal language modeling, it is common that the model embeddings and labels need to be shifted such that the model predicts the next token.

from cut_cross_entropy import linear_cross_entropy

embeddings = model.compute_embedding(inputs)
classifier = model.get_classifier_weights()

shift_embeddings = embeddings[..., :-1, :].flatten(0, -2)
shift_labels = labels[..., 1:]

manual_shift_loss = linear_cross_entropy(shift_embeddings, classifier, shift_labels)

Instead, pass shift=True to perform this computation without allocating the shift_embeddings matrix.

from cut_cross_entropy import linear_cross_entropy

embeddings = model.compute_embedding(inputs)
classifier = model.get_classifier_weights()

# This is the same as manual_shift_loss above
auto_shift_loss = linear_cross_entropy(embeddings, classifier, labels, shift=True)

We also provide a highly optimized implementation of linear-cross-entropy loss using torch.compile. This is a good option for scenarios where speed is the primary goal and the model has a relatively small vocabulary compared to its hidden dimension (when |V| >> D, cce will both save memory and be faster). This option also works on the CPU and older GPUs, making it useful for testing.

from cut_cross_entropy import linear_cross_entropy

embeddings = model.compute_embedding(inputs)
classifier = model.get_classifier_weights()

loss = linear_cross_entropy(embeddings, classifier, labels, ..., impl="torch_compile")

Transformers Integration

Installation

Install cut-cross-entropy with transformers dependencies

pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git"

Usage

If you are using transformers, you can patch transformers to use CCE directly. Note that logits will no longer be returned (None will be returned instead).

from cut_cross_entropy.transformers import cce_patch

cce_patch("llama")

# or

model = ...
model = cce_patch(model)

We currently support the Llama, Phi3, Mistral, and Gemma2 families of models.

cce_patch takes two options. The first is the linear-cross-entropy implementation to use. Currently "cce" or "torch_compile".

The second is the loss reduction. We support "mean", "sum", and "none", that mirror their PyTorch counterpart. "mean" is the default and what the transformers trainer API expects. However, "none" in particular can enable for efficient computation of quantities based on the loss.

For example, the following efficiently computes the perplexity of a batch of sequences:

import transformers

from cut_cross_entropy.transformers import cce_patch


model = transformers.AutoModelForCausalLM.from_pretrained(...)

model = cce_patch(model, reduction="none")

labels = input_ids.clone()
labels[~attention_mask] = -100 # -100 is the ignore index for PyTorch and CCE.

outputs = model(input_ids, attention_mask, labels=labels)

loss = outputs[0] # A (B, T - 1) tensor because reduction="none". T - 1 because the first input token has
# no loss.

ppl = torch.exp(
    # [:, 1:] because the first token has no loss
    loss.sum(1) / (labels[:, 1:] != -100).count_nonzero(dim=1)
).mean()  # Average perplexity over the batch

Training and reproducing the benchmark results

We provide a training in training/train.py.

Installation

pip install "cut-cross-entropy[all] @ git+https://github.com/apple/ml-cross-entropy.git"

Training

Use scripts/train.sh to train a full model.

Benchmarking

The benchmark script can be run via python -m benchmark.

Expected output with A100 SMX4, PyTorch 2.4.1, and CUDA 12.4.

          method        kind  runtime_ms  op_mem_mb test_data
0            cce     loss-fw        46.4        1.1    gemma2
1  torch_compile     loss-fw        49.9     4000.1    gemma2
2       baseline     loss-fw        81.9    24000.0    gemma2
3            cce     loss-bw        89.3     1163.0    gemma2
4  torch_compile     loss-bw        92.3    12000.0    gemma2
5       baseline     loss-bw       122.4    16000.0    gemma2
6            cce  loss-fw-bw       134.8     1164.0    gemma2
7  torch_compile  loss-fw-bw       144.0    16000.1    gemma2
8       baseline  loss-fw-bw       208.8    28000.0    gemma2

Development

If dependencies are installed locally, cut-cross-entropy will work without a pip install as long as python is executed in the root path of the github repo.

To install directly from the github repo, either use an (editable) install or manipulate PYTHONPATH, e.g.

pip install -e ".[dev]"

# or
pip install ".[dev]"

# or
export PYTHONPATH=/path/to/ml-cross-entropy:${PYTHONPATH}

Citation

@article{wijmans2024cut,
  author       = {Erik Wijmans and
                  Brody Huval and
                  Alexander Hertzberg and
                  Vladlen Koltun and
                  Philipp Kr\"ahenb\"uhl},
  title        = {Cut Your Losses in Large-Vocabulary Language Models},
  journal      = {arXiv},
  year         = {2024},
  url          = {https://arxiv.org/abs/2411.09009},
}

License

This sample code is released under the LICENSE terms.

Acknowledgements

Our codebase is built using multiple opensource contributions, please see Acknowledgements for more details.

Please check the paper for a complete list of references and datasets used in this work.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cut_cross_entropy-24.12.1.tar.gz (22.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cut_cross_entropy-24.12.1-py3-none-any.whl (22.5 kB view details)

Uploaded Python 3

File details

Details for the file cut_cross_entropy-24.12.1.tar.gz.

File metadata

  • Download URL: cut_cross_entropy-24.12.1.tar.gz
  • Upload date:
  • Size: 22.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.4

File hashes

Hashes for cut_cross_entropy-24.12.1.tar.gz
Algorithm Hash digest
SHA256 8843bd8d564ebf006dcb95c8b170c1e02e8ab5d99fc748c74c8eaa10d5eee47f
MD5 9aff1cb1ddee53534d1e094583554c38
BLAKE2b-256 63c3794c3ed23916c50f0e06cccde1377cf886fbf0bc3baf4097a7fa8daa45a6

See more details on using hashes here.

File details

Details for the file cut_cross_entropy-24.12.1-py3-none-any.whl.

File metadata

File hashes

Hashes for cut_cross_entropy-24.12.1-py3-none-any.whl
Algorithm Hash digest
SHA256 6ae7250f6faa36ba9acc13cd4c672d03ea2b53640f4af4093d52430081853917
MD5 4af250f43c07fd68b195f72c9a7da2e6
BLAKE2b-256 dff18ee1c2865b741bc491dd349a0e01616a70fae8e3b53844533a8db98a6f9f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page