Skip to main content

A Gradient-based Prompt Optimizer for Text Generation

Project description

GReaTer

arXiv colab

GReaTer: Gradients over Reasoning Makes Smaller Language Models Strong Prompt Optimizers

Sarkar Snigdha Sarathi Das, Ryo Kamoi, Bo Pang, Yusen Zhang, Caiming Xiong, Rui Zhang

Overview

overview

Three key components of GReaTer are the following:

  • The language model fLLM generates token candidates by conditioning on input samples.
  • fLLM uses task input and current prompt to generate reasoning and extract final answer logits.
  • The logits are used to calculate loss and compute gradient over generated reasoning with respect to the candidate tokens. These gradients determine the selection of candidate token to update the current position of the current prompt.

Installation

Not implemented yet. Don't really do the following. Refer to example.py to learn how to use for now.

pip install GReaTer

Usage

  1. create an input dataset for optimization

    from GReaTer import GreaterDataSet
    
    # There are two ways to create a dataset
    # 1. Load a pre-defined dataset from a json file
    dataset1 = GreaterDataSet(data_path="./data/boolean_expressions.jsonl")
    
    # 2. Create a dataset from scratch
    # custom_inputs is a list of dictionaries, each dictionary is suppposed to contain a question, a prompt, and an answer
    dataset2 = GreaterDataSet(custom_inputs=[
        {
            "question": "((-1 + 2 + 9 * 5) - (-2 + -4 + -4 * -7)) =", 
            "prompt": "Use logical reasoning and think step by step.", 
            "answer": "24"
        },
        {
            "question": "((-9 * -5 - 6 + -2) - (-8 - -6 * -3 * 1)) =",
            "prompt": "Use logical reasoning and think step by step.",
            "answer": "63"
        },
        {
            "question": "((3 * -3 * 6 + -5) - (-2 + -7 - 7 - -7)) =",
            "prompt": "Use logical reasoning and think step by step.",
            "answer": "-50"
        }
    ])
    
  2. define the optimize config, for details please refer to the our documentation page

    # optimizer config
    optimize_config = {
        "intersect_q": 5,
        "candidates_topk": 10,
        "loss_function": F.cross_entropy,
        "perplexity_loss": True,
        "perplexity_lambda": 0.2,
        "generate_config": {
            "temperature": 0.6,
            "max_new_tokens": 1024
        }
    }
    
  3. load the model and tokenizer to initialize the optimizer

    # So far we support Llama-3 and Gemma-2 family models
    # You could use transformers to load the model and tokenizer
    from transformers import LlamaForCausalLM, LlamaTokenizer
    
    model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
    tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
    
    # initialize the optimizer with the model, tokenizer, and optimize config
    optimizer = GreaterOptimizer(
        model=model, tokenizer=tokenizer, optimize_config=optimize_config
    )
    
  4. optimize the prompt

    # optimize the prompt, optimizer will return a dict containing original question and a list of optimized prompts
    outputs = optimizer.optimize(
        inputs=dataset1, 
        # this extractor will be applied to all prompts inside the dataset
        p_extractor="\nNext, only give the exact answer, no extract words or any punctuation:",
        rounds=80
    )
    

Citation

@article{das2024greater,
  title={GReaTer: Gradients over Reasoning Makes Smaller Language Models Strong Prompt Optimizers},
  author={Das, Sarkar Snigdha Sarathi and Kamoi, Ryo and Pang, Bo and Zhang, Yusen and Xiong, Caiming and Zhang, Rui},
  journal={arXiv preprint arXiv:2412.09722},
  year={2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

greaterprompt-0.1.0.tar.gz (129.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

greaterprompt-0.1.0-py2.py3-none-any.whl (13.0 kB view details)

Uploaded Python 2Python 3

File details

Details for the file greaterprompt-0.1.0.tar.gz.

File metadata

  • Download URL: greaterprompt-0.1.0.tar.gz
  • Upload date:
  • Size: 129.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.2

File hashes

Hashes for greaterprompt-0.1.0.tar.gz
Algorithm Hash digest
SHA256 963571a8894b3d5f7af1ce02fe4804679a7e622bd3fac56969d5150a83fa1f12
MD5 e67986b058e6ce35c27b0c39a9bd3430
BLAKE2b-256 b1b2921e8030b6c50258f172cef615624e36b1c4a6d40d523a71af21a83437ef

See more details on using hashes here.

File details

Details for the file greaterprompt-0.1.0-py2.py3-none-any.whl.

File metadata

  • Download URL: greaterprompt-0.1.0-py2.py3-none-any.whl
  • Upload date:
  • Size: 13.0 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.2

File hashes

Hashes for greaterprompt-0.1.0-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 aaf08ccc620b9022a2c4883984e7eac9f845875c37da4413b2efcfca2b7238f5
MD5 a14c3ccec920d053828fc6da3e069679
BLAKE2b-256 c7ca488b1b1862899def8493164fb47e2703b27837c2baaa50c32c9923be3de3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page