Skip to main content

A simulation training framework

Project description

Sparse Rewards Can Self-Train Dialogue Agents

Barrett Martin Lattimer, Varun Gangal, Ryan McDonald, Yi Yang

contact: blattimer@asapp.com

paper: https://arxiv.org/abs/2409.04617

This repo runs JOSH, the ToolWOZ, and τ-bench dataset. This repo also contains ways of logging training and preference-annotated episodes from user-simulator interactions and LORA-driven preference tuning of small LLMs from such preference annotated experience.

Setup

  1. Run the following in a new env
pip install josh-train

or

pip install -e .
  1. Unzip the dataset.zip file in the data folder

  2. Set up your openai credentials

export OPENAI_API_KEY= # api_key
export OPENAI_ORGANIZATION= # api_org

If you're running Llama or another local model, you will also need to set HF_TOKEN much in the same way. Wherever you see HF_KEY please replace it by your huggingface token.

Running ToolWOZ

You can run ToolWOZ normally by doing the following

python josh_train/main.py

Increase the --max_concurrency depending on your api rate limits

JOSH on ToolWOZ

Enable JOSH on ToolWOZ by adding the --josh flag, and make the running of JOSH print updates by also adding --josh_debug

One example of a more involved JOSH prompt would be the following

python josh_train/main.py --josh --josh_debug --max_concurrency 20 --seed 20 --task_split train --temperature 1.0 --agent_strategy react --user_mode goal --model gpt-4o-mini --end_index 10 --beam_size 8

Running τ-bench

We have added a clone of τ-bench to this repo with two run files, one for normal τ-bench testing and another for JOSH rollouts on τ-bench

To run τ-bench normally you can do

python tau-bench-eval/run.py

JOSH on τ-bench

To run JOSH on τ-bench you can do

python tau-bench-eval/run.py --josh --debug

Using JOSH

A class of JOSH is provided in this repo to be very flexible and work for a wide variety of user/agent interactions. To use JOSH yourself, you can start with the following code snippet

from josh_train.josh import JOSH, BaseJOSHAgent, BaseRewards, BaseJOSHUser
def add_error_message(agent):
        agent.messages.append({'role':'assistant', 'content':'Error: Agent ran out of retries.'})
        return agent
    
def step_agent(agent:BaseJOSHAgent, **kwargs):
    pass_to_customer = agent.step(**kwargs)
    return agent, pass_to_customer

def step_user(user:BaseJOSHUser, agent:BaseJOSHAgent):
    agent, end_conversation = user.step(agent)
    return agent, end_conversation

josh = JOSH(
            rewards=BaseRewards(['say hello', 'say hello', 'say hello']),
            agent_step=step_agent,
            user_step=step_user,
            add_error_message=add_error_message,
            root_agent = BaseJOSHAgent(),
            user = BaseJOSHUser(),
            debug=True
        )

for _ in range(10):
    max_reward, all_done = josh.step()
    if all_done:
        break

print(max_reward)
print(josh.training_examples)

All classes can be built on top of, and expanded for further use.

MT-Bench

(If you want to later evaluate MTBench)

unzip mtbencheval.zip

Citation

Please cite if you enjoyed this work!

@article{lattimer2024sparse,
  title={Sparse Rewards Can Self-Train Dialogue Agents},
  author={Lattimer, Barrett Martin and Gangal, Varun and McDonald, Ryan and Yang, Yi},
  journal={arXiv preprint arXiv:2409.04617},
  year={2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

josh_train-0.1.3.tar.gz (25.2 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

josh_train-0.1.3-py3-none-any.whl (25.1 MB view details)

Uploaded Python 3

File details

Details for the file josh_train-0.1.3.tar.gz.

File metadata

  • Download URL: josh_train-0.1.3.tar.gz
  • Upload date:
  • Size: 25.2 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.6

File hashes

Hashes for josh_train-0.1.3.tar.gz
Algorithm Hash digest
SHA256 980324de5f8df0fb1a8c9df58bce01e5b2f0fb5f3f25e7f583c5091884c34769
MD5 00fbf2bf4b9cfc865f0b48fc45db06d5
BLAKE2b-256 203adf55a1dd40c91c8f441a898adce8c41e2a29295dcdeecdc1c6cc7cbda288

See more details on using hashes here.

File details

Details for the file josh_train-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: josh_train-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 25.1 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.6

File hashes

Hashes for josh_train-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 4cd05b7beeb48aebf3b3e4ab48ad9f6680a71287619de1f36e7b8757dfe59d7a
MD5 39c38611531eeb39ea4f6faa25157d0e
BLAKE2b-256 b7b27b207414373a33bc0d78a852d191e238f4b9c05703e7cccc6c9e948e6d02

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page