Skip to main content

Production-ready RL framework for training multi-turn conversational AI agents using GRPO and GSPO

Project description

StateSet Agents

Reinforcement‑learning framework for multi‑turn conversational AI agents.

PyPI version Python 3.8+ License: BUSL-1.1

StateSet Agents is a production‑oriented RL stack for training and serving LLM‑backed agents that improve through multi‑turn interaction. The library provides:

  • Async‑first agent APIs (MultiTurnAgent, ToolAgent) with Hugging Face and stub backends.
  • Environments for conversational and task‑oriented episodes.
  • Trajectories and value/advantage utilities tailored to dialogue.
  • Composable reward functions (heuristic, domain, multi‑objective, neural).
  • A family of group‑based policy‑optimization trainers (GRPO, GSPO, GEPO, DAPO, VAPO) plus PPO and RLAIF.
  • Optional performance layers (vLLM generation, Rust acceleration, distributed training, HPO, FastAPI service).

If you want a framework that treats conversations as first‑class RL episodes (rather than single turns), this is it.


Why group‑based optimization?

Traditional RLHF/PPO trains on one sampled response at a time. In long conversations this leads to high‑variance updates and brittle behavior.
StateSet Agents implements group‑relative methods:

  • GRPO (Group Relative Policy Optimization): sample a group of trajectories per prompt, compute advantages relative to the group baseline, then apply clipped policy‑gradient updates.
  • GSPO (Group Sequence Policy Optimization): a more stable sequence‑level variant (Alibaba Qwen team) that avoids token‑level collapse on long outputs and MoE models.

The result is steadier learning for dialogue tasks.


Core concepts

  • Agent: wraps a causal LM and exposes initialize() and generate_response().
    • MultiTurnAgent handles conversation history and state.
    • ToolAgent adds function/tool calling.
  • Environment: defines episode reset/step logic and optional reward hooks.
    • ConversationEnvironment ships with scenario‑driven multi‑turn conversations.
    • TaskEnvironment is for goal‑oriented tasks.
  • Trajectory: a multi‑turn record of turns, rewards, and metadata (MultiTurnTrajectory).
  • Rewards: RewardFunction subclasses and factories; combined via CompositeReward or multi‑objective reward models.
  • Training: trainers in stateset_agents.training implement GRPO‑family updates, GAE/value heads, KL regularization, LoRA support, and optional distributed/vLLM execution.

Installation

Core (lightweight, stub‑ready)

pip install stateset-agents

Training / real models

pip install "stateset-agents[training]"

Optional extras

pip install "stateset-agents[trl]"        # TRL GRPO integration + bitsandbytes
pip install "stateset-agents[vllm]"       # vLLM generation backend
pip install "stateset-agents[hpo]"        # Optuna/Ray Tune HPO
pip install "stateset-agents[api]"        # FastAPI service
pip install "stateset-agents[distributed]"# DeepSpeed / multi‑GPU helpers
pip install "stateset-agents[full]"       # Most extras in one go

Quick start

1) Stub hello world (no downloads)

Runs without Torch/transformers and is ideal for CI or prototyping.

import asyncio
from stateset_agents import MultiTurnAgent
from stateset_agents.core.agent import AgentConfig

async def main():
    agent = MultiTurnAgent(AgentConfig(model_name="stub://demo"))
    await agent.initialize()
    reply = await agent.generate_response([{"role": "user", "content": "Hi!"}])
    print(reply)

asyncio.run(main())

2) Chat with a real model

import asyncio
from stateset_agents import MultiTurnAgent
from stateset_agents.core.agent import AgentConfig

async def main():
    agent = MultiTurnAgent(
        AgentConfig(model_name="gpt2", max_new_tokens=128, temperature=0.7)
    )
    await agent.initialize()
    messages = [{"role": "user", "content": "What is GRPO?"}]
    print(await agent.generate_response(messages))

asyncio.run(main())

Train a multi‑turn agent with GRPO

The high‑level train(...) helper chooses single‑turn vs multi‑turn GRPO automatically.

import asyncio
from stateset_agents import (
    MultiTurnAgent,
    ConversationEnvironment,
    CompositeReward,
    HelpfulnessReward,
    SafetyReward,
    train,
)
from stateset_agents.core.agent import AgentConfig

async def main():
    # 1) Agent
    agent = MultiTurnAgent(AgentConfig(model_name="gpt2"))
    await agent.initialize()

    # 2) Environment
    scenarios = [
        {
            "id": "refund",
            "topic": "refunds",
            "context": "User wants a refund for a delayed order.",
            "user_responses": [
                "My order is late.",
                "I'd like a refund.",
                "Thanks for your help.",
            ],
        }
    ]
    env = ConversationEnvironment(scenarios=scenarios, max_turns=6)

    # 3) Reward
    reward_fn = CompositeReward(
        [HelpfulnessReward(weight=0.7), SafetyReward(weight=0.3)]
    )

    # 4) Train
    trained_agent = await train(
        agent=agent,
        environment=env,
        reward_fn=reward_fn,
        num_episodes=50,
        profile="balanced",
        save_path="./outputs/refund_agent",
    )

    # 5) Try the trained model
    resp = await trained_agent.generate_response(
        [{"role": "user", "content": "My order was delayed, what can you do?"}]
    )
    print(resp)

asyncio.run(main())

More end‑to‑end scripts live in examples/complete_grpo_training.py and examples/production_ready_customer_service.py.


Other training algorithms

All algorithms are available under stateset_agents.training when training deps are installed:

  • GSPO: stable sequence‑level GRPO variant (GSPOTrainer, GSPOConfig, train_with_gspo)
  • GEPO: expectation‑based group optimization for heterogeneous/distributed setups
  • DAPO: decoupled clip + dynamic sampling for reasoning‑heavy tasks
  • VAPO: value‑augmented group optimization (strong for math/reasoning)
  • PPO baseline: standard PPO trainer for comparison
  • RLAIF: RL from AI feedback via judge/reward models

Minimal GSPO sketch:

from stateset_agents.training import get_config_for_task, GSPOConfig, train_with_gspo
from stateset_agents.rewards.multi_objective_reward import create_customer_service_reward

base_cfg = get_config_for_task("customer_service", model_name="gpt2")
gspo_cfg = GSPOConfig.from_training_config(base_cfg, num_outer_iterations=5)

trained_agent = await train_with_gspo(
    config=gspo_cfg,
    agent=agent,
    environment=env,
    reward_model=create_customer_service_reward(),
)

See docs/GSPO_GUIDE.md, docs/ADVANCED_RL_ALGORITHMS.md, and examples/train_with_gspo.py for full configs.


Hyperparameter optimization (HPO)

Install with stateset-agents[hpo], then:

from stateset_agents.training import TrainingConfig, TrainingProfile
from stateset_agents.training.hpo import quick_hpo

base_cfg = TrainingConfig.from_profile(
    TrainingProfile.BALANCED, num_episodes=100
)

summary = await quick_hpo(
    agent=agent,
    environment=env,
    reward_function=reward_fn,
    base_config=base_cfg,
    n_trials=30,
)
print(summary.best_params)

See docs/HPO_GUIDE.md and examples/hpo_training_example.py.


Custom rewards

Use the decorator for quick experiments:

from stateset_agents.core.reward import reward_function

@reward_function(weight=0.5)
async def politeness_reward(turns, context=None) -> float:
    return 1.0 if any("please" in t.content.lower() for t in turns) else 0.0

Combine with built‑ins via CompositeReward.


Custom environments

Subclass Environment for task‑specific dynamics:

from stateset_agents.core.environment import Environment, EnvironmentState
from stateset_agents.core.trajectory import ConversationTurn

class MyEnv(Environment):
    async def reset(self, scenario=None) -> EnvironmentState:
        ...

    async def step(
        self, state: EnvironmentState, action: ConversationTurn
    ):
        ...

Checkpoints

  • train(..., save_path="...") saves an agent checkpoint.
  • Load later:
from stateset_agents.core.agent import load_agent_from_checkpoint

agent = await load_agent_from_checkpoint("./outputs/refund_agent")

CLI

The CLI is a thin wrapper around the Python API:

stateset-agents version
stateset-agents doctor
stateset-agents train --stub
stateset-agents train --config ./config.yaml --dry-run false --save ./outputs/ckpt
stateset-agents evaluate --checkpoint ./outputs/ckpt --message "Hello"
stateset-agents serve --host 0.0.0.0 --port 8001

For complex runs prefer the Python API and the examples folder.


Examples and docs

Good starting points:

  • examples/hello_world.py – stub mode walkthrough
  • examples/quick_start.py – basic agent + environment
  • examples/complete_grpo_training.py – end‑to‑end GRPO training
  • examples/train_with_gspo.py – GSPO + GSPO‑token training
  • examples/train_with_trl_grpo.py – Hugging Face TRL GRPO integration

Key docs:

  • docs/USAGE_GUIDE.md
  • docs/RL_FRAMEWORK_GUIDE.md
  • docs/GSPO_GUIDE.md
  • docs/HPO_GUIDE.md
  • docs/CLI_REFERENCE.md
  • docs/ARCHITECTURE.md

Related Projects


Contributing

See CONTRIBUTING.md. Please run pytest -q and format with black/isort before opening a PR.


License

Business Source License 1.1. Non‑production use permitted until 2029‑09‑03, then transitions to Apache 2.0. See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stateset_agents-0.6.0.tar.gz (530.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

stateset_agents-0.6.0-py3-none-any.whl (638.0 kB view details)

Uploaded Python 3

File details

Details for the file stateset_agents-0.6.0.tar.gz.

File metadata

  • Download URL: stateset_agents-0.6.0.tar.gz
  • Upload date:
  • Size: 530.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.16

File hashes

Hashes for stateset_agents-0.6.0.tar.gz
Algorithm Hash digest
SHA256 663050502f56664b3ed97e88fe73d50083a156e0bea10895efda61131824d9eb
MD5 e86acc3c9c95a117206c7942ba87f43e
BLAKE2b-256 e35b974172b696bca9355821c3737da7545263ef901250ec56b2e8659707e296

See more details on using hashes here.

File details

Details for the file stateset_agents-0.6.0-py3-none-any.whl.

File metadata

  • Download URL: stateset_agents-0.6.0-py3-none-any.whl
  • Upload date:
  • Size: 638.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.16

File hashes

Hashes for stateset_agents-0.6.0-py3-none-any.whl
Algorithm Hash digest
SHA256 503b72cf42c57ad6bc5838a9ae1e048b8399b3e353e23597b975fdfa7eaab419
MD5 5cc76950d5bb0d9b04d3e1526e9ec4bb
BLAKE2b-256 3767be6595ab9e499836933d753f1767415fd9f79c0b9ba7f5904babd42291c3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page