Skip to main content

A lightweight post-training framework for LLMs

Project description

oxRL

oxRL

Post-train any model under 10 lines of code.

A lightweight post-training framework for LLMs, VLMs, and VLAs. Maximizing developer speed. Scales to billions of parameters with DeepSpeed, vLLM, and Ray.


Usage (Python API)

Post-train any model in under 10 lines of code. oxRL auto-detects your hardware and can auto-prepare common datasets.

from oxrl import Trainer

# 1. Initialize with your model
trainer = Trainer(model="Qwen/Qwen2.5-0.5B-Instruct")

# 2. Start training (auto-downloads and preps dataset)
trainer.train(dataset="gsm8k")

Supported Models

The following models have been verified and onboarded. You can find ready-to-use scripts in the examples/onboarded_models/ directory.

Model Params Task Dataset GPU Setup
Qwen2.5-0.5B-Instruct 0.5B Math GSM8K 1 Train + 1 Rollout
Qwen2.5-1.5B-Instruct 1.5B Math GSM8K 1 Train + 1 Rollout
Qwen2.5-Coder-1.5B-Instruct 1.5B Code MBPP 1 Train + 1 Rollout
SmolLM2-1.7B-Instruct 1.7B Instruct UltraFeedback 1 Train + 1 Rollout
Qwen2.5-3B-Instruct 3.0B Math MATH 1 Train + 1 Rollout
Qwen2.5-7B-Instruct 7.0B Math GSM8K 2 Train + 2 Rollout
DeepSeek-R1-Distill-Qwen-7B 7.0B Reasoning MATH 2 Train + 2 Rollout
Mistral-7B-Instruct-v0.3 7.0B Instruct UltraFeedback 2 Train + 2 Rollout

System Architecture

┌─────────────────────────────────────────────────────────────────┐
│                         oxRL Framework                          │
├─────────────────────┬───────────────────┬───────────────────────┤
│   Training Engines  │  Rollout Engines  │    Config + Data      │
│   (Ray + DeepSpeed) │  (Ray + vLLM)     │    (Pydantic + HF)    │
├─────────────────────┼───────────────────┼───────────────────────┤
│                     │                   │                       │
│  algs/grpo.py       │ rollouts/         │ configs/load.py       │
│    SGRPO loss       │   vllm_engine.py  │ configs/*.yaml        │
│    CISPO loss       │   replay_buffer.py│                       │
│  algs/PPO/ppo.py    │                   │ datasets/             │
│  algs/SFT/sft.py    │                   │   prompt_only.py      │
│                     │                   │   prompt_response.py  │
│                     │                   │   mixed_ratio_sampler │
├─────────────────────┴───────────────────┴───────────────────────┤
│  utils/setup.py  │  utils/logging.py  │  rewards/compute_score  │
└──────────────────┴────────────────────┴─────────────────────────┘

RL Training Workflow

┌──────────────┐     ┌───────────────────┐     ┌──────────────────┐
│ Load Config  │────▶│  Initialize Ray   │────▶│ Create Engines   │
│ (YAML file)  │     │  Cluster          │     │ N train + M roll │
└──────────────┘     └───────────────────┘     └────────┬─────────┘
                                                        │
                     ┌──────────────────────────────────┘
                     ▼
          ┌─────────────────────┐
          │   For each epoch:   │
          │                     │
          │  ┌───────────────┐  │    Rollout engines generate responses
          │  │ 1. Rollouts   │  │    using vLLM, compute rewards,
          │  │    (vLLM)     │  │    store in replay buffer
          │  └───────┬───────┘  │
          │          │          │
          │  ┌───────▼───────┐  │    Training engines run forward/backward
          │  │ 2. Train      │  │    with DeepSpeed ZeRO-3 across GPUs
          │  │    (DeepSpeed)│  │
          │  └───────┬───────┘  │
          │          │          │
          │  ┌───────▼───────┐  │    Save model, update rollout engines
          │  │ 3. Checkpoint │  │    with new policy weights
          │  │    + Refresh  │  │
          │  └───────────────┘  │
          └─────────────────────┘

Quick Start

Installation

git clone https://github.com/warlockee/oxRL.git
cd oxRL
pip install -r req.txt

Dependencies: PyTorch, DeepSpeed, vLLM, Ray, Transformers, Pydantic.


Post-train a model in 3 steps (CLI)

Step 1. Prepare your data as a parquet or JSONL file with chat-format prompts:

{"prompt": [{"role": "user", "content": "What is 2+2?"}], "answer": "4"}

Step 2. Write a minimal config (everything else uses sensible defaults):

# config.yaml
run:
  experiment_id: "my-run"
  training_gpus: 2
  rollout_gpus: 2

train:
  alg_name: "sgrpo"           # or "cispo"
  total_number_of_epochs: 10
  train_steps_per_epoch: 20

model:
  name: "google/gemma-3-1b-it" # any HuggingFace model

data:
  train_dnames: ["my_data"]
  train_ratios: {"my_data": 1.0}
  train_files_path: "./data/train.parquet"
  val_files_path: "./data/val.parquet"

Step 3. Run:

python main_rl.py --config-file config.yaml

For supervised fine-tuning:

python main_sl.py --config-file configs/sl_args.yaml

See examples/quickstart.py for a complete runnable example.

Custom Reward Functions

Write a function in rewards/compute_score.py and set reward.reward_func in your config:

def my_reward(prompt_ids, response_ids, finish_reason):
    r = torch.zeros(len(response_ids), dtype=torch.float32)
    # your scoring logic here
    r[-1] = 1.0 if meets_criteria(response_ids) else 0.0
    return r, False  # (reward_tensor, is_per_token)
reward:
  reward_func: "my_reward"

Algorithms

Algorithm File Description
SGRPO algs/grpo.py Stable GRPO — clipped surrogate loss with optional KL regularization from reference model
CISPO algs/grpo.py Clipped importance-sampling policy optimization — weighted log-probability loss
PPO algs/PPO/ppo.py Proximal Policy Optimization with GAE, value clipping, entropy bonus
SFT algs/SFT/sft.py Supervised fine-tuning with masked cross-entropy loss

SGRPO and CISPO share the same training infrastructure (DeepSpeed + Ray actors) and differ only in the policy loss computation. Select with train.alg_name in your config.

Project Structure

oxRL/
├── main_rl.py              RL training loop (Ray + DeepSpeed)
├── main_sl.py              SL training loop (DeepSpeed)
├── algs/
│   ├── grpo.py             SGRPO + CISPO (unified, loss_variant selects)
│   ├── PPO/ppo.py          PPO with GAE + value function
│   └── SFT/sft.py          Supervised fine-tuning
├── configs/
│   ├── load.py             Pydantic config with sensible defaults
│   ├── rl_args.yaml        Full RL config example
│   └── sl_args.yaml        Full SL config example
├── datasets/
│   ├── prompt_only.py      RL prompts (chat format → tokens)
│   ├── prompt_response.py  SL prompt-response pairs
│   └── mixed_ratio_sampler Multi-dataset weighted sampling
├── rollouts/
│   ├── vllm_engine.py      vLLM inference with hot model refresh
│   └── replay_buffer.py    On-policy sample storage
├── rewards/
│   └── compute_score.py    Pluggable reward functions
├── utils/
│   ├── setup.py            Distributed setup (seeds, rank, tokenizer)
│   ├── logging.py          Rank-aware logging + MLflow
│   └── utils.py            Tensor helpers (dtype, padding)
├── preprocessing/
│   └── gsm8k.py            GSM8K dataset preparation
└── examples/
    └── quickstart.py       End-to-end example (48 lines)

Configuration

oxRL uses Pydantic for type-safe configuration. Every field has a sensible default — you only need to specify what's unique to your run.

Required fields (no defaults):

Section Field Description
run experiment_id Name for this run
run training_gpus Number of GPUs for training
run rollout_gpus Number of GPUs for rollout generation
train alg_name Algorithm: sgrpo, cispo, sft
train total_number_of_epochs Training epochs
train train_steps_per_epoch Optimizer steps per epoch (RL)
model name HuggingFace model ID
data train_dnames Dataset name list
data train_ratios Dataset mixing ratios
data train_files_path Path to training data
data val_files_path Path to validation data

Everything else (optimizer, scheduler, DeepSpeed ZeRO-3, vLLM rollouts, reward function) defaults to production-tested values. See configs/rl_args.yaml for the full reference.

Experiment Tracking

MLflow is supported but optional. oxRL works out of the box without it — training runs fine, you just won't get experiment tracking.

Without MLflow (default): training logs to console only. Nothing to configure.

With MLflow: install it and set the tracking URI in your config:

pip install mlflow
run:
  tracking_uri: "http://localhost:5000"  # or your MLflow server

Start a local MLflow UI:

mlflow ui --port 5000

oxRL automatically logs hyperparameters, per-step losses, KL divergence, clip fractions, rewards, and epoch-level aggregates to MLflow.

Key Design Decisions

Sequential rollout → training. oxRL does not pipeline rollout generation with training. This is deliberate. Pipelined overlap improves GPU utilization but makes debugging significantly harder. When training diverges at step 4,000, you want to know exactly what happened.

One class for SGRPO and CISPO. Both algorithms share 99% of their code — the only difference is 4 lines in the policy loss computation. A loss_variant parameter selects between them. No inheritance, no abstraction.

DeepSpeed ZeRO-3 by default. The config system auto-syncs optimizer, scheduler, dtype, and batch size settings to DeepSpeed — you configure once in the YAML and oxRL handles the rest.

Strict on-policy enforcement. Optional mode that validates rollouts were generated by the current policy version. Catches silent distribution shift bugs that waste GPU-days.

Contributing

Contributions are welcome. The bar: keep changes readable, testable, and debuggable. Follow the existing style. If your change adds complexity, it should be worth it.

FAQ

Check out the FAQ for common questions and answers.

Acknowledgments

Some components of this codebase are inspired by practices from open source projects. We try to cite sources wherever we directly reuse exact code. If we missed a citation, please let us know and we will credit the source.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

oxrl-0.8.1.tar.gz (18.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

oxrl-0.8.1-py3-none-any.whl (15.1 kB view details)

Uploaded Python 3

File details

Details for the file oxrl-0.8.1.tar.gz.

File metadata

  • Download URL: oxrl-0.8.1.tar.gz
  • Upload date:
  • Size: 18.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for oxrl-0.8.1.tar.gz
Algorithm Hash digest
SHA256 6154b1fb66094a89fd28f6c43e3b31ba7e91a71aa977464fb9e305d30021f503
MD5 170417c5f7ff01772ce66d5ea8b21c8c
BLAKE2b-256 3fd74e4dbbcf2011309b78fe485de2f4f4675cc285ec7fda3f1840ec64c9228a

See more details on using hashes here.

File details

Details for the file oxrl-0.8.1-py3-none-any.whl.

File metadata

  • Download URL: oxrl-0.8.1-py3-none-any.whl
  • Upload date:
  • Size: 15.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for oxrl-0.8.1-py3-none-any.whl
Algorithm Hash digest
SHA256 58094607c7934f29498a8f550bb6e0f917f70800d54bf9457368c6debe1e0201
MD5 4844a7bcb9730210079ff0f22fb8d71d
BLAKE2b-256 ba8f9a0d4fbcc0727fc40e22387c02c5487dd74a825e763fd90053fce108e5cc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page