Unified API for training and inference
Project description
๐งธ SkyRL tx: Unifying LLM training and inference
GitHub โข Tinker Docs โข Tinker Cookbook โข Slack
SkyRL tx is an open-source library that implements a backend for the Tinker API, allowing you to set up your own Tinker-like service running on your own hardware. It provides a unified interface for both training and inference, enabling seamless online learning, cost-effective multi-tenancy through LoRA, and simplified ML infrastructure.
โจ Key Features
- Unified Training & Inference โ Single engine for forward passes, backward passes, and sampling
- Multi-User LoRA Support โ Efficient GPU sharing across users with individual adapters
- SFT & RL Support โ Supervised fine-tuning and reinforcement learning with PPO and custom loss functions
- Multi-Node Training โ FSDP and tensor parallelism for distributed training
- Multiple Model Architectures โ Support for Qwen3 (dense & MoE), Llama 3, and DeepSeek V3
- External Inference Engine โ Optional vLLM integration for optimized inference
- Production Ready โ PostgreSQL support, cloud storage checkpoints, and database migrations
๐๏ธ Architecture
SkyRL tx consists of four main components:
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ REST API Server โ
โ (FastAPI - handles requests) โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ
โผ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Database โ
โ (SQLite/PostgreSQL - metadata, job queue) โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ
โผ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Engine โ
โ (Scheduling & batching across users/adapters) โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ
โผ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Worker โ
โ (Model execution, forward/backward, optimizer) โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
๐ Quick Start
Installation
git clone https://github.com/NovaSky-AI/SkyRL
cd SkyRL/skyrl-tx
# For GPU
uv run --extra gpu --extra tinker -m tx.tinker.api --base-model <model>
# For TPU
uv run --extra tpu --extra tinker -m tx.tinker.api --base-model <model>
Basic Training Example (Pig Latin)
Start the server:
uv run --extra gpu --extra tinker -m tx.tinker.api --base-model "Qwen/Qwen3-0.6B"
Run a simple training loop:
import tinker
import numpy as np
from tinker import types
# Connect to the local server
service_client = tinker.ServiceClient(base_url="http://localhost:8000", api_key="tml-dummy")
training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen3-0.6B")
tokenizer = training_client.get_tokenizer()
# Training examples
examples = [
{"input": "banana split", "output": "anana-bay plit-say"},
{"input": "quantum physics", "output": "uantum-qay ysics-phay"},
{"input": "coding wizard", "output": "oding-cay izard-way"},
]
def process_example(example, tokenizer):
prompt = f"English: {example['input']}\nPig Latin:"
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
tokens = prompt_tokens + completion_tokens
weights = [0] * len(prompt_tokens) + [1] * len(completion_tokens)
return types.Datum(
model_input=types.ModelInput.from_ints(tokens=tokens[:-1]),
loss_fn_inputs=dict(weights=weights[1:], target_tokens=tokens[1:])
)
processed = [process_example(ex, tokenizer) for ex in examples]
# Training loop
for _ in range(6):
fwdbwd = training_client.forward_backward(processed, "cross_entropy").result()
training_client.optim_step(types.AdamParams(learning_rate=1e-4)).result()
logprobs = np.concatenate([o['logprobs'].tolist() for o in fwdbwd.loss_fn_outputs])
weights = np.concatenate([e.loss_fn_inputs['weights'].tolist() for e in processed])
print(f"Loss: {-np.dot(logprobs, weights) / weights.sum():.4f}")
Sampling
# After training, create a sampling client
sampling_client = training_client.save_weights_and_get_sampling_client(name='my-model')
# Sample from the model
prompt = types.ModelInput.from_ints(tokenizer.encode("English: coffee break\nPig Latin:"))
params = types.SamplingParams(max_tokens=20, temperature=0.0)
result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=8).result()
for i, seq in enumerate(result.sequences):
print(f"{i}: {tokenizer.decode(seq.tokens)}")
๐ Usage Examples
Dense Model Training (Qwen3-8B on 8รH100)
# Start the server
uv run --extra gpu --extra tinker -m tx.tinker.api \
--base-model Qwen/Qwen3-8B \
--backend-config '{"max_lora_adapters": 2, "max_lora_rank": 1, "tensor_parallel_size": 8, "train_micro_batch_size": 1}'
# Run training (using tinker-cookbook)
export TINKER_API_KEY="tml-dummy"
uv run --with wandb --with tinker sl_loop.py \
base_url=http://localhost:8000 \
model_name=Qwen/Qwen3-8B lora_rank=1
MoE Model Training (Qwen/Qwen3-30B-A3B)
# Start the server
uv run --extra gpu --extra tinker -m tx.tinker.api \
--base-model Qwen/Qwen3-30B-A3B \
--backend-config '{"max_lora_adapters": 2, "max_lora_rank": 1, "tensor_parallel_size": 8, "train_micro_batch_size": 1, "shard_attention_heads": false}'
# Run training (using tinker-cookbook)
export TINKER_API_KEY="tml-dummy"
uv run --with wandb --with tinker sl_loop.py \
base_url=http://localhost:8000 \
model_name=Qwen/Qwen3-30B-A3B lora_rank=1 max_length=512
Reinforcement Learning (Qwen/Qwen3-8B)
# Start server
uv run --extra gpu --extra tinker -m tx.tinker.api \
--base-model Qwen/Qwen3-8B \
--backend-config '{"max_lora_adapters": 3, "max_lora_rank": 1, "tensor_parallel_size": 8, "train_micro_batch_size": 8, "sample_max_num_sequences": 256}' > out.log
# Run RL loop
uv run --with wandb --with tinker rl_loop.py \
base_url=http://localhost:8000 \
model_name="Qwen/Qwen3-8B" \
lora_rank=1 max_length=1024
Multi-Node Training
# Node 0 (coordinator + API server)
CUDA_VISIBLE_DEVICES=0,1,2,3 uv run --extra gpu --extra tinker -m tx.tinker.api \
--base-model Qwen/Qwen3-8B \
--backend-config '{
"max_lora_adapters": 3,
"max_lora_rank": 1,
"tensor_parallel_size": 4,
"fully_sharded_data_parallel_size": 2,
"train_micro_batch_size": 8,
"sample_max_num_sequences": 256,
"coordinator_address": "node0:7777",
"num_processes": 2
}' > out.log
# Node 1 (worker)
CUDA_VISIBLE_DEVICES=4,5,6,7 uv run --extra gpu --extra tinker -m tx.tinker.backends.jax \
--coordinator-address "node0:7777" \
--num-processes 2 \
--process-id 1
With External vLLM Inference
# Start vLLM
VLLM_ALLOW_RUNTIME_LORA_UPDATING=True \
VLLM_PLUGINS=lora_filesystem_resolver \
VLLM_LORA_RESOLVER_CACHE_DIR=/tmp/lora_models/ \
CUDA_VISIBLE_DEVICES=4,5,6,7 uv run --with vllm vllm serve Qwen/Qwen3-4B \
--tensor-parallel-size 4 --port 7999 --enable-lora
# Start SkyRL tx with external inference
CUDA_VISIBLE_DEVICES=0,1,2,3 uv run --extra gpu --extra tinker -m tx.tinker.api \
--base-model Qwen/Qwen3-4B \
--external-inference-url "http://0.0.0.0:7999" \
--backend-config '{"max_lora_adapters": 3, "max_lora_rank": 1, "tensor_parallel_size": 4, "train_micro_batch_size": 8}' > out.log
๐ฏ Supported Features
| Feature | Status |
|---|---|
| Qwen3 Dense Models | โ |
| Qwen3 MoE Models | โ |
| Llama 3 Models | โ |
| DeepSeek V3 Models | โ |
| Multi-User LoRA | โ |
| LoRA (all layers) | โ |
| Forward/Backward | โ |
| Sampling | โ |
| Gradient Accumulation | โ |
| Gradient Checkpointing | โ |
| JIT Compilation | โ |
| Tensor Parallelism | โ |
| FSDP | โ |
| Multi-Node | โ |
| PostgreSQL | โ |
| Cloud Storage Checkpoints | โ |
| Custom Loss Functions | โ |
| External Inference (vLLM) | โ |
| Local Model Loading | โ |
๐บ๏ธ Roadmap
- Performance โ Expert parallelism, context parallelism, optimized kernels
- Models โ More architectures, PyTorch model definitions via torchax
- API Coverage โ Full Tinker API compatibility
- Operations โ Dashboard/frontend, improved logging and metrics
- Integration โ SkyRL-train Tinkerification
๐ค Contributing
We welcome contributions! The project is early and hackable โ now is a great time to get involved.
Ways to contribute:
- Try examples from the Tinker documentation or cookbook
- Fix issues or implement features from our issue tracker
- Improve documentation
- Add support for more models
- Performance optimizations
๐ Resources
- Ray Summit Talk โ SkyRL tx: A unified training and inference engine
- Slides โ Presentation slides
- Tinker Documentation โ Official Tinker API docs
- Tinker Cookbook โ Example recipes
๐ Blog Posts
- Introducing SkyRL tx
- SkyRL tx v0.0.2
- SkyRL tx v0.0.3
- SkyRL tx v0.1.0
- SkyRL tx v0.2.0
- SkyRL tx v0.2.1
๐ฌ Contact
- Slack: #skyrl-tx
- GitHub: NovaSky-AI/SkyRL/skyrl-tx
- Twitter/X: @NovaSkyAI
๐ License
See LICENSE for details.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file skyrl_tx-0.3.0.tar.gz.
File metadata
- Download URL: skyrl_tx-0.3.0.tar.gz
- Upload date:
- Size: 82.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.28 {"installer":{"name":"uv","version":"0.9.28","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7ab09c8b2322c08b32960838fcd0f6357914a85721465cfc7ae62b1bc670a6b9
|
|
| MD5 |
7fc506750cc8440a86b1ac22d280e5b5
|
|
| BLAKE2b-256 |
300325b701d74087a0cd328114cd3f1d78999929ebbe2a2dfdbec24c2b4dc550
|
File details
Details for the file skyrl_tx-0.3.0-py3-none-any.whl.
File metadata
- Download URL: skyrl_tx-0.3.0-py3-none-any.whl
- Upload date:
- Size: 97.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.28 {"installer":{"name":"uv","version":"0.9.28","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7b70f34652635b5ed3deae349269ef8c060264adae8704d15393c25cbab10fe8
|
|
| MD5 |
e08253ed3b868bb4ad5c5c29f40e42c9
|
|
| BLAKE2b-256 |
341f2baa1699b2e73624ee1f37d3ca5a41001a0d35a2e463b4822527d450f1aa
|