A package for sampling from intractable distributions with LLMs.
Project description
QUEST: Quality-Aware Metropolis-Hastings Sampling for Machine Translation
Gonçalo Faria, Sweta Agrawal, António Farinhas, Ricardo Rei, José G. C. de Souza, Andre Martins
Paper: arxiv link goes here
TL;DR: This paper presents a method to generate diverse and high-quality machine translations by sampling from a Gibbs distribution using the Metropolis-Hastings algorithm.
Abstract:
An important challenge in machine translation (MT) is to generate high-quality and diverse translations. Prior work has shown that the estimated likelihood from the MT model correlates poorly with translation quality. In contrast, quality evaluation metrics (such as COMET or BLEURT) exhibit high correlations with human judgments, which has motivated their use as rerankers (such as quality-aware and minimum Bayes risk decoding). However, relying on a single translation with high estimated quality increases the chances of "gaming the metric''. In this paper, we address the problem of sampling a set of high-quality and diverse translations. We provide a simple and effective way to avoid over-reliance on noisy quality estimates by using them as the energy function of a Gibbs distribution. Instead of looking for a mode in the distribution, we generate multiple samples from high-density areas through the Metropolis-Hastings algorithm, a simple Markov chain Monte Carlo approach. The results show that our proposed method leads to high-quality and diverse outputs across multiple language pairs (English$\leftrightarrow${German, Russian}) with two strong decoder-only LLMs (Alma-7b, Tower-7b).
Documentation
TBD
Quick Start Examples
Install
Install using pip (recommended):
pip install quest-decoding
Install using pip (from github):
pip install git+https://github.com/deep-spin/quest-decoding.git
Sentiment Steering
from langchain.prompts import PromptTemplate
from quest import RewardModel
from quest import VLLM
template = PromptTemplate.from_template(
"I received the following comment on a X: {tweet}. How should I respond?:\n"
) # a prompt template you define - usefull for tasks like translation.
test_input_data = [{
"tweet": "You should refrain from commenting on this matter."
}]
model = VLLM(
model_path="haoranxu/ALMA-7B",
prompt_template=template,
)
reward = RewardModel("lvwerra/distilbert-imdb") # sentiment model from HF.
chain = Quest(
input_data=test_input_data,
model=model,
reward=reward,
)
chain_outputs = chain.run(
steps=10,
use_tqdm=True,
)
print(chain_outputs.samples)
Contact
For bugs and feature requests please visit GitHub Issues. For business inquiries or professional support requests please send an e-mail.
Citation
@inproceedings{
questdecoding,
title={QUEST: Quality-Aware Metropolis-Hastings Sampling for Machine Translation},
author={Gonçalo Faria, Sweta Agrawal, António Farinhas, Ricardo Rei, José G. C. de Souza, Andre Martins},
booktitle={},
year={2024},
url={arxiv link goes here}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file quest-decoding-1.0.8.tar.gz
.
File metadata
- Download URL: quest-decoding-1.0.8.tar.gz
- Upload date:
- Size: 18.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.10.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5ebcd4b3f9e183adeaff17529731788728739a684e17046aaabad5c5f631d360 |
|
MD5 | 35013e71aa84dfa64d2c100a093bf67a |
|
BLAKE2b-256 | 8436f9644650bd0463143adc9319506ae0936df1fa77ffa4f00b2352cc6bf83d |
File details
Details for the file quest_decoding-1.0.8-py3-none-any.whl
.
File metadata
- Download URL: quest_decoding-1.0.8-py3-none-any.whl
- Upload date:
- Size: 21.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.10.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6aca470b83d0cc48db29df0de484cf972fbe3419c37e0da3b24d871475b252a3 |
|
MD5 | 91c21cfcff96d12068757d6b450a9614 |
|
BLAKE2b-256 | ff7fe581a76daac4b6ccaff12cc76a0a26ae89bdc29ff5bf0454632ee3109b09 |