Skip to main content

A simulation training framework

Project description

Sparse Rewards Can Self-Train Dialogue Agents

Barrett Martin Lattimer, Varun Gangal, Ryan McDonald, Yi Yang

contact: blattimer@asapp.com

paper: https://arxiv.org/abs/2409.04617

This repo runs JOSH, the ToolWOZ, and τ-bench dataset. This repo also contains ways of logging training and preference-annotated episodes from user-simulator interactions and LORA-driven preference tuning of small LLMs from such preference annotated experience.

Setup

  1. Run the following in a new env
pip install josh-train

or

pip install -e .
  1. Unzip the dataset.zip file in the data folder

  2. Set up your openai credentials

export OPENAI_API_KEY= # api_key
export OPENAI_ORGANIZATION= # api_org

If you're running Llama or another local model, you will also need to set HF_TOKEN much in the same way. Wherever you see HF_KEY please replace it by your huggingface token.

Running ToolWOZ

You can run ToolWOZ normally by doing the following

python josh_train/main.py

Increase the --max_concurrency depending on your api rate limits

JOSH on ToolWOZ

Enable JOSH on ToolWOZ by adding the --josh flag, and make the running of JOSH print updates by also adding --josh_debug

One example of a more involved JOSH prompt would be the following

python josh_train/main.py --josh --josh_debug --max_concurrency 20 --seed 20 --task_split train --temperature 1.0 --agent_strategy react --user_mode goal --model gpt-4o-mini --end_index 10 --beam_size 8

Running τ-bench

We have added a clone of τ-bench to this repo with two run files, one for normal τ-bench testing and another for JOSH rollouts on τ-bench

To run τ-bench normally you can do

python tau-bench-eval/run.py

JOSH on τ-bench

To run JOSH on τ-bench you can do

python tau-bench-eval/run.py --josh --debug

Using JOSH

A class of JOSH is provided in this repo to be very flexible and work for a wide variety of user/agent interactions. To use JOSH yourself, you can start with the following code snippet

from josh_train.josh import JOSH, BaseJOSHAgent, BaseRewards, BaseJOSHUser
def add_error_message(agent):
        agent.messages.append({'role':'assistant', 'content':'Error: Agent ran out of retries.'})
        return agent
    
def step_agent(agent:BaseJOSHAgent, **kwargs):
    pass_to_customer = agent.step(**kwargs)
    return agent, pass_to_customer

def step_user(user:BaseJOSHUser, agent:BaseJOSHAgent):
    agent, end_conversation = user.step(agent)
    return agent, end_conversation

josh = JOSH(
            rewards=BaseRewards(['say hello', 'say hello', 'say hello']),
            agent_step=step_agent,
            user_step=step_user,
            add_error_message=add_error_message,
            root_agent = BaseJOSHAgent(),
            user = BaseJOSHUser(),
            debug=True
        )

for _ in range(10):
    max_reward, all_done = josh.step()
    if all_done:
        break

print(max_reward)
print(josh.training_examples)

All classes can be built on top of, and expanded for further use.

MT-Bench

(If you want to later evaluate MTBench)

unzip mtbencheval.zip

Citation

Please cite if you enjoyed this work!

@article{lattimer2024sparse,
  title={Sparse Rewards Can Self-Train Dialogue Agents},
  author={Lattimer, Barrett Martin and Gangal, Varun and McDonald, Ryan and Yang, Yi},
  journal={arXiv preprint arXiv:2409.04617},
  year={2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

josh_train-0.1.4.tar.gz (25.2 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

josh_train-0.1.4-py3-none-any.whl (25.1 MB view details)

Uploaded Python 3

File details

Details for the file josh_train-0.1.4.tar.gz.

File metadata

  • Download URL: josh_train-0.1.4.tar.gz
  • Upload date:
  • Size: 25.2 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.6

File hashes

Hashes for josh_train-0.1.4.tar.gz
Algorithm Hash digest
SHA256 d7186a185795f5f49f27961e509a45dbc22b30737184836a03a8f14e837adfda
MD5 ec383eca3ca44ce49b822a6350bea932
BLAKE2b-256 a36cd5638f90588c59651daadd9ef410b7dd753931bec9c6ec49f7643a1232a1

See more details on using hashes here.

File details

Details for the file josh_train-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: josh_train-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 25.1 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.6

File hashes

Hashes for josh_train-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 e905d1d28e28cb809011d3b5651ecb45eda5721df08cd1126d50c4990af205fa
MD5 2637d3791a520fc5cb598b2d985033bf
BLAKE2b-256 319992aef5ac6b47503853b14f9c17e37a0f464b68be35b7a160d5b220a4a164

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page