Skip to main content

A simulation training framework

Project description

Sparse Rewards Can Self-Train Dialogue Agents

Barrett Martin Lattimer, Varun Gangal, Ryan McDonald, Yi Yang

contact: blattimer@asapp.com

paper: https://arxiv.org/abs/2409.04617

This repo runs JOSH, the ToolWOZ, and τ-bench dataset. This repo also contains ways of logging training and preference-annotated episodes from user-simulator interactions and LORA-driven preference tuning of small LLMs from such preference annotated experience.

Setup

  1. Run the following in a new env
pip install josh-train

or

pip install -e .
  1. Unzip the dataset.zip file in the data folder

  2. Set up your openai credentials

export OPENAI_API_KEY= # api_key
export OPENAI_ORGANIZATION= # api_org

If you're running Llama or another local model, you will also need to set HF_TOKEN much in the same way. Wherever you see HF_KEY please replace it by your huggingface token.

Running ToolWOZ

You can run ToolWOZ normally by doing the following

python josh_train/main.py

Increase the --max_concurrency depending on your api rate limits

JOSH on ToolWOZ

Enable JOSH on ToolWOZ by adding the --josh flag, and make the running of JOSH print updates by also adding --josh_debug

One example of a more involved JOSH prompt would be the following

python josh_train/main.py --josh --josh_debug --max_concurrency 20 --seed 20 --task_split train --temperature 1.0 --agent_strategy react --user_mode goal --model gpt-4o-mini --end_index 10 --beam_size 8

Running τ-bench

We have added a clone of τ-bench to this repo with two run files, one for normal τ-bench testing and another for JOSH rollouts on τ-bench

To run τ-bench normally you can do

python tau-bench-eval/run.py

JOSH on τ-bench

To run JOSH on τ-bench you can do

python tau-bench-eval/run.py --josh --debug

Using JOSH

A class of JOSH is provided in this repo to be very flexible and work for a wide variety of user/agent interactions. To use JOSH yourself, you can start with the following code snippet

from josh_train.josh import JOSH, BaseJOSHAgent, BaseRewards, BaseJOSHUser
def add_error_message(agent):
        agent.messages.append({'role':'assistant', 'content':'Error: Agent ran out of retries.'})
        return agent
    
def step_agent(agent:BaseJOSHAgent, **kwargs):
    pass_to_customer = agent.step(**kwargs)
    return agent, pass_to_customer

def step_user(user:BaseJOSHUser, agent:BaseJOSHAgent):
    agent, end_conversation = user.step(agent)
    return agent, end_conversation

josh = JOSH(
            rewards=BaseRewards(['say hello', 'say hello', 'say hello']),
            agent_step=step_agent,
            user_step=step_user,
            add_error_message=add_error_message,
            root_agent = BaseJOSHAgent(),
            user = BaseJOSHUser(),
            debug=True
        )

for _ in range(10):
    max_reward, all_done = josh.step()
    if all_done:
        break

print(max_reward)
print(josh.training_examples)

All classes can be built on top of, and expanded for further use.

MT-Bench

(If you want to later evaluate MTBench)

unzip mtbencheval.zip

Citation

Please cite if you enjoyed this work!

@article{lattimer2024sparse,
  title={Sparse Rewards Can Self-Train Dialogue Agents},
  author={Lattimer, Barrett Martin and Gangal, Varun and McDonald, Ryan and Yang, Yi},
  journal={arXiv preprint arXiv:2409.04617},
  year={2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

josh_train-0.1.2.tar.gz (25.2 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

josh_train-0.1.2-py3-none-any.whl (25.1 MB view details)

Uploaded Python 3

File details

Details for the file josh_train-0.1.2.tar.gz.

File metadata

  • Download URL: josh_train-0.1.2.tar.gz
  • Upload date:
  • Size: 25.2 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.6

File hashes

Hashes for josh_train-0.1.2.tar.gz
Algorithm Hash digest
SHA256 b0131e979d851760943e9beca4ac006ffaeaf05a96bd2dda7114f864b99bcaed
MD5 cb715d027b69f98eb42e3a3f8b61cace
BLAKE2b-256 954db6c306f85238682c6bf4325d72d7fadbf79376ef448c9c62f2b0775ee1ee

See more details on using hashes here.

File details

Details for the file josh_train-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: josh_train-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 25.1 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.6

File hashes

Hashes for josh_train-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 7ce7dfb6f9fbd0c53d5a890e15b0f31be9db8e7f599cfcf4dbb085560a8be191
MD5 832e96e5b20ecc7dd2f5253731c74585
BLAKE2b-256 9477190469263f045075bf7b9b80307713c92921ed3b639fbfb5625b66944998

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page