Skip to main content

Minimum Bayes risk decoding for Hugging Face Transformers

Project description

mbr 🔥

Main PyPI

mbr adds Sampling-based Minimum Bayes Risk decoding to Hugging Face transformers. Originally proposed by Eikema & Aziz (2022), this technique is a risk-minimizing algorithm for generating text with a language model.

Pronounce: ember /ˈɛm.bɚ/

Installation

pip install mbr

Requirements:

  • Python >= 3.9
  • PyTorch

Usage

The main components of mbr are:

  • mbr.MBRGenerationMixin: overrides a model's generate method to add MBR decoding.
  • mbr.MBRGenerationConfig: specifies the parameters of MBR decoding, e.g., the number of samples to generate and the metric to optimize.

1. Load a Hugging Face transformers model

Models need to inherit from MBRGenerationMixin for MBR decoding to work. Here's two ways to achieve this, using the Llama model as an example:

Variant A:

from transformers import LlamaForCausalLM

from mbr import MBRGenerationMixin

class MBRLlamaForCausalLM(MBRGenerationMixin, LlamaForCausalLM):
    pass

Then, you can use MBRLlamaForCausalLM as you would use LlamaForCausalLM:

model = MBRLlamaForCausalLM.from_pretrained(...)

Variant B:

from mbr import MBR
model = MBR(LlamaForCausalLM).from_pretrained(...)

2. Configure MBR decoding

Create an MBRConfig object to pass to the model's generate method:

from mbr import MBRConfig

mbr_config = MBRConfig(
    num_samples=10,
    metric="chrf",
)

3. Generate text as usual

Call the model's generate method directly, or use the Pipeline API. Make sure to pass the mbr_config, as well as the model's tokenizer.

from transformers import pipeline

generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
output = generator("Hello,", mbr_config=mbr_config, tokenizer=tokenizer)

How MBR decoding works

The following research papers, among many others, provide a description of Sampling-based Minimum Bayes Risk decoding:

In practice, MBR decoding is most commonly implemented as follows (on the example of machine translation):

  • Instead of searching for the single most probable output sequence (e.g., using beam search), generate a number of samples.
  • Score each sample against the other samples using a metric (e.g., BLEU).
  • Return the sample with the highest score. Intuitively, this can be seen as returning the median of all samples.
Illustration of MBR decoding

The terminology around MBR decoding varies:

Term used in this codebase Alternative terms
samples candidates, hypotheses
references pseudo-references, evidence
metric score expected utility
(negative) expected risk, error

Details

Configuring the sampling

The generation of the samples can be customized by passing a generation_config to the generate method or to the pipeline call:

from transformers import GenerationConfig

generation_config = GenerationConfig.from_pretrained("mymodel",
    do_sample=True,
    num_beams=1,
    epsilon_cutoff=0.02,
)
model.generate(..., generation_config=generation_config)

Separate set of references

By default, the samples themselves are used a references (or a subset of the samples if num_references is smaller than num_samples).

You could also sample the reference set independently, using a custom generation config for the references:

from transformers import GenerationConfig

references_config = GenerationConfig.from_pretrained("mymodel",
    do_sample=True,
    num_beams=1,
    top_p=0.9,
)
model.generate(..., references_config=references_config)

Choosing a metric

By default, mbr integrates metrics via the Hugging Face Evaluate library.

A full list of metrics is found here. Some typical choices are:

In the MBR config, you can either specify the metric's name (e.g., "chrf", "comet") or pass an evaluate.Metric object directly.

Since different metrics output differently structured dicts, you need to specify the metric_output_field that should be used as the metric score.

from evaluate import load

metric = load('chrf')
mbr_config = MBRGenerationConfig(
    metric=metric,
    metric_output_field="score",  # the ChrF metric returns a dict with a "score" field
    ...
)

Customizing the metric computation

Internally, mbr will call the metric's compute method to calculate the metric score for each sample.

By default, mbr will call compute separately for each sample–reference pair. Since this requires many compute calls, it can make sense to optimize the metric computation. Different metrics will require different optimization strategies. To override the default way of calling the metric, define a MetricRunner class and pass it to the generate method:

from mbr import MetricRunner

class MyMetricRunner(MetricRunner):

    def __call__(self,
                 input_ids: torch.LongTensor,
                 sample_ids: Tuple[torch.LongTensor],
                 reference_ids: Tuple[torch.LongTensor],
                 ) -> torch.FloatTensor:
        ...  # TODO: implement your efficient metric computation here
        
model.generate(..., metric_runner=MyMetricRunner())

For COMET, an optimized implementation is already provided in CometMetricRunner:

from mbr.metrics.comet import CometMetricRunner

mbr_config = MBRGenerationConfig(
    ...,
    metric="comet",
    metric_output_field="mean_score",
)

metric_runner = CometMetricRunner(mbr_config, tokenizer)
model.generate(..., metric_runner=metric_runner)

Optimizations

MBR decoding is notoriously slow. mbr implements some optimizations:

  • Cached encoder outputs: For encoder-decoder models, the encoder outputs are computed only once and reused during sampling.
  • Cached metric: The metric is computed only once for each unique sample–reference pair (since there will be duplicate samples and references).
  • Optimized COMET metric: Inspired by Amrhein & Sennrich (2022), sequence embeddings are cached and reused for all pairwise comparisons.

Example scripts

The experiments directory contains the code for reproductions of experiments from the following papers:

Related projects

Changelog

  • v0.2.0
    • Breaking change: Rename MBRGenerationConfig to MBRConfig
    • Breaking change: MetricRunner now returns a MetricOutput dict instead of the raw tensor of scores.
    • Make the size of the metric cache configurable via MBRConfig.metric_cache_size
    • Allow that the number of references can be larger than the number of samples (if generated separately from the samples).
    • Remove GenerationConfig as parent class of MBRConfig

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mbr-0.2.0.tar.gz (1.1 MB view details)

Uploaded Source

Built Distribution

mbr-0.2.0-py3-none-any.whl (22.8 kB view details)

Uploaded Python 3

File details

Details for the file mbr-0.2.0.tar.gz.

File metadata

  • Download URL: mbr-0.2.0.tar.gz
  • Upload date:
  • Size: 1.1 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.6

File hashes

Hashes for mbr-0.2.0.tar.gz
Algorithm Hash digest
SHA256 a50b95802d5ea787521d36e6a2c510736984f7f9d947d3fab3eec95f7610f535
MD5 d12f3d3325aeced4944dae3d9b359bec
BLAKE2b-256 92ad229333c301931d3bd934461b5c72d3bab6cb09a6cadc0fc3e8b1b2dc174c

See more details on using hashes here.

File details

Details for the file mbr-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: mbr-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 22.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.6

File hashes

Hashes for mbr-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 efc6b03e6ba5a3f4fc5fd89b8947028b40d812f6236c5946e7ce3e7c9967ffc7
MD5 5e97d55dede3c45ceecd102c5b9c34c9
BLAKE2b-256 f184f563aeae2413a19a6bb17d19e1092d5aa43031a36d2a7c7237b55743c472

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page