A simulation training framework
Project description
Sparse Rewards Can Self-Train Dialogue Agents
Barrett Martin Lattimer, Varun Gangal, Ryan McDonald, Yi Yang
contact: blattimer@asapp.com
paper: https://arxiv.org/abs/2409.04617
This repo runs JOSH, the ToolWOZ, and τ-bench dataset. This repo also contains ways of logging training and preference-annotated episodes from user-simulator interactions and LORA-driven preference tuning of small LLMs from such preference annotated experience.
Setup
- Run the following in a new env
pip install -e .
-
Unzip the
dataset.zipfile in thedatafolder -
Set up your openai credentials
export OPENAI_API_KEY= # api_key
export OPENAI_ORGANIZATION= # api_org
If you're running Llama or another local model, you will also need to set HF_TOKEN much in the same way. Wherever you see HF_KEY please replace it by your huggingface token.
Running ToolWOZ
You can run ToolWOZ normally by doing the following
python josh_train/main.py
Increase the --max_concurrency depending on your api rate limits
JOSH on ToolWOZ
Enable JOSH on ToolWOZ by adding the --josh flag, and make the running of JOSH print updates by also adding --josh_debug
One example of a more involved JOSH prompt would be the following
python josh_train/main.py --josh --josh_debug --max_concurrency 20 --seed 20 --task_split train --temperature 1.0 --agent_strategy react --user_mode goal --model gpt-4o-mini --end_index 10 --beam_size 8
Running τ-bench
We have added a clone of τ-bench to this repo with two run files, one for normal τ-bench testing and another for JOSH rollouts on τ-bench
To run τ-bench normally you can do
python tau-bench-eval/run.py
JOSH on τ-bench
To run JOSH on τ-bench you can do
python tau-bench-eval/run.py --josh --debug
Using JOSH
A class of JOSH is provided in this repo to be very flexible and work for a wide variety of user/agent interactions. To use JOSH yourself, you can start with the following code snippet
from josh_train.josh import JOSH, BaseJOSHAgent, BaseRewards, BaseJOSHUser
def add_error_message(agent):
agent.messages.append({'role':'assistant', 'content':'Error: Agent ran out of retries.'})
return agent
def step_agent(agent:BaseJOSHAgent, **kwargs):
pass_to_customer = agent.step(**kwargs)
return agent, pass_to_customer
def step_user(user:BaseJOSHUser, agent:BaseJOSHAgent):
agent, end_conversation = user.step(agent)
return agent, end_conversation
josh = JOSH(
rewards=BaseRewards(['say hello', 'say hello', 'say hello']),
agent_step=step_agent,
user_step=step_user,
add_error_message=add_error_message,
root_agent = BaseJOSHAgent(),
user = BaseJOSHUser(),
debug=True
)
for _ in range(10):
max_reward, all_done = josh.step()
if all_done:
break
print(max_reward)
print(josh.training_examples)
All classes can be built on top of, and expanded for further use.
MT-Bench
(If you want to later evaluate MTBench)
unzip mtbencheval.zip
Citation
Please cite if you enjoyed this work!
@article{lattimer2024sparse,
title={Sparse Rewards Can Self-Train Dialogue Agents},
author={Lattimer, Barrett Martin and Gangal, Varun and McDonald, Ryan and Yang, Yi},
journal={arXiv preprint arXiv:2409.04617},
year={2024}
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file josh_train-0.1.0.tar.gz.
File metadata
- Download URL: josh_train-0.1.0.tar.gz
- Upload date:
- Size: 25.2 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
42134fd811aabb4ec235edea30f7c63cc08dd7b8599edba1d4739419e7d5f99c
|
|
| MD5 |
281c30d23e235a5aa2b3b7f509c5555e
|
|
| BLAKE2b-256 |
3110b692857f3464d6e153d37d111778248cd1d14922ba6ebdc1ae9456444be2
|
File details
Details for the file josh_train-0.1.0-py3-none-any.whl.
File metadata
- Download URL: josh_train-0.1.0-py3-none-any.whl
- Upload date:
- Size: 25.1 MB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f6dd2a51d617e3b39b86924dcb61b1ba1f36cbc8d4ed4650a281bee75b8f2701
|
|
| MD5 |
76043aa8eb5d646e04fd8500b330c60e
|
|
| BLAKE2b-256 |
23d14bcd758d47d6e25d35d18558bd9bcc50139fe6f9d35d59b8a89d6b4f694c
|