Reinforcement Pre-Training for Language Models - Implementation of ArXiv:2506.08007
Project description
Reinforcement Pre-Training (RPT)
A Python package implementing Reinforcement Pre-Training techniques for language models, based on the paper "Reinforcement Pre-Training (RPT)".
Overview
This package provides a complete implementation of RPT, a novel approach that uses reinforcement learning to train language models by treating next-token prediction as reasoning tasks with verifiable rewards. RPT reframes traditional language modeling to incentivize next-token reasoning, leading to improved accuracy and more robust language understanding.
Key Features
- 🧠 Next-Token Reasoning: Treats token prediction as reasoning tasks with verifiable rewards
- 🚀 Scalable Training: Built-in support for distributed training and memory optimization
- 📊 Comprehensive Metrics: Detailed evaluation tools for reasoning quality assessment
- 🔧 Easy Integration: Simple API that works with existing Hugging Face models
- ⚡ Performance Optimized: Mixed precision training, gradient checkpointing, and efficient batching
- 📈 Visualization Tools: Built-in plotting and analysis capabilities
Installation
From PyPI (when published)
pip install reinforcement-pretraining
From Source
git clone https://github.com/your-username/reinforcement-pretraining.git
cd reinforcement-pretraining
pip install -e .
With Optional Dependencies
# For visualization
pip install reinforcement-pretraining[viz]
# For distributed training
pip install reinforcement-pretraining[distributed]
# For development
pip install reinforcement-pretraining[dev]
# Everything
pip install reinforcement-pretraining[all]
Quick Start
Basic Usage
import torch
from transformers import AutoTokenizer, AutoModel
from rpt import RPTTrainer, RPTModel, RewardSystem, DataProcessor
# Load a pre-trained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
base_model = AutoModel.from_pretrained("gpt2")
# Create RPT model with value head
rpt_model = RPTModel(
base_model=base_model,
tokenizer=tokenizer,
add_value_head=True
)
# Setup reward system
reward_system = RewardSystem(
reward_type="hybrid", # Combines accuracy and confidence
reward_scale=1.0
)
# Prepare your data
data_processor = DataProcessor(tokenizer=tokenizer)
texts = ["Your training texts here...", "Another example..."]
dataset = data_processor.create_dataset(texts, split_ratio=0.9)
train_dataset, val_dataset = dataset
# Create data loaders
train_loader = data_processor.create_dataloader(
train_dataset,
batch_size=8,
shuffle=True
)
val_loader = data_processor.create_dataloader(
val_dataset,
batch_size=8,
shuffle=False
)
# Setup optimizer
optimizer = torch.optim.AdamW(rpt_model.parameters(), lr=5e-5)
# Create trainer
trainer = RPTTrainer(
model=rpt_model,
reward_system=reward_system,
optimizer=optimizer,
train_dataloader=train_loader,
val_dataloader=val_loader,
max_epochs=3,
output_dir="./rpt_output"
)
# Start training
results = trainer.train()
Advanced Usage with Scaling
from rpt import ScalingUtils, ScalingConfig, RPTMetrics
# Configure scaling for large models
scaling_config = ScalingConfig(
use_distributed=True,
gradient_checkpointing=True,
mixed_precision=True,
auto_batch_size=True
)
scaling_utils = ScalingUtils(scaling_config)
# Setup distributed training
scaling_utils.setup_distributed_training()
# Optimize model for memory efficiency
rpt_model = scaling_utils.optimize_model_memory(rpt_model)
# Wrap for distributed training
rpt_model = scaling_utils.wrap_model_for_distributed(rpt_model)
# Find optimal batch size
optimal_batch_size = scaling_utils.get_optimal_batch_size(
model=rpt_model,
sample_input={"input_ids": torch.randint(0, 1000, (1, 512))},
device=torch.device("cuda")
)
print(f"Optimal batch size: {optimal_batch_size}")
Data Processing
# Load data from various formats
texts = data_processor.load_text_data("path/to/data.jsonl", data_format="jsonl")
# Apply quality filtering and create dataset
dataset = data_processor.create_dataset(
texts=texts,
split_ratio=0.9,
shuffle=True,
filter_quality=True
)
# Get dataset statistics
stats = data_processor.get_data_statistics(dataset[0])
print(f"Dataset stats: {stats}")
Metrics and Evaluation
# Initialize metrics tracker
metrics = RPTMetrics(track_detailed_stats=True)
# During training, update metrics
step_metrics = metrics.update_metrics(
step=100,
predictions=model_outputs["logits"],
targets=target_tokens,
rewards=computed_rewards,
attention_mask=attention_mask
)
# Get summary statistics
summary = metrics.get_summary_stats(window_size=100)
# Plot training progress
metrics.plot_metrics(
metrics_to_plot=["token_accuracy", "avg_reward", "perplexity"],
save_path="training_metrics.png"
)
# Export metrics for analysis
metrics.export_metrics("training_results.json")
Examples
Check out the examples/ directory for complete training scripts:
basic_pretraining.py: Simple RPT training setupadvanced_scaling.py: Large-scale distributed trainingcustom_rewards.py: Implementing custom reward functionsevaluation_analysis.py: Comprehensive model evaluation
Core Components
RPTModel
Wraps existing language models and adds RL training capabilities:
- Value head for advantage estimation
- Reasoning-based generation
- Easy saving/loading
RewardSystem
Implements verifiable rewards for next-token prediction:
- Accuracy-based rewards
- Confidence-based rewards
- Adaptive reward scaling
- Custom reward functions
RPTTrainer
Main training loop with PPO-style optimization:
- Gradient accumulation
- Mixed precision training
- Comprehensive logging
- Automatic checkpointing
ScalingUtils
Tools for scaling to large models and datasets:
- Distributed training setup
- Memory optimization
- Automatic batch sizing
- System monitoring
DataProcessor
Utilities for preparing text data:
- Multiple format support
- Quality filtering
- Reasoning augmentation
- Efficient batching
Citation
If you use this package in your research, please cite the original paper:
@article{rpt2024,
title={Reinforcement Pre-Training (RPT)},
author={[Authors]},
journal={arXiv preprint arXiv:2506.08007},
year={2024}
}
Contributing
We welcome contributions! Please see CONTRIBUTING.md for guidelines.
License
This project is licensed under the MIT License - see the LICENSE file for details.
Support
- 📖 Documentation: Read the Docs
- 🐛 Issues: GitHub Issues
- 💬 Discussions: GitHub Discussions
Acknowledgments
- Original RPT paper authors
- Hugging Face Transformers team
- PyTorch team
- The open-source AI community
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file reinforcement_pretraining-0.2.0.tar.gz.
File metadata
- Download URL: reinforcement_pretraining-0.2.0.tar.gz
- Upload date:
- Size: 57.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1ded7d4e4c8cfb5ff718edc86fd931894543d587036628d2922aa8009960e440
|
|
| MD5 |
abce2ce9e67bb4492062b5b784dbfd33
|
|
| BLAKE2b-256 |
24cb5ef441f28ceb50d12ecd567d6d44125690d214105c81cb92ca19a30e06c6
|
File details
Details for the file reinforcement_pretraining-0.2.0-py3-none-any.whl.
File metadata
- Download URL: reinforcement_pretraining-0.2.0-py3-none-any.whl
- Upload date:
- Size: 50.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
15293a9d0f8d35745a6141de57fa7adf2138727d7036a934d3f2c5293a3ee44c
|
|
| MD5 |
06f6c7c79165ffba4044751e83abdff3
|
|
| BLAKE2b-256 |
591fb08afbb93a7dd24ad7dd7808ff245de725287331008947bbbd592ab12ab4
|