Skip to main content

Uncertainty-Guided Likelihood-Tree Search

Project description

ULTS: Uncertainty-guided Likelihood Tree Search

Accompanying implementation of the following paper [ArXiv]:

@article{grosse2024ults,
  title={Uncertainty-Guided Optimization On Large Language Model Search Trees},
  author={Grosse, Julia and Wu, Ruotian and Rashid, Ahmad and Hennig, Philipp and Poupart, Pascal and Kristiadi, Agustinus},
  journal={arXiv preprint arXiv:TODO!},
  year={2024}
}

Setup

Requires Python >= 3.9.

  1. Install PyTorch with CUDA, version >= 2.0
  2. Install this package: pip install git+https://github.com/JuliaGrosse/ults.git@main

Usage

See full example here: examples/generate.py.

Quickstart with the Dirichlet prior

[!IMPORTANT] ULTS will first check prior_dir directory (default ./ults_priors) for a precomputed prior with your choices of width (vocab size), depth (max tokens to generate), and $\alpha$ (concentration strength). If not exists, then it will compute and cache the prior --- this might take a while! However, this only needs to be done once for each each of the choices above. In the subsequent generation call, the decoding will be very quick.

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
model = AutoModelForCausalLM.from_pretrained(
  "meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16
)
model.eval()

text = "Moose is a"
model_inputs = tokenizer(text, return_tensors="pt")

-output = model.generate(
-    **model_inputs,
-    num_beams=5,
-    max_new_tokens=40,
-)
-generated_sequence = output.sequences

+import ults
+output = ults.generate(
+    model=model,
+    model_inputs=model_inputs,
+    max_tokens=40,
+)
+generated_sequence = output.sequence

generated_text = tokenizer.decode(generated_sequence[0])

Using the Empirical Prior

On top of the default Dirichlet priors (agnostic to the LLM), ULTS can also leverage empirical priors, specific to the LLM at hand. Example precomputed empirical priors, compatible with Llama-2-7b, Mistral-7B-v0.1, and Gemma-7b, are available in examples/ults_priors.

  1. First, gather samples of the LLM's softmax outputs from different time steps. Here we will use the greedy decoding. See examples/sample_llm_outputs.py for a complete example
RESULT_DIR = f"./ults_priors/llm_output_samples/{DATASET_NAME}_{LLM_NAME}"

# Samples of contexts from your dataset
contexts: List[str]

for idx, context in enumerate(contexts):
    input_ids = tokenizer(sentence[sent_key], return_tensors="pt")["input_ids"]

    # `n_tokens` is the max. depth of the tree that you want to optimize on
    # i.e., the max number of tokens you want to generate with ULTS
    for d in range(n_tokens):
        with torch.no_grad():
            outputs = torch.softmax(model(input_ids).logits, dim=-1)

        # Save the last softmax output (this is our empirical sample for depth `d`)
        outputs = outputs[0, -1, :]
        torch.save(outputs, f"{RESULT_DIR}/sample_index{idx}_depth{d}.pt")

        # Continue greedy generation
        index = torch.argmax(qualities)
        model_input = torch.cat([model_input, index.expand(1, 1)], dim=1)

# Stack them together into a (n_samples*n_tokens, vocab_size) tensor
import glob, random
sample_files = glob.glob(f"{RESULT_DIR}/sample_*.pt")
samples = [torch.load(sample) for sample in sample_files]
torch.save(torch.vstack(samples), f'{RESULT_DIR}/all_samples.pt')
  1. Then, when specify the prior when calling ULTS. Everything else stays the same as in examples/generate.py.
output = ults.generate(
    ...
+   prior_kind="empirical",
+   prior_empirical_llm_samples=torch.load(f'{RESULT_DIR}/all_samples.pt')
    ...
)

Caveats

  1. Currently doesn't support batch generation.
  2. Huggingface optimizes the average log-likelihood. It is effectively penalizes shorter sequences. Meanwhile, ULTS optimizes the total log-likelihood, so the behavior differs from Huggingface's. There is a plan to support this in ULTS, see #36.

Development

This repo uses pdm as the dependency manager and the build system.

  1. Install pdm, see: https://pdm-project.org/en/latest/
  2. Run pdm install

All dependencies will then be installed by pdm. Moreover the current repo will be installed in an editable mode.

[!IMPORTANT] Before pushing your code, ensure that all tests pass and all linting and formatting issues are resolved.

  1. Run pytest and make sure all tests pass.
  2. Run make ruff and ensure:
    1. All codes are formatted correctly.
    2. There is no linting issue.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ults-0.1.0.tar.gz (11.4 kB view details)

Uploaded Source

Built Distribution

ults-0.1.0-py3-none-any.whl (10.7 kB view details)

Uploaded Python 3

File details

Details for the file ults-0.1.0.tar.gz.

File metadata

  • Download URL: ults-0.1.0.tar.gz
  • Upload date:
  • Size: 11.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: pdm/2.16.1 CPython/3.12.4 Darwin/23.4.0

File hashes

Hashes for ults-0.1.0.tar.gz
Algorithm Hash digest
SHA256 5f1ecbfd2fab0f7117bafbc6013aeb58186b59b229af2a4c65a4e2b164b59244
MD5 824cbc2ad926254f4a06135815cde55c
BLAKE2b-256 1aa9526c61e62c94a7ed10efc33702ed93896e894c84339d328a2737546795ec

See more details on using hashes here.

File details

Details for the file ults-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: ults-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 10.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: pdm/2.16.1 CPython/3.12.4 Darwin/23.4.0

File hashes

Hashes for ults-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 acb371f96307c55c81a0754412f280bfeb3114753754b03e02c95564225ab443
MD5 f899696f3d464985187b87847cafb21f
BLAKE2b-256 8650977b16eaeb5439926aa48125b89062ef88b89b1e1c847e14da574a9a7b1a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page