RL for Vision Language Models
Project description
vlmrl
A reinforcement learning framework for vision-language models, written in JAX.
Core components:
models/qwen25vl— Qwen2.5-VL with mRoPE, KV cache, grouped-query attentioncore/sampling.py— Inferencecore/grpo.py— Training (GRPO)core/eval.py— Evaluationenvs/base.py— Vision environments for captioning, multimodal reasoning, etc.
Quickstart
Install
uv sync
Convert HF → JAX (defaults to Qwen/Qwen2.5-VL-7B-Instruct)
python -m utils.hf_to_jax --model_dir checkpoints/qwen25vl_7b
Sample
python -m core.sampling \
--ckpt_dir checkpoints/qwen25vl_7b \
--image imgs/f35_takeoff.png \
--prompt "Describe the image"
Train (GRPO)
python core/grpo.py \
--model_dir=checkpoints/qwen25vl_7b \
--env_name=vision \
--groups_per_batch=8 \
--group_size=1 \
--lr=5e-7 \
--total_steps=10000 \
--wandb_project=vlm-rl
Eval
python core/eval.py \
--model_dir checkpoints/qwen25vl_7b \
--env_name=vision \
--num_generation_tokens=128 \
--inference_batch_per_device=1 \
--vlm_max_pixels=1048576 \
--top_k=5
Environments
Extend envs.base.BaseEnv to add custom vision environments.
Built-in:
vision/vision_caption— Single-image captioning; reward = keyword hitsnlvr2— Two-image True/False reasoning
Requirements
- Python 3.10+
- Linux, CUDA 12, NVIDIA GPU (~60GB VRAM for 7B)
- JAX 0.6.1 (CUDA 12 build)
References
- lmpo — kvfrans/lmpo
- Qwen model base — jax-ml/jax-llm-examples
- NLVR2 dataset — HuggingFaceM4/NLVR2
License
See LICENSE and NOTICE.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
vlmrl-0.0.1.tar.gz
(47.9 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
vlmrl-0.0.1-py3-none-any.whl
(52.9 kB
view details)
File details
Details for the file vlmrl-0.0.1.tar.gz.
File metadata
- Download URL: vlmrl-0.0.1.tar.gz
- Upload date:
- Size: 47.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
befa7f4331e04f426429176bb4c4a61886f0132cea72819a1efddade01f3aad8
|
|
| MD5 |
25748d3f3a2668209e4a12382d45712f
|
|
| BLAKE2b-256 |
74723e94b077b741ac2ea3bb9c59576cae2a7904102c4cbe84690a42ea9b1b05
|
File details
Details for the file vlmrl-0.0.1-py3-none-any.whl.
File metadata
- Download URL: vlmrl-0.0.1-py3-none-any.whl
- Upload date:
- Size: 52.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6ba23cef505b2f0f677aa8217e581ea5ccbb37e7c647bc9600b9b73fbd18e9d4
|
|
| MD5 |
cb54fefe9247e303b8884b934af63b72
|
|
| BLAKE2b-256 |
04188a6afcfaa43da51ce8da65ecadc06aea2d31d01ea3ae226975c14d2ec1d3
|