Skip to main content

Unified API for training and inference

Project description

๐Ÿงธ SkyRL tx: Unifying LLM training and inference

GitHub โ€ข Tinker Docs โ€ข Tinker Cookbook โ€ข Slack

SkyRL tx is an open-source library that implements a backend for the Tinker API, allowing you to set up your own Tinker-like service running on your own hardware. It provides a unified interface for both training and inference, enabling seamless online learning, cost-effective multi-tenancy through LoRA, and simplified ML infrastructure.

โœจ Key Features

  • Unified Training & Inference โ€” Single engine for forward passes, backward passes, and sampling
  • Multi-User LoRA Support โ€” Efficient GPU sharing across users with individual adapters
  • SFT & RL Support โ€” Supervised fine-tuning and reinforcement learning with PPO and custom loss functions
  • Multi-Node Training โ€” FSDP and tensor parallelism for distributed training
  • Multiple Model Architectures โ€” Support for Qwen3 (dense & MoE), Llama 3, and DeepSeek V3
  • External Inference Engine โ€” Optional vLLM integration for optimized inference
  • Production Ready โ€” PostgreSQL support, cloud storage checkpoints, and database migrations

๐Ÿ—๏ธ Architecture

SkyRL tx consists of four main components:

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                        REST API Server                          โ”‚
โ”‚                    (FastAPI - handles requests)                 โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
                                 โ”‚
                                 โ–ผ
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                           Database                              โ”‚
โ”‚         (SQLite/PostgreSQL - metadata, job queue)               โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
                                 โ”‚
                                 โ–ผ
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                            Engine                               โ”‚
โ”‚        (Scheduling & batching across users/adapters)            โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
                                 โ”‚
                                 โ–ผ
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                            Worker                               โ”‚
โ”‚       (Model execution, forward/backward, optimizer)            โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

๐Ÿš€ Quick Start

Installation

git clone https://github.com/NovaSky-AI/SkyRL
cd SkyRL/skyrl-tx

# For GPU
uv run --extra gpu --extra tinker -m tx.tinker.api --base-model <model>

# For TPU
uv run --extra tpu --extra tinker -m tx.tinker.api --base-model <model>

Basic Training Example (Pig Latin)

Start the server:

uv run --extra gpu --extra tinker -m tx.tinker.api --base-model "Qwen/Qwen3-0.6B"

Run a simple training loop:

import tinker
import numpy as np
from tinker import types

# Connect to the local server
service_client = tinker.ServiceClient(base_url="http://localhost:8000", api_key="tml-dummy")
training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen3-0.6B")
tokenizer = training_client.get_tokenizer()

# Training examples
examples = [
    {"input": "banana split", "output": "anana-bay plit-say"},
    {"input": "quantum physics", "output": "uantum-qay ysics-phay"},
    {"input": "coding wizard", "output": "oding-cay izard-way"},
]

def process_example(example, tokenizer):
    prompt = f"English: {example['input']}\nPig Latin:"
    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
    completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)

    tokens = prompt_tokens + completion_tokens
    weights = [0] * len(prompt_tokens) + [1] * len(completion_tokens)

    return types.Datum(
        model_input=types.ModelInput.from_ints(tokens=tokens[:-1]),
        loss_fn_inputs=dict(weights=weights[1:], target_tokens=tokens[1:])
    )

processed = [process_example(ex, tokenizer) for ex in examples]

# Training loop
for _ in range(6):
    fwdbwd = training_client.forward_backward(processed, "cross_entropy").result()
    training_client.optim_step(types.AdamParams(learning_rate=1e-4)).result()

    logprobs = np.concatenate([o['logprobs'].tolist() for o in fwdbwd.loss_fn_outputs])
    weights = np.concatenate([e.loss_fn_inputs['weights'].tolist() for e in processed])
    print(f"Loss: {-np.dot(logprobs, weights) / weights.sum():.4f}")

Sampling

# After training, create a sampling client
sampling_client = training_client.save_weights_and_get_sampling_client(name='my-model')

# Sample from the model
prompt = types.ModelInput.from_ints(tokenizer.encode("English: coffee break\nPig Latin:"))
params = types.SamplingParams(max_tokens=20, temperature=0.0)
result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=8).result()

for i, seq in enumerate(result.sequences):
    print(f"{i}: {tokenizer.decode(seq.tokens)}")

๐Ÿ“– Usage Examples

Dense Model Training (Qwen3-8B on 8ร—H100)

# Start the server
uv run --extra gpu --extra tinker -m tx.tinker.api \
    --base-model Qwen/Qwen3-8B \
    --backend-config '{"max_lora_adapters": 2, "max_lora_rank": 1, "tensor_parallel_size": 8, "train_micro_batch_size": 1}'

# Run training (using tinker-cookbook)
export TINKER_API_KEY="tml-dummy"
uv run --with wandb --with tinker sl_loop.py \
    base_url=http://localhost:8000 \
    model_name=Qwen/Qwen3-8B lora_rank=1

MoE Model Training (Qwen/Qwen3-30B-A3B)

# Start the server
uv run --extra gpu --extra tinker -m tx.tinker.api \
    --base-model Qwen/Qwen3-30B-A3B \
    --backend-config '{"max_lora_adapters": 2, "max_lora_rank": 1, "tensor_parallel_size": 8, "train_micro_batch_size": 1, "shard_attention_heads": false}'

# Run training (using tinker-cookbook)
export TINKER_API_KEY="tml-dummy"
uv run --with wandb --with tinker sl_loop.py \
    base_url=http://localhost:8000 \
    model_name=Qwen/Qwen3-30B-A3B lora_rank=1 max_length=512

Reinforcement Learning (Qwen/Qwen3-8B)

# Start server
uv run --extra gpu --extra tinker -m tx.tinker.api \
    --base-model Qwen/Qwen3-8B \
    --backend-config '{"max_lora_adapters": 3, "max_lora_rank": 1, "tensor_parallel_size": 8, "train_micro_batch_size": 8, "sample_max_num_sequences": 256}' > out.log

# Run RL loop
uv run --with wandb --with tinker rl_loop.py \
    base_url=http://localhost:8000 \
    model_name="Qwen/Qwen3-8B" \
    lora_rank=1 max_length=1024

Multi-Node Training

# Node 0 (coordinator + API server)
CUDA_VISIBLE_DEVICES=0,1,2,3 uv run --extra gpu --extra tinker -m tx.tinker.api \
    --base-model Qwen/Qwen3-8B \
    --backend-config '{
        "max_lora_adapters": 3,
        "max_lora_rank": 1,
        "tensor_parallel_size": 4,
        "fully_sharded_data_parallel_size": 2,
        "train_micro_batch_size": 8,
        "sample_max_num_sequences": 256,
        "coordinator_address": "node0:7777",
        "num_processes": 2
    }' > out.log

# Node 1 (worker)
CUDA_VISIBLE_DEVICES=4,5,6,7 uv run --extra gpu --extra tinker -m tx.tinker.backends.jax \
    --coordinator-address "node0:7777" \
    --num-processes 2 \
    --process-id 1

With External vLLM Inference

# Start vLLM
VLLM_ALLOW_RUNTIME_LORA_UPDATING=True \
VLLM_PLUGINS=lora_filesystem_resolver \
VLLM_LORA_RESOLVER_CACHE_DIR=/tmp/lora_models/ \
CUDA_VISIBLE_DEVICES=4,5,6,7 uv run --with vllm vllm serve Qwen/Qwen3-4B \
    --tensor-parallel-size 4 --port 7999 --enable-lora

# Start SkyRL tx with external inference
CUDA_VISIBLE_DEVICES=0,1,2,3 uv run --extra gpu --extra tinker -m tx.tinker.api \
    --base-model Qwen/Qwen3-4B \
    --external-inference-url "http://0.0.0.0:7999" \
    --backend-config '{"max_lora_adapters": 3, "max_lora_rank": 1, "tensor_parallel_size": 4, "train_micro_batch_size": 8}' > out.log

๐ŸŽฏ Supported Features

Feature Status
Qwen3 Dense Models โœ…
Qwen3 MoE Models โœ…
Llama 3 Models โœ…
DeepSeek V3 Models โœ…
Multi-User LoRA โœ…
LoRA (all layers) โœ…
Forward/Backward โœ…
Sampling โœ…
Gradient Accumulation โœ…
Gradient Checkpointing โœ…
JIT Compilation โœ…
Tensor Parallelism โœ…
FSDP โœ…
Multi-Node โœ…
PostgreSQL โœ…
Cloud Storage Checkpoints โœ…
Custom Loss Functions โœ…
External Inference (vLLM) โœ…
Local Model Loading โœ…

๐Ÿ—บ๏ธ Roadmap

  • Performance โ€” Expert parallelism, context parallelism, optimized kernels
  • Models โ€” More architectures, PyTorch model definitions via torchax
  • API Coverage โ€” Full Tinker API compatibility
  • Operations โ€” Dashboard/frontend, improved logging and metrics
  • Integration โ€” SkyRL-train Tinkerification

๐Ÿค Contributing

We welcome contributions! The project is early and hackable โ€” now is a great time to get involved.

Ways to contribute:

๐Ÿ“š Resources

๐Ÿ“ Blog Posts

๐Ÿ“ฌ Contact

๐Ÿ“„ License

See LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

skyrl_tx-0.3.0.tar.gz (82.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

skyrl_tx-0.3.0-py3-none-any.whl (97.7 kB view details)

Uploaded Python 3

File details

Details for the file skyrl_tx-0.3.0.tar.gz.

File metadata

  • Download URL: skyrl_tx-0.3.0.tar.gz
  • Upload date:
  • Size: 82.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.28 {"installer":{"name":"uv","version":"0.9.28","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for skyrl_tx-0.3.0.tar.gz
Algorithm Hash digest
SHA256 7ab09c8b2322c08b32960838fcd0f6357914a85721465cfc7ae62b1bc670a6b9
MD5 7fc506750cc8440a86b1ac22d280e5b5
BLAKE2b-256 300325b701d74087a0cd328114cd3f1d78999929ebbe2a2dfdbec24c2b4dc550

See more details on using hashes here.

File details

Details for the file skyrl_tx-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: skyrl_tx-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 97.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.28 {"installer":{"name":"uv","version":"0.9.28","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for skyrl_tx-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 7b70f34652635b5ed3deae349269ef8c060264adae8704d15393c25cbab10fe8
MD5 e08253ed3b868bb4ad5c5c29f40e42c9
BLAKE2b-256 341f2baa1699b2e73624ee1f37d3ca5a41001a0d35a2e463b4822527d450f1aa

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page