Skip to main content

A simulation training framework

Project description

Sparse Rewards Can Self-Train Dialogue Agents

Barrett Martin Lattimer, Varun Gangal, Ryan McDonald, Yi Yang

contact: blattimer@asapp.com

paper: https://arxiv.org/abs/2409.04617

This repo runs JOSH, the ToolWOZ, and τ-bench dataset. This repo also contains ways of logging training and preference-annotated episodes from user-simulator interactions and LORA-driven preference tuning of small LLMs from such preference annotated experience.

Setup

  1. Run the following in a new env
pip install -e .
  1. Unzip the dataset.zip file in the data folder

  2. Set up your openai credentials

export OPENAI_API_KEY= # api_key
export OPENAI_ORGANIZATION= # api_org

If you're running Llama or another local model, you will also need to set HF_TOKEN much in the same way. Wherever you see HF_KEY please replace it by your huggingface token.

Running ToolWOZ

You can run ToolWOZ normally by doing the following

python josh_train/main.py

Increase the --max_concurrency depending on your api rate limits

JOSH on ToolWOZ

Enable JOSH on ToolWOZ by adding the --josh flag, and make the running of JOSH print updates by also adding --josh_debug

One example of a more involved JOSH prompt would be the following

python josh_train/main.py --josh --josh_debug --max_concurrency 20 --seed 20 --task_split train --temperature 1.0 --agent_strategy react --user_mode goal --model gpt-4o-mini --end_index 10 --beam_size 8

Running τ-bench

We have added a clone of τ-bench to this repo with two run files, one for normal τ-bench testing and another for JOSH rollouts on τ-bench

To run τ-bench normally you can do

python tau-bench-eval/run.py

JOSH on τ-bench

To run JOSH on τ-bench you can do

python tau-bench-eval/run.py --josh --debug

Using JOSH

A class of JOSH is provided in this repo to be very flexible and work for a wide variety of user/agent interactions. To use JOSH yourself, you can start with the following code snippet

from josh_train.josh import JOSH, BaseJOSHAgent, BaseRewards, BaseJOSHUser
def add_error_message(agent):
        agent.messages.append({'role':'assistant', 'content':'Error: Agent ran out of retries.'})
        return agent
    
def step_agent(agent:BaseJOSHAgent, **kwargs):
    pass_to_customer = agent.step(**kwargs)
    return agent, pass_to_customer

def step_user(user:BaseJOSHUser, agent:BaseJOSHAgent):
    agent, end_conversation = user.step(agent)
    return agent, end_conversation

josh = JOSH(
            rewards=BaseRewards(['say hello', 'say hello', 'say hello']),
            agent_step=step_agent,
            user_step=step_user,
            add_error_message=add_error_message,
            root_agent = BaseJOSHAgent(),
            user = BaseJOSHUser(),
            debug=True
        )

for _ in range(10):
    max_reward, all_done = josh.step()
    if all_done:
        break

print(max_reward)
print(josh.training_examples)

All classes can be built on top of, and expanded for further use.

MT-Bench

(If you want to later evaluate MTBench)

unzip mtbencheval.zip

Citation

Please cite if you enjoyed this work!

@article{lattimer2024sparse,
  title={Sparse Rewards Can Self-Train Dialogue Agents},
  author={Lattimer, Barrett Martin and Gangal, Varun and McDonald, Ryan and Yang, Yi},
  journal={arXiv preprint arXiv:2409.04617},
  year={2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

josh_train-0.1.0.tar.gz (25.2 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

josh_train-0.1.0-py3-none-any.whl (25.1 MB view details)

Uploaded Python 3

File details

Details for the file josh_train-0.1.0.tar.gz.

File metadata

  • Download URL: josh_train-0.1.0.tar.gz
  • Upload date:
  • Size: 25.2 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.6

File hashes

Hashes for josh_train-0.1.0.tar.gz
Algorithm Hash digest
SHA256 42134fd811aabb4ec235edea30f7c63cc08dd7b8599edba1d4739419e7d5f99c
MD5 281c30d23e235a5aa2b3b7f509c5555e
BLAKE2b-256 3110b692857f3464d6e153d37d111778248cd1d14922ba6ebdc1ae9456444be2

See more details on using hashes here.

File details

Details for the file josh_train-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: josh_train-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 25.1 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.6

File hashes

Hashes for josh_train-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f6dd2a51d617e3b39b86924dcb61b1ba1f36cbc8d4ed4650a281bee75b8f2701
MD5 76043aa8eb5d646e04fd8500b330c60e
BLAKE2b-256 23d14bcd758d47d6e25d35d18558bd9bcc50139fe6f9d35d59b8a89d6b4f694c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page