Skip to main content

Reinforcement Pre-Training for Language Models - Implementation of ArXiv:2506.08007

Project description

Reinforcement Pre-Training (RPT)

A Python package implementing Reinforcement Pre-Training techniques for language models, based on the paper "Reinforcement Pre-Training (RPT)".

Overview

This package provides a complete implementation of RPT, a novel approach that uses reinforcement learning to train language models by treating next-token prediction as reasoning tasks with verifiable rewards. RPT reframes traditional language modeling to incentivize next-token reasoning, leading to improved accuracy and more robust language understanding.

Key Features

  • 🧠 Next-Token Reasoning: Treats token prediction as reasoning tasks with verifiable rewards
  • 🚀 Scalable Training: Built-in support for distributed training and memory optimization
  • 📊 Comprehensive Metrics: Detailed evaluation tools for reasoning quality assessment
  • 🔧 Easy Integration: Simple API that works with existing Hugging Face models
  • Performance Optimized: Mixed precision training, gradient checkpointing, and efficient batching
  • 📈 Visualization Tools: Built-in plotting and analysis capabilities

Installation

From PyPI (when published)

pip install reinforcement-pretraining

From Source

git clone https://github.com/your-username/reinforcement-pretraining.git
cd reinforcement-pretraining
pip install -e .

With Optional Dependencies

# For visualization
pip install reinforcement-pretraining[viz]

# For distributed training
pip install reinforcement-pretraining[distributed]

# For development
pip install reinforcement-pretraining[dev]

# Everything
pip install reinforcement-pretraining[all]

Quick Start

Basic Usage

import torch
from transformers import AutoTokenizer, AutoModel
from rpt import RPTTrainer, RPTModel, RewardSystem, DataProcessor

# Load a pre-trained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
base_model = AutoModel.from_pretrained("gpt2")

# Create RPT model with value head
rpt_model = RPTModel(
    base_model=base_model,
    tokenizer=tokenizer,
    add_value_head=True
)

# Setup reward system
reward_system = RewardSystem(
    reward_type="hybrid",  # Combines accuracy and confidence
    reward_scale=1.0
)

# Prepare your data
data_processor = DataProcessor(tokenizer=tokenizer)
texts = ["Your training texts here...", "Another example..."]
dataset = data_processor.create_dataset(texts, split_ratio=0.9)
train_dataset, val_dataset = dataset

# Create data loaders
train_loader = data_processor.create_dataloader(
    train_dataset, 
    batch_size=8, 
    shuffle=True
)
val_loader = data_processor.create_dataloader(
    val_dataset, 
    batch_size=8, 
    shuffle=False
)

# Setup optimizer
optimizer = torch.optim.AdamW(rpt_model.parameters(), lr=5e-5)

# Create trainer
trainer = RPTTrainer(
    model=rpt_model,
    reward_system=reward_system,
    optimizer=optimizer,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    max_epochs=3,
    output_dir="./rpt_output"
)

# Start training
results = trainer.train()

Advanced Usage with Scaling

from rpt import ScalingUtils, ScalingConfig, RPTMetrics

# Configure scaling for large models
scaling_config = ScalingConfig(
    use_distributed=True,
    gradient_checkpointing=True,
    mixed_precision=True,
    auto_batch_size=True
)

scaling_utils = ScalingUtils(scaling_config)

# Setup distributed training
scaling_utils.setup_distributed_training()

# Optimize model for memory efficiency
rpt_model = scaling_utils.optimize_model_memory(rpt_model)

# Wrap for distributed training
rpt_model = scaling_utils.wrap_model_for_distributed(rpt_model)

# Find optimal batch size
optimal_batch_size = scaling_utils.get_optimal_batch_size(
    model=rpt_model,
    sample_input={"input_ids": torch.randint(0, 1000, (1, 512))},
    device=torch.device("cuda")
)

print(f"Optimal batch size: {optimal_batch_size}")

Data Processing

# Load data from various formats
texts = data_processor.load_text_data("path/to/data.jsonl", data_format="jsonl")

# Apply quality filtering and create dataset
dataset = data_processor.create_dataset(
    texts=texts,
    split_ratio=0.9,
    shuffle=True,
    filter_quality=True
)

# Get dataset statistics
stats = data_processor.get_data_statistics(dataset[0])
print(f"Dataset stats: {stats}")

Metrics and Evaluation

# Initialize metrics tracker
metrics = RPTMetrics(track_detailed_stats=True)

# During training, update metrics
step_metrics = metrics.update_metrics(
    step=100,
    predictions=model_outputs["logits"],
    targets=target_tokens,
    rewards=computed_rewards,
    attention_mask=attention_mask
)

# Get summary statistics
summary = metrics.get_summary_stats(window_size=100)

# Plot training progress
metrics.plot_metrics(
    metrics_to_plot=["token_accuracy", "avg_reward", "perplexity"],
    save_path="training_metrics.png"
)

# Export metrics for analysis
metrics.export_metrics("training_results.json")

Examples

Check out the examples/ directory for complete training scripts:

  • basic_pretraining.py: Simple RPT training setup
  • advanced_scaling.py: Large-scale distributed training
  • custom_rewards.py: Implementing custom reward functions
  • evaluation_analysis.py: Comprehensive model evaluation

Core Components

RPTModel

Wraps existing language models and adds RL training capabilities:

  • Value head for advantage estimation
  • Reasoning-based generation
  • Easy saving/loading

RewardSystem

Implements verifiable rewards for next-token prediction:

  • Accuracy-based rewards
  • Confidence-based rewards
  • Adaptive reward scaling
  • Custom reward functions

RPTTrainer

Main training loop with PPO-style optimization:

  • Gradient accumulation
  • Mixed precision training
  • Comprehensive logging
  • Automatic checkpointing

ScalingUtils

Tools for scaling to large models and datasets:

  • Distributed training setup
  • Memory optimization
  • Automatic batch sizing
  • System monitoring

DataProcessor

Utilities for preparing text data:

  • Multiple format support
  • Quality filtering
  • Reasoning augmentation
  • Efficient batching

Citation

If you use this package in your research, please cite the original paper:

@article{rpt2024,
  title={Reinforcement Pre-Training (RPT)},
  author={[Authors]},
  journal={arXiv preprint arXiv:2506.08007},
  year={2024}
}

Contributing

We welcome contributions! Please see CONTRIBUTING.md for guidelines.

License

This project is licensed under the MIT License - see the LICENSE file for details.

Support

Acknowledgments

  • Original RPT paper authors
  • Hugging Face Transformers team
  • PyTorch team
  • The open-source AI community

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

reinforcement_pretraining-0.2.0.tar.gz (57.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

reinforcement_pretraining-0.2.0-py3-none-any.whl (50.1 kB view details)

Uploaded Python 3

File details

Details for the file reinforcement_pretraining-0.2.0.tar.gz.

File metadata

File hashes

Hashes for reinforcement_pretraining-0.2.0.tar.gz
Algorithm Hash digest
SHA256 1ded7d4e4c8cfb5ff718edc86fd931894543d587036628d2922aa8009960e440
MD5 abce2ce9e67bb4492062b5b784dbfd33
BLAKE2b-256 24cb5ef441f28ceb50d12ecd567d6d44125690d214105c81cb92ca19a30e06c6

See more details on using hashes here.

File details

Details for the file reinforcement_pretraining-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for reinforcement_pretraining-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 15293a9d0f8d35745a6141de57fa7adf2138727d7036a934d3f2c5293a3ee44c
MD5 06f6c7c79165ffba4044751e83abdff3
BLAKE2b-256 591fb08afbb93a7dd24ad7dd7808ff245de725287331008947bbbd592ab12ab4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page