Skip to main content

A Gradient-based Prompt Optimizer for Text Generation

Project description

GReaTer

arXiv colab

GReaTer: Gradients over Reasoning Makes Smaller Language Models Strong Prompt Optimizers

Sarkar Snigdha Sarathi Das, Ryo Kamoi, Bo Pang, Yusen Zhang, Caiming Xiong, Rui Zhang

Overview

overview

Three key components of GReaTer are the following:

  • The language model fLLM generates token candidates by conditioning on input samples.
  • fLLM uses task input and current prompt to generate reasoning and extract final answer logits.
  • The logits are used to calculate loss and compute gradient over generated reasoning with respect to the candidate tokens. These gradients determine the selection of candidate token to update the current position of the current prompt.

Installation

Not implemented yet. Don't really do the following. Refer to example.py to learn how to use for now.

pip install GReaTer

Usage

  1. create an input dataset for optimization

    from GReaTer import GreaterDataSet
    
    # There are two ways to create a dataset
    # 1. Load a pre-defined dataset from a json file
    dataset1 = GreaterDataSet(data_path="./data/boolean_expressions.jsonl")
    
    # 2. Create a dataset from scratch
    # custom_inputs is a list of dictionaries, each dictionary is suppposed to contain a question, a prompt, and an answer
    dataset2 = GreaterDataSet(custom_inputs=[
        {
            "question": "((-1 + 2 + 9 * 5) - (-2 + -4 + -4 * -7)) =", 
            "prompt": "Use logical reasoning and think step by step.", 
            "answer": "24"
        },
        {
            "question": "((-9 * -5 - 6 + -2) - (-8 - -6 * -3 * 1)) =",
            "prompt": "Use logical reasoning and think step by step.",
            "answer": "63"
        },
        {
            "question": "((3 * -3 * 6 + -5) - (-2 + -7 - 7 - -7)) =",
            "prompt": "Use logical reasoning and think step by step.",
            "answer": "-50"
        }
    ])
    
  2. define the optimize config, for details please refer to the our documentation page

    # optimizer config
    optimize_config = {
        "intersect_q": 5,
        "candidates_topk": 10,
        "loss_function": F.cross_entropy,
        "perplexity_loss": True,
        "perplexity_lambda": 0.2,
        "generate_config": {
            "temperature": 0.6,
            "max_new_tokens": 1024
        }
    }
    
  3. load the model and tokenizer to initialize the optimizer

    # So far we support Llama-3 and Gemma-2 family models
    # You could use transformers to load the model and tokenizer
    from transformers import LlamaForCausalLM, LlamaTokenizer
    
    model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
    tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
    
    # initialize the optimizer with the model, tokenizer, and optimize config
    optimizer = GreaterOptimizer(
        model=model, tokenizer=tokenizer, optimize_config=optimize_config
    )
    
  4. optimize the prompt

    # optimize the prompt, optimizer will return a dict containing original question and a list of optimized prompts
    outputs = optimizer.optimize(
        inputs=dataset1, 
        # this extractor will be applied to all prompts inside the dataset
        p_extractor="\nNext, only give the exact answer, no extract words or any punctuation:",
        rounds=80
    )
    

Citation

@article{das2024greater,
  title={GReaTer: Gradients over Reasoning Makes Smaller Language Models Strong Prompt Optimizers},
  author={Das, Sarkar Snigdha Sarathi and Kamoi, Ryo and Pang, Bo and Zhang, Yusen and Xiong, Caiming and Zhang, Rui},
  journal={arXiv preprint arXiv:2412.09722},
  year={2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

greaterprompt-0.1.1.tar.gz (130.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

greaterprompt-0.1.1-py2.py3-none-any.whl (13.0 kB view details)

Uploaded Python 2Python 3

File details

Details for the file greaterprompt-0.1.1.tar.gz.

File metadata

  • Download URL: greaterprompt-0.1.1.tar.gz
  • Upload date:
  • Size: 130.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.2

File hashes

Hashes for greaterprompt-0.1.1.tar.gz
Algorithm Hash digest
SHA256 7f4269da3a5d515a0bb8067d79bb3e2078fdc7db655f2079e14bd1cada429254
MD5 22e0a874a300269f29746faa59b60386
BLAKE2b-256 b552cc8342d4aaff600a8aa973940e5698e2ef0ce4f02058edf53df152a7ed8b

See more details on using hashes here.

File details

Details for the file greaterprompt-0.1.1-py2.py3-none-any.whl.

File metadata

  • Download URL: greaterprompt-0.1.1-py2.py3-none-any.whl
  • Upload date:
  • Size: 13.0 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.2

File hashes

Hashes for greaterprompt-0.1.1-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 1a2916e308ce7b9968a5af32eaadfe8be2bdaed40b5b3c74c39be233de3cb9df
MD5 44813f40c756e101de846519a2049e7d
BLAKE2b-256 6ff7b46d8e4126ee7288697b917bc64de4dfd82e0f9331aa1848f3f1f82b8738

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page