A lightweight post-training framework for LLMs
Project description
oxRL
Post-train any model under 10 lines of code.
A lightweight post-training framework for LLMs, VLMs, and VLAs. Maximizing developer speed. Scales to billions of parameters with DeepSpeed, vLLM, and Ray.
Usage (Python API)
Post-train any model in under 10 lines of code. oxRL auto-detects your hardware and can auto-prepare common datasets.
from oxrl import Trainer
# 1. Initialize with your model
trainer = Trainer(model="Qwen/Qwen2.5-0.5B-Instruct")
# 2. Start training (auto-downloads and preps dataset)
trainer.train(dataset="gsm8k")
Supported Models
The following models have been verified and onboarded. You can find ready-to-use scripts in the examples/onboarded_models/ directory.
| Model | Params | Task | Dataset | GPU Setup |
|---|---|---|---|---|
| Qwen2.5-0.5B-Instruct | 0.5B | Math | GSM8K | 1 Train + 1 Rollout |
| Qwen2.5-1.5B-Instruct | 1.5B | Math | GSM8K | 1 Train + 1 Rollout |
| Qwen2.5-Coder-1.5B-Instruct | 1.5B | Code | MBPP | 1 Train + 1 Rollout |
| SmolLM2-1.7B-Instruct | 1.7B | Instruct | UltraFeedback | 1 Train + 1 Rollout |
| Qwen2.5-3B-Instruct | 3.0B | Math | MATH | 1 Train + 1 Rollout |
| Qwen2.5-7B-Instruct | 7.0B | Math | GSM8K | 2 Train + 2 Rollout |
| DeepSeek-R1-Distill-Qwen-7B | 7.0B | Reasoning | MATH | 2 Train + 2 Rollout |
| Mistral-7B-Instruct-v0.3 | 7.0B | Instruct | UltraFeedback | 2 Train + 2 Rollout |
System Architecture
┌─────────────────────────────────────────────────────────────────┐
│ oxRL Framework │
├─────────────────────┬───────────────────┬───────────────────────┤
│ Training Engines │ Rollout Engines │ Config + Data │
│ (Ray + DeepSpeed) │ (Ray + vLLM) │ (Pydantic + HF) │
├─────────────────────┼───────────────────┼───────────────────────┤
│ │ │ │
│ algs/grpo.py │ rollouts/ │ configs/load.py │
│ SGRPO loss │ vllm_engine.py │ configs/*.yaml │
│ CISPO loss │ replay_buffer.py│ │
│ algs/PPO/ppo.py │ │ datasets/ │
│ algs/SFT/sft.py │ │ prompt_only.py │
│ │ │ prompt_response.py │
│ │ │ mixed_ratio_sampler │
├─────────────────────┴───────────────────┴───────────────────────┤
│ utils/setup.py │ utils/logging.py │ rewards/compute_score │
└──────────────────┴────────────────────┴─────────────────────────┘
RL Training Workflow
┌──────────────┐ ┌───────────────────┐ ┌──────────────────┐
│ Load Config │────▶│ Initialize Ray │────▶│ Create Engines │
│ (YAML file) │ │ Cluster │ │ N train + M roll │
└──────────────┘ └───────────────────┘ └────────┬─────────┘
│
┌──────────────────────────────────┘
▼
┌─────────────────────┐
│ For each epoch: │
│ │
│ ┌───────────────┐ │ Rollout engines generate responses
│ │ 1. Rollouts │ │ using vLLM, compute rewards,
│ │ (vLLM) │ │ store in replay buffer
│ └───────┬───────┘ │
│ │ │
│ ┌───────▼───────┐ │ Training engines run forward/backward
│ │ 2. Train │ │ with DeepSpeed ZeRO-3 across GPUs
│ │ (DeepSpeed)│ │
│ └───────┬───────┘ │
│ │ │
│ ┌───────▼───────┐ │ Save model, update rollout engines
│ │ 3. Checkpoint │ │ with new policy weights
│ │ + Refresh │ │
│ └───────────────┘ │
└─────────────────────┘
Quick Start
Installation
git clone https://github.com/warlockee/oxRL.git
cd oxRL
pip install -r req.txt
Dependencies: PyTorch, DeepSpeed, vLLM, Ray, Transformers, Pydantic.
Post-train a model in 3 steps (CLI)
Step 1. Prepare your data as a parquet or JSONL file with chat-format prompts:
{"prompt": [{"role": "user", "content": "What is 2+2?"}], "answer": "4"}
Step 2. Write a minimal config (everything else uses sensible defaults):
# config.yaml
run:
experiment_id: "my-run"
training_gpus: 2
rollout_gpus: 2
train:
alg_name: "sgrpo" # or "cispo"
total_number_of_epochs: 10
train_steps_per_epoch: 20
model:
name: "google/gemma-3-1b-it" # any HuggingFace model
data:
train_dnames: ["my_data"]
train_ratios: {"my_data": 1.0}
train_files_path: "./data/train.parquet"
val_files_path: "./data/val.parquet"
Step 3. Run:
python main_rl.py --config-file config.yaml
For supervised fine-tuning:
python main_sl.py --config-file configs/sl_args.yaml
See examples/quickstart.py for a complete runnable example.
Custom Reward Functions
Write a function in rewards/compute_score.py and set reward.reward_func in your config:
def my_reward(prompt_ids, response_ids, finish_reason):
r = torch.zeros(len(response_ids), dtype=torch.float32)
# your scoring logic here
r[-1] = 1.0 if meets_criteria(response_ids) else 0.0
return r, False # (reward_tensor, is_per_token)
reward:
reward_func: "my_reward"
Algorithms
| Algorithm | File | Description |
|---|---|---|
| SGRPO | algs/grpo.py |
Stable GRPO — clipped surrogate loss with optional KL regularization from reference model |
| CISPO | algs/grpo.py |
Clipped importance-sampling policy optimization — weighted log-probability loss |
| PPO | algs/PPO/ppo.py |
Proximal Policy Optimization with GAE, value clipping, entropy bonus |
| SFT | algs/SFT/sft.py |
Supervised fine-tuning with masked cross-entropy loss |
SGRPO and CISPO share the same training infrastructure (DeepSpeed + Ray actors) and differ only in the policy loss computation. Select with train.alg_name in your config.
Project Structure
oxRL/
├── main_rl.py RL training loop (Ray + DeepSpeed)
├── main_sl.py SL training loop (DeepSpeed)
├── algs/
│ ├── grpo.py SGRPO + CISPO (unified, loss_variant selects)
│ ├── PPO/ppo.py PPO with GAE + value function
│ └── SFT/sft.py Supervised fine-tuning
├── configs/
│ ├── load.py Pydantic config with sensible defaults
│ ├── rl_args.yaml Full RL config example
│ └── sl_args.yaml Full SL config example
├── datasets/
│ ├── prompt_only.py RL prompts (chat format → tokens)
│ ├── prompt_response.py SL prompt-response pairs
│ └── mixed_ratio_sampler Multi-dataset weighted sampling
├── rollouts/
│ ├── vllm_engine.py vLLM inference with hot model refresh
│ └── replay_buffer.py On-policy sample storage
├── rewards/
│ └── compute_score.py Pluggable reward functions
├── utils/
│ ├── setup.py Distributed setup (seeds, rank, tokenizer)
│ ├── logging.py Rank-aware logging + MLflow
│ └── utils.py Tensor helpers (dtype, padding)
├── preprocessing/
│ └── gsm8k.py GSM8K dataset preparation
└── examples/
└── quickstart.py End-to-end example (48 lines)
Configuration
oxRL uses Pydantic for type-safe configuration. Every field has a sensible default — you only need to specify what's unique to your run.
Required fields (no defaults):
| Section | Field | Description |
|---|---|---|
run |
experiment_id |
Name for this run |
run |
training_gpus |
Number of GPUs for training |
run |
rollout_gpus |
Number of GPUs for rollout generation |
train |
alg_name |
Algorithm: sgrpo, cispo, sft |
train |
total_number_of_epochs |
Training epochs |
train |
train_steps_per_epoch |
Optimizer steps per epoch (RL) |
model |
name |
HuggingFace model ID |
data |
train_dnames |
Dataset name list |
data |
train_ratios |
Dataset mixing ratios |
data |
train_files_path |
Path to training data |
data |
val_files_path |
Path to validation data |
Everything else (optimizer, scheduler, DeepSpeed ZeRO-3, vLLM rollouts, reward function) defaults to production-tested values. See configs/rl_args.yaml for the full reference.
Experiment Tracking
MLflow is supported but optional. oxRL works out of the box without it — training runs fine, you just won't get experiment tracking.
Without MLflow (default): training logs to console only. Nothing to configure.
With MLflow: install it and set the tracking URI in your config:
pip install mlflow
run:
tracking_uri: "http://localhost:5000" # or your MLflow server
Start a local MLflow UI:
mlflow ui --port 5000
oxRL automatically logs hyperparameters, per-step losses, KL divergence, clip fractions, rewards, and epoch-level aggregates to MLflow.
Key Design Decisions
Sequential rollout → training. oxRL does not pipeline rollout generation with training. This is deliberate. Pipelined overlap improves GPU utilization but makes debugging significantly harder. When training diverges at step 4,000, you want to know exactly what happened.
One class for SGRPO and CISPO. Both algorithms share 99% of their code — the only difference is 4 lines in the policy loss computation. A loss_variant parameter selects between them. No inheritance, no abstraction.
DeepSpeed ZeRO-3 by default. The config system auto-syncs optimizer, scheduler, dtype, and batch size settings to DeepSpeed — you configure once in the YAML and oxRL handles the rest.
Strict on-policy enforcement. Optional mode that validates rollouts were generated by the current policy version. Catches silent distribution shift bugs that waste GPU-days.
Contributing
Contributions are welcome. The bar: keep changes readable, testable, and debuggable. Follow the existing style. If your change adds complexity, it should be worth it.
FAQ
Check out the FAQ for common questions and answers.
Acknowledgments
Some components of this codebase are inspired by practices from open source projects. We try to cite sources wherever we directly reuse exact code. If we missed a citation, please let us know and we will credit the source.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file oxrl-0.8.1.tar.gz.
File metadata
- Download URL: oxrl-0.8.1.tar.gz
- Upload date:
- Size: 18.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6154b1fb66094a89fd28f6c43e3b31ba7e91a71aa977464fb9e305d30021f503
|
|
| MD5 |
170417c5f7ff01772ce66d5ea8b21c8c
|
|
| BLAKE2b-256 |
3fd74e4dbbcf2011309b78fe485de2f4f4675cc285ec7fda3f1840ec64c9228a
|
File details
Details for the file oxrl-0.8.1-py3-none-any.whl.
File metadata
- Download URL: oxrl-0.8.1-py3-none-any.whl
- Upload date:
- Size: 15.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
58094607c7934f29498a8f550bb6e0f917f70800d54bf9457368c6debe1e0201
|
|
| MD5 |
4844a7bcb9730210079ff0f22fb8d71d
|
|
| BLAKE2b-256 |
ba8f9a0d4fbcc0727fc40e22387c02c5487dd74a825e763fd90053fce108e5cc
|