Skip to main content

A package for sampling from intractable distributions with LLMs.

Project description

QUEST: Quality-Aware Metropolis-Hastings Sampling for Machine Translation

Gonçalo Faria, Sweta Agrawal, António Farinhas, Ricardo Rei, José G. C. de Souza, Andre Martins

Paper: https://arxiv.org/abs/2406.00049

TL;DR: This paper presents a method to generate diverse and high-quality machine translations by sampling from a Gibbs distribution using the Metropolis-Hastings algorithm.

Abstract:

An important challenge in machine translation (MT) is to generate high-quality and diverse translations. Prior work has shown that the estimated likelihood from the MT model correlates poorly with translation quality. In contrast, quality evaluation metrics (such as COMET or BLEURT) exhibit high correlations with human judgments, which has motivated their use as rerankers (such as quality-aware and minimum Bayes risk decoding). However, relying on a single translation with high estimated quality increases the chances of "gaming the metric''. In this paper, we address the problem of sampling a set of high-quality and diverse translations. We provide a simple and effective way to avoid over-reliance on noisy quality estimates by using them as the energy function of a Gibbs distribution. Instead of looking for a mode in the distribution, we generate multiple samples from high-density areas through the Metropolis-Hastings algorithm, a simple Markov chain Monte Carlo approach. The results show that our proposed method leads to high-quality and diverse outputs across multiple language pairs (English$\leftrightarrow${German, Russian}) with two strong decoder-only LLMs (Alma-7b, Tower-7b).


Documentation

TBD


Quick Start Examples

Install

Install using pip (recommended):

pip install quest-decoding

Install using pip (from github):

pip install git+https://github.com/deep-spin/quest-decoding.git
Sentiment Steering
    from langchain.prompts import PromptTemplate
    from quest import RewardModel
    from quest import VLLM


    template =  PromptTemplate.from_template(
        "I received the following comment on a X: {tweet}. How should I respond?:\n"
    ) # a prompt template you define - usefull for tasks like translation. 
    
    test_input_data = [{
        "tweet": "You should refrain from commenting on this matter."
    }]

    model = VLLM(
        model_path="haoranxu/ALMA-7B",
        prompt_template=template,
    )

    reward = RewardModel("lvwerra/distilbert-imdb")  # sentiment model from HF. 
    
    chain = Quest(
        input_data=test_input_data,
        model=model,
        reward=reward,
    )
    
    chain_outputs = chain.run(
        steps=10,
        use_tqdm=True,
    )
    
    print(chain_outputs.samples)
        

Contact

For bugs and feature requests please visit GitHub Issues. For business inquiries or professional support requests please send an e-mail.


Citation

@misc{faria2024quest,
      title={QUEST: Quality-Aware Metropolis-Hastings Sampling for Machine Translation}, 
      author={Gonçalo R. A. Faria and Sweta Agrawal and António Farinhas and Ricardo Rei and José G. C. de Souza and André F. T. Martins},
      year={2024},
      eprint={2406.00049},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

quest_decoding-1.0.14.tar.gz (31.2 kB view details)

Uploaded Source

Built Distribution

quest_decoding-1.0.14-py3-none-any.whl (51.1 kB view details)

Uploaded Python 3

File details

Details for the file quest_decoding-1.0.14.tar.gz.

File metadata

  • Download URL: quest_decoding-1.0.14.tar.gz
  • Upload date:
  • Size: 31.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.8.19

File hashes

Hashes for quest_decoding-1.0.14.tar.gz
Algorithm Hash digest
SHA256 f2635638357c92d46467b9be96bd356d5092ef41aca3a45188f821c329076ef4
MD5 22971a51fdbab4d17b256cb011bb9ed2
BLAKE2b-256 d94bca79e595dc88d1031298dc5f0164c9d84294d0d9835b81b9691d514be632

See more details on using hashes here.

File details

Details for the file quest_decoding-1.0.14-py3-none-any.whl.

File metadata

File hashes

Hashes for quest_decoding-1.0.14-py3-none-any.whl
Algorithm Hash digest
SHA256 ecb3e6f707a657f6c3d4bfe23f703c30dc4d1af44575b58c2ae22f4b78eb16f6
MD5 87dc771fe6d915f744af024586a333aa
BLAKE2b-256 2e7ad3d93649d0a1710b0f1b00840270ca6e31132cf3ff312a6afabcdec22d52

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page