RWKV Model Training Toolkit - Pretrain, SFT, and GRPO Training for RWKV Models
Project description
RWKVTune
RWKV Model Training Toolkit - A comprehensive library for training RWKV language models.
Features
-
Three Training Paradigms
PretrainTrainer: Continue pre-training from existing modelsSFTTrainer: Supervised Fine-Tuning for instruction followingGRPOTrainer: GRPO (Group Relative Policy Optimization) for RLHF
-
Efficient Training
- Multi-GPU training with DeepSpeed ZeRO optimization
- Gradient checkpointing for memory efficiency
- Mixed precision training (bf16/fp16/fp32)
-
Parameter-Efficient Fine-Tuning
- LoRA support with customizable target modules
- Easy adapter merging and saving
-
Advanced Capabilities
- Infinite context training support
- HuggingFace Datasets integration
- Checkpoint resume and elastic training
Installation
From PyPI (Recommended)
pip install rwkvtune
From Source
git clone https://github.com/rwkv-community/rwkvtune.git
cd rwkvtune
pip install -e .
With DeepSpeed Support
pip install rwkvtune[deepspeed]
Development Installation
pip install rwkvtune[dev]
Quick Start
Supervised Fine-Tuning (SFT)
from rwkvtune import AutoModel, AutoTokenizer
from rwkvtune.training import SFTConfig, SFTTrainer
from datasets import Dataset
# Load model and tokenizer
model = AutoModel.from_pretrained("/path/to/model")
tokenizer = AutoTokenizer.from_pretrained("/path/to/model")
# Prepare dataset (must have 'input_ids' and 'labels')
def prepare_data(examples):
# Your data preprocessing logic
return {"input_ids": [...], "labels": [...]}
dataset = Dataset.from_list([...])
dataset = dataset.map(prepare_data)
# Configure training
config = SFTConfig(
ctx_len=2048,
micro_bsz=4,
epoch_count=3,
lr_init=1e-4,
devices=1,
precision="bf16",
)
# Create trainer and train
trainer = SFTTrainer(
model=model,
args=config,
train_dataset=dataset,
processing_class=tokenizer,
)
trainer.train()
SFT with LoRA
from rwkvtune import AutoModel
from rwkvtune.peft import LoraConfig, get_peft_model
from rwkvtune.training import SFTConfig, SFTTrainer
# Load model
model = AutoModel.from_pretrained("/path/to/model")
# Apply LoRA
lora_config = LoraConfig(
r=64,
lora_alpha=128,
lora_dropout=0.0,
)
model = get_peft_model(model, lora_config)
# Configure training
config = SFTConfig(
ctx_len=2048,
micro_bsz=4,
epoch_count=3,
)
# Train
trainer = SFTTrainer(
model=model,
args=config,
train_dataset=dataset,
)
trainer.train()
GRPO Training
from rwkvtune import AutoModel, AutoTokenizer
from rwkvtune.training import GRPOConfig, GRPOTrainer
from datasets import Dataset
# Define reward function
def reward_func(prompts, completions, **kwargs):
rewards = []
for completion in completions:
# Your reward logic
rewards.append(1.0 if "correct" in completion else 0.0)
return rewards
# Prepare dataset (must have 'prompt' and 'input_ids')
dataset = Dataset.from_list([
{"prompt": "What is 2+2?", "input_ids": [...]},
...
])
# Configure GRPO
config = GRPOConfig(
ctx_len=2048,
micro_bsz=2,
num_generations=4,
epoch_count=1,
)
# Create trainer
trainer = GRPOTrainer(
model="/path/to/model",
reward_funcs=reward_func,
args=config,
train_dataset=dataset,
)
trainer.train()
Continue Pre-training
from rwkvtune import AutoModel
from rwkvtune.training import PretrainConfig, PretrainTrainer
from datasets import Dataset
# Prepare dataset (must have 'input_ids' and 'labels')
dataset = Dataset.from_list([
{"input_ids": [...], "labels": [...]},
...
])
# Configure pre-training
config = PretrainConfig(
ctx_len=4096,
micro_bsz=8,
epoch_count=1,
lr_init=3e-4,
)
# Create trainer
trainer = PretrainTrainer(
model="/path/to/model",
args=config,
train_dataset=dataset,
)
trainer.train()
Command Line Tools
Merge LoRA Weights
rwkvtune-merge-lora \
--base-model /path/to/base \
--lora-model /path/to/lora \
--output /path/to/merged
Multi-GPU Training
RWKVTune supports multi-GPU training with DeepSpeed:
config = SFTConfig(
devices=4, # Number of GPUs
strategy="deepspeed_stage_2", # DeepSpeed ZeRO Stage 2
precision="bf16",
)
Or with environment variables:
# Using torchrun
torchrun --nproc_per_node=4 train.py
# Using DeepSpeed launcher
deepspeed --num_gpus=4 train.py
Configuration Options
SFTConfig / PretrainConfig
| Parameter | Type | Default | Description |
|---|---|---|---|
ctx_len |
int | 1024 | Context length |
micro_bsz |
int | 4 | Batch size per GPU |
epoch_count |
int | 10 | Number of epochs |
lr_init |
float | 3e-4 | Initial learning rate |
lr_final |
float | 1e-5 | Final learning rate |
warmup_steps |
int | 50 | Warmup steps |
grad_cp |
int | 0 | Gradient checkpointing (0=off, 1=on) |
devices |
int | 1 | Number of GPUs |
precision |
str | "bf16" | Training precision |
strategy |
str | "auto" | Training strategy |
GRPOConfig
| Parameter | Type | Default | Description |
|---|---|---|---|
num_generations |
int | 4 | Completions per prompt |
beta |
float | 0.04 | KL penalty coefficient |
temperature |
float | 1.0 | Sampling temperature |
max_new_tokens |
int | 256 | Max tokens to generate |
LoraConfig
| Parameter | Type | Default | Description |
|---|---|---|---|
r |
int | 64 | LoRA rank |
lora_alpha |
int | 128 | LoRA alpha |
lora_dropout |
float | 0.0 | LoRA dropout |
target_modules |
list | auto | Modules to apply LoRA |
Model Support
Currently supported models:
- RWKV-7 (all sizes: 0.1B, 0.4B, 1.5B, 2.9B, 7.2B, 13.3B)
Requirements
- Python >= 3.8
- PyTorch >= 2.0.0
- Lightning >= 2.0.0
- CUDA (recommended for training)
License
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
Citation
If you use RWKVTune in your research, please cite:
@software{rwkvtune,
title = {RWKVTune: RWKV Model Training Toolkit},
year = {2024},
url = {https://github.com/rwkv-community/rwkvtune}
}
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
Acknowledgments
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file rwkvtune-0.1.0.tar.gz.
File metadata
- Download URL: rwkvtune-0.1.0.tar.gz
- Upload date:
- Size: 559.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
dde4ad596b8083b3eaea278f8377309c46cbfacffa354b73ee22ec18809a1f20
|
|
| MD5 |
d380b551367d9e3cc2a9cd61aa0cda83
|
|
| BLAKE2b-256 |
257b49cb8938f87df2ec9c2b6efac52e3ddaf83437176a9b737a877b44602eb7
|
File details
Details for the file rwkvtune-0.1.0-py3-none-any.whl.
File metadata
- Download URL: rwkvtune-0.1.0-py3-none-any.whl
- Upload date:
- Size: 588.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
baa87fb56b4ec2530af5c7370d65f3b40d09e87e248eed87d9b7f63157436764
|
|
| MD5 |
015e82d52b53bf9d58d77feb464227ff
|
|
| BLAKE2b-256 |
0432f3c465b005a73c58f2188d49b21fd7d2c3ee9675881ff4385594587d864b
|